# mypy: ignore-errors

import math
from copy import copy
from dataclasses import dataclass
from functools import partial
from typing import Optional

import torch
from torch.fx.experimental.symbolic_shapes import is_nested_int
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.opinfo.core import (
    BinaryUfuncInfo,
    ReductionOpInfo,
    SampleInput,
    UnaryUfuncInfo,
)
from torch.utils._pytree import tree_flatten, tree_map


@dataclass
class ExtraOpData:
    """
    Contains info on top of the typical OpInfo data that is useful for NJT test generation.

    The process that converts the standard op_db -> an NJT-compatible op_db will attach this
    data onto each associated OpInfo entry.
    """

    # Indicates whether the associated op is a view op
    is_view: bool = False

    # Specifies the names of any dim-related args that the op takes in. This is useful
    # for NJT tests because there is often asymmetry across the supported set of dims for
    # an op; it may make sense to operate over the batch dim but not the ragged dim, for
    # example. The length of this list should match the number of relevant overloads.
    # Each list item of the outer list should specify dim argnames. Ellipses should be used
    # to indicate multi-dim support for a given overload.
    #
    # For example, squeeze() has both a dim and multi-dim overload, where the argname for
    # each is simply "dim". Its entry should be: [["dim"], ["dim..."]].
    #
    # If no overload of the op accepts dim-related args, this should be None.
    dim_args: list[list[str]] = None

    # Helper function to extract names of dim-related args.
    # Returns: tuple of (single dim argname if available, dim list argname if available)
    # If the op doesn't support dim-related args at all OR this op only has overloads
    # with multiple dim args (e.g. transpose()), then this returns (None, None).
    def get_dim_argnames(self) -> tuple[Optional[str], Optional[str]]:
        if self.dim_args is None:
            return (None, None)

        # name for the dim arg that supports a single dim
        single_dim_argname = None
        # name for the dim arg that supports a list of dims
        dimlist_argname = None
        for overload in self.dim_args:
            # only consider overloads with a single dim-related arg
            if len(overload) != 1:
                continue
            if overload[0].endswith("..."):
                dimlist_argname = overload[0].replace("...", "")
                if single_dim_argname is None:
                    single_dim_argname = dimlist_argname
            else:
                single_dim_argname = overload[0]
        return (single_dim_argname, dimlist_argname)


# Mapping of OpInfo full names -> extra data to tack onto the OpInfo entry for use
# in test generation.
extra_op_data = {
    "_segment_reduce.lengths": ExtraOpData(dim_args=[["axis0"]]),
    "_segment_reduce.offsets": ExtraOpData(dim_args=[["axis0"]]),
    "all": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
    "argmax": ExtraOpData(dim_args=[["dim"]]),
    "argmin": ExtraOpData(dim_args=[["dim"]]),
    "amax": ExtraOpData(dim_args=[["dim..."]]),
    "amin": ExtraOpData(dim_args=[["dim..."]]),
    "any": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
    "argsort": ExtraOpData(dim_args=[["dim"]]),
    "broadcast_to": ExtraOpData(is_view=True),
    "cat": ExtraOpData(dim_args=[["dim"]]),
    "chunk": ExtraOpData(is_view=True, dim_args=[["dim"]]),
    "conj": ExtraOpData(is_view=True),
    "contiguous": ExtraOpData(is_view=True),
    "count_nonzero": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
    "cummax": ExtraOpData(dim_args=[["dim"]]),
    "cummin": ExtraOpData(dim_args=[["dim"]]),
    "cumprod": ExtraOpData(dim_args=[["dim"]]),
    "cumsum": ExtraOpData(dim_args=[["dim"]]),
    "cumulative_trapezoid": ExtraOpData(dim_args=[["dim"]]),
    "diag_embed": ExtraOpData(dim_args=[["dim1", "dim2"]]),
    "diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]),
    "diagonal_copy": ExtraOpData(dim_args=[["dim1", "dim2"]]),
    "diagonal_scatter": ExtraOpData(dim_args=[["dim1", "dim2"]]),
    "diff": ExtraOpData(dim_args=[["dim"]]),
    "expand": ExtraOpData(is_view=True),
    "expand_as": ExtraOpData(is_view=True),
    "fft.fft": ExtraOpData(dim_args=[["dim"]]),
    "fft.hfft": ExtraOpData(dim_args=[["dim"]]),
    "fft.ifft": ExtraOpData(dim_args=[["dim"]]),
    "fft.ihfft": ExtraOpData(dim_args=[["dim"]]),
    "fft.irfft": ExtraOpData(dim_args=[["dim"]]),
    "fft.rfft": ExtraOpData(dim_args=[["dim"]]),
    "flatten": ExtraOpData(is_view=True, dim_args=[["start_dim", "end_dim"]]),
    "flip": ExtraOpData(dim_args=[["dims..."]]),
    "gather": ExtraOpData(dim_args=[["dim"]]),
    "imag": ExtraOpData(is_view=True),
    "index_add": ExtraOpData(dim_args=[["dim"]]),
    "index_copy": ExtraOpData(dim_args=[["dim"]]),
    "index_fill": ExtraOpData(dim_args=[["dim"]]),
    "index_reduce.amax": ExtraOpData(dim_args=[["dim"]]),
    "index_reduce.amin": ExtraOpData(dim_args=[["dim"]]),
    "index_reduce.mean": ExtraOpData(dim_args=[["dim"]]),
    "index_reduce.prod": ExtraOpData(dim_args=[["dim"]]),
    "index_select": ExtraOpData(dim_args=[["dim"]]),
    "kthvalue": ExtraOpData(dim_args=[["dim"]]),
    "linalg.cross": ExtraOpData(dim_args=[["dim"]]),
    "linalg.diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]),
    "linalg.tensorsolve": ExtraOpData(dim_args=[["dims..."]]),
    "linalg.vecdot": ExtraOpData(dim_args=[["dim"]]),
    "linalg.vector_norm": ExtraOpData(dim_args=[["dim..."]]),
    "log_softmax": ExtraOpData(dim_args=[["dim"]]),
    "logcumsumexp": ExtraOpData(dim_args=[["dim"]]),
    "masked.amax": ExtraOpData(dim_args=[["dim"]]),
    "masked.amin": ExtraOpData(dim_args=[["dim"]]),
    "masked.argmax": ExtraOpData(dim_args=[["dim"]]),
    "masked.argmin": ExtraOpData(dim_args=[["dim"]]),
    "masked.logsumexp": ExtraOpData(dim_args=[["dim"]]),
    "masked.mean": ExtraOpData(dim_args=[["dim"]]),
    "masked.norm": ExtraOpData(dim_args=[["dim"]]),
    "masked.prod": ExtraOpData(dim_args=[["dim"]]),
    "masked.std": ExtraOpData(dim_args=[["dim"]]),
    "masked.sum": ExtraOpData(dim_args=[["dim"]]),
    "masked.var": ExtraOpData(dim_args=[["dim"]]),
    "max.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]),
    "median": ExtraOpData(dim_args=[["dim"]]),
    "mean": ExtraOpData(dim_args=[["dim..."]]),
    "min.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]),
    "mode": ExtraOpData(dim_args=[["dim"]]),
    "movedim": ExtraOpData(
        dim_args=[["source", "destination"], ["source...", "destination..."]]
    ),
    "nanmean": ExtraOpData(dim_args=[["dim..."]]),
    "nanmedian": ExtraOpData(dim_args=[["dim"]]),
    "nansum": ExtraOpData(dim_args=[["dim..."]]),
    "narrow": ExtraOpData(is_view=True, dim_args=[["dim"]]),
    "narrow_copy": ExtraOpData(dim_args=[["dim"]]),
    "nn.functional.cosine_similarity": ExtraOpData(dim_args=[["dim"]]),
    "nn.functional.glu": ExtraOpData(dim_args=[["dim"]]),
    "permute": ExtraOpData(is_view=True, dim_args=[["dims..."]]),
    "positive": ExtraOpData(is_view=True),
    "prod": ExtraOpData(dim_args=[["dim"]]),
    "ravel": ExtraOpData(is_view=True),
    "real": ExtraOpData(is_view=True),
    "renorm": ExtraOpData(dim_args=[["dim"]]),
    "reshape": ExtraOpData(is_view=True),
    "reshape_as": ExtraOpData(is_view=True),
    "roll": ExtraOpData(dim_args=[["dims..."]]),
    "rot90": ExtraOpData(dim_args=[["dims..."]]),
    "scatter": ExtraOpData(dim_args=[["dim"]]),
    "scatter_add": ExtraOpData(dim_args=[["dim"]]),
    "scatter_reduce.amax": ExtraOpData(dim_args=[["dim"]]),
    "scatter_reduce.amin": ExtraOpData(dim_args=[["dim"]]),
    "scatter_reduce.mean": ExtraOpData(dim_args=[["dim"]]),
    "scatter_reduce.prod": ExtraOpData(dim_args=[["dim"]]),
    "scatter_reduce.sum": ExtraOpData(dim_args=[["dim"]]),
    "select": ExtraOpData(is_view=True, dim_args=[["dim"]]),
    "select_scatter": ExtraOpData(dim_args=[["dim"]]),
    "slice": ExtraOpData(is_view=True, dim_args=[["dim"]]),
    "slice_scatter": ExtraOpData(dim_args=[["dim"]]),
    "softmax": ExtraOpData(dim_args=[["dim"]]),
    "sort": ExtraOpData(dim_args=[["dim"]]),
    "split": ExtraOpData(is_view=True, dim_args=[["dim"]]),
    "split_with_sizes": ExtraOpData(is_view=True, dim_args=[["dim"]]),
    "split_with_sizes_copy": ExtraOpData(dim_args=[["dim"]]),
    "squeeze": ExtraOpData(is_view=True, dim_args=[["dim"], ["dim..."]]),
    "squeeze_copy": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
    "stack": ExtraOpData(dim_args=[["dim"]]),
    "std": ExtraOpData(dim_args=[["dim..."]]),
    "std.unbiased": ExtraOpData(dim_args=[["dim..."]]),
    "sum": ExtraOpData(dim_args=[["dim..."]]),
    "t": ExtraOpData(is_view=True),
    "tensor_split": ExtraOpData(is_view=True, dim_args=[["dim"]]),
    "tensordot": ExtraOpData(dim_args=[["dims..."]]),
    "tile": ExtraOpData(dim_args=[["dims..."]]),
    "topk": ExtraOpData(dim_args=[["dim"]]),
    "transpose": ExtraOpData(is_view=True, dim_args=[["dim0", "dim1"]]),
    "transpose_copy": ExtraOpData(dim_args=[["dim0", "dim1"]]),
    "trapezoid": ExtraOpData(dim_args=[["dim"]]),
    "trapz": ExtraOpData(dim_args=[["dim"]]),
    "unbind": ExtraOpData(is_view=True, dim_args=[["dim"]]),
    "unflatten": ExtraOpData(is_view=True, dim_args=[["dim"]]),
    "unfold": ExtraOpData(is_view=True, dim_args=[["dimension"]]),
    "unfold_copy": ExtraOpData(dim_args=[["dimension"]]),
    "unsafe_chunk": ExtraOpData(dim_args=[["dim"]]),
    "unsafe_split": ExtraOpData(dim_args=[["dim"]]),
    "unsqueeze": ExtraOpData(is_view=True, dim_args=[["dim"]]),
    "unsqueeze_copy": ExtraOpData(dim_args=[["dim"]]),
    "var": ExtraOpData(dim_args=[["dim..."]]),
    "var.unbiased": ExtraOpData(dim_args=[["dim..."]]),
    "view": ExtraOpData(is_view=True),
    "view_as": ExtraOpData(is_view=True),
    "view_as_complex": ExtraOpData(is_view=True),
    "view_as_real": ExtraOpData(is_view=True),
}


# random integer used for sizes
def _rnd():
    return torch.randint(3, 8, ()).item()


def _raggedness_matches(nt1, nt2):
    return (
        nt1.is_nested
        and nt2.is_nested
        and nt1._ragged_idx == nt2._ragged_idx
        and nt1.shape[nt1._ragged_idx] == nt2.shape[nt2._ragged_idx]
    )


# Helper function to avoid reusing the exact same tensor / NJT across SampleInputs,
# as this causes autograd problems.
def _clone(t):
    requires_grad = t.requires_grad
    return t.detach().clone().requires_grad_(requires_grad)


# Helper function to update a sample with new kwargs / name
def _update_sample(sample, new_kwargs):
    all_kwargs = dict(sample.kwargs)
    all_kwargs.update(new_kwargs)
    full_name = ", ".join([sample.name, *(f"{k}={v}" for (k, v) in new_kwargs.items())])
    return SampleInput(
        _clone(sample.input),
        args=sample.args,
        kwargs=all_kwargs,
        name=full_name,
    )


# Generates a random NT.
# dims should be something like [5, None, 10], with None indicating that a
# random ragged structure should be used
def random_nt_from_dims(
    dims, device=None, dtype=None, layout=torch.strided, requires_grad=False
):
    sizes = [[d if d is not None else _rnd() for d in dims[1:]] for d in range(dims[0])]
    return torch.nested.nested_tensor(
        [torch.randn(*size) for size in sizes],
        device=device,
        dtype=dtype,
        layout=layout,
        requires_grad=requires_grad,
    )


# Helper function to get a reasonable string representation of an NJT for use in
# SampleInput names.
def _describe_njt(njt) -> str:
    contig_type = "_contig" if njt.is_contiguous() else "_noncontig"
    if njt._lengths is not None and njt._offsets is not None:
        contig_type += "_holes"
    elif njt._ragged_idx != 1:
        contig_type += "_transposed"

    cached_data = "_without_seqlen_cache"
    if njt._max_seqlen_tensor is not None:
        cached_data = "_with_seqlen_cache"

    return f"{njt.dim()}D{contig_type}{cached_data}"


# Helper function to get a reasonable string representation of a given dim wrt an NJT.
def _describe_dim(njt, dim):
    if dim == 0:
        return "batch_dim"
    elif dim == njt._ragged_idx:
        return "ragged_dim"
    return "normal_dim"


# Helper function for generating a comprehensive set of NJT sample inputs.
def _sample_njts(device, dtype, requires_grad=False, dims=None):
    if dims is None:
        dims = [2, 3, 4]
    if not isinstance(dims, (list, tuple)):
        dims = [dims]

    # contiguous NJTs
    for dim in dims:
        # with min / max seqlen cached
        shape = (_rnd(), None, *[_rnd() for _ in range(dim - 2)])
        nt = random_nt_from_dims(
            shape,
            device=device,
            dtype=dtype,
            requires_grad=requires_grad,
            layout=torch.jagged,
        )
        yield nt

        # without min / max seqlen cached
        values = _clone(nt.values())
        offsets = _clone(nt.offsets())
        yield torch.nested.nested_tensor_from_jagged(values, offsets).requires_grad_(
            requires_grad
        )

        # non-contiguous transposed NJT (not possible for 2D)
        if dim > 2:
            yield nt.transpose(-1, nt._ragged_idx)

        # non-contiguous with holes NJT
        values = _clone(nt.values())
        offsets = _clone(nt.offsets())
        # subtract 1 to cause holes
        lengths = _clone(offsets.diff() - 1)
        yield torch.nested.nested_tensor_from_jagged(
            values=values,
            offsets=offsets,
            lengths=lengths,
        ).requires_grad_(requires_grad)


# Computes an unbind-based reference for a given OpInfo on a given SampleInput.
# This reference unbinds the input NJT and invokes the op on each of the components,
# optionally wrapping the result in an NJT.
def unbind_reference(op, sample, wrap_output_as_njt=True):
    # first NJT in the arglist determines expected ragged structure
    nt_inp = (
        sample.input
        if sample.input.is_nested
        # TODO: look in kwargs too?
        else next(a for a in sample.args if a.is_nested)
    )

    out_ref_components = []
    for i in range(nt_inp.shape[0]):

        def _slice_input(t, i=i, inp=nt_inp):
            # any NJT with the same ragged structure as the input should
            # be sliced to pass to the reference
            if isinstance(t, torch.Tensor) and _raggedness_matches(t, inp):
                return t[i]
            # allow the SampleInput to tell us how to slice it for ref calculation
            elif isinstance(t, torch.Tensor) and hasattr(t, "_batch_dim"):
                bdim = t._batch_dim  # type: ignore[attr]
                if t.shape[bdim] == 1:
                    return t[0]
                else:
                    return t.select(bdim, i)
            else:
                return t

        inp = _slice_input(sample.input)
        args = tree_map(_slice_input, sample.args)
        kwargs = tree_map(_slice_input, sample.kwargs)

        # Handle indices in index_put
        if "index_put" in op.full_name and "indices" in kwargs:
            if len(kwargs["indices"]) > 1:
                # If after unrolling we still have indices left, use them
                kwargs["indices"] = [t[i] for t in kwargs["indices"][1:]]
            else:
                # If no indices are left, create them so they match the NJT implementation
                sequence_put = kwargs["indices"][0].tolist()
                if i in sequence_put:
                    kwargs["indices"] = [
                        torch.tensor(
                            list(range(inp.shape[0])),
                            dtype=torch.int32,
                            device=kwargs["indices"][0].device,
                        )
                    ]
                else:
                    kwargs["indices"] = [
                        torch.tensor(
                            [], dtype=torch.int32, device=kwargs["indices"][0].device
                        )
                    ]

        from torch.nested._internal.ops import _outer_to_inner_dim

        # Need to adjust dims to apply on NJT component
        if op._extra_op_data.dim_args is not None:
            # get all possible dim-related argnames that could be encountered for this op
            argnames = tree_map(
                lambda a: a.replace("...", ""),
                tree_flatten(op._extra_op_data.dim_args)[0],
            )
            # for all dim-related args present, convert from outer -> inner dim space
            for argname in {a for a in argnames if a in kwargs}:
                # allow the SampleInput to tell us how to canonicalize the dim kwargs
                ndim = nt_inp._ndim if hasattr(nt_inp, "_ndim") else nt_inp.dim()
                kwargs[argname] = _outer_to_inner_dim(
                    ndim, kwargs[argname], nt_inp._ragged_idx, canonicalize=True
                )

        out_ref_component = op.op(inp, *args, **kwargs)
        out_ref_components.append(out_ref_component)

    if wrap_output_as_njt:
        # handle list / tuple of outputs
        if len(out_ref_components) > 0 and isinstance(
            out_ref_components[0], (list, tuple)
        ):
            num_returns = len(out_ref_components[0])
            # ensure we get the same number of returns for each invocation
            assert all(len(o) == num_returns for o in out_ref_components)
            # construct NJTs from same index returns from each invocation
            njt_returns = [
                torch.nested.as_nested_tensor(
                    [o[r] for o in out_ref_components], layout=torch.jagged
                )
                for r in range(num_returns)
            ]
            return type(out_ref_components[0])(njt_returns)
        return torch.nested.as_nested_tensor(out_ref_components, layout=torch.jagged)

    return out_ref_components


# Computes the reference value for a non-reduction unary op with dim-wise application.
def unary_dimwise_reference(op, sample, batchwise_reference=None):
    # extract info about the dim args this op supports
    assert op._extra_op_data.dim_args is not None
    single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames()
    # only support a single non-list dim arg for now
    assert dimlist_argname is None
    assert single_dim_argname is not None
    if sample.kwargs[single_dim_argname] == 0:
        # unbind reference won't work for batch-wise operation; handle this case here
        assert batchwise_reference is not None
        return batchwise_reference(op, sample)
    return unbind_reference(op, sample)


# Computes the reference value for a reduction op.
def reduction_reference(op, sample):
    assert sample.input.is_nested

    # extract info about the dim args this op supports
    assert op._extra_op_data.dim_args is not None
    single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames()
    assert single_dim_argname is not None

    dim = sample.kwargs.get(
        dimlist_argname, sample.kwargs.get(single_dim_argname, None)
    )
    keepdim = sample.kwargs.get("keepdim", False)
    assert dim != 0, "reductions over just the batch dim are not supported"
    if isinstance(dim, (tuple, list)):
        reduce_on_ragged = sample.input._ragged_idx in dim
        reduce_on_batch = 0 in dim
    else:
        reduce_on_ragged = sample.input._ragged_idx == dim
        reduce_on_batch = dim == 0

    if dim is None:
        # calculate reference value by running reduction on values buffer
        return op.op(sample.input.values(), *sample.args, **sample.kwargs)

    if reduce_on_ragged and reduce_on_batch:
        # run reference directly on buffer with dims converted to inner space
        from torch.nested._internal.ops import _outer_to_inner_dim

        ref_kwargs = dict(sample.kwargs)
        assert dimlist_argname is not None
        ref_kwargs[dimlist_argname] = _outer_to_inner_dim(
            sample.input.dim(), dim, sample.input._ragged_idx, canonicalize=True
        )
        out = op.op(sample.input.values(), *sample.args, **ref_kwargs)
        if keepdim:
            if isinstance(out, (tuple, list)):
                # some ops return multiple things; unsqueeze all of them
                out = type(out)(o.unsqueeze(0) for o in out)
            else:
                out = out.unsqueeze(0)
        return out

    if reduce_on_ragged and not reduce_on_batch:
        # calculate reference value by running an unbind reference and stacking
        out_ref_components = unbind_reference(op, sample, wrap_output_as_njt=False)
        if len(out_ref_components) > 0 and isinstance(
            out_ref_components[0], (tuple, list)
        ):
            # some ops return multiple things; stack all of them
            num_returns = len(out_ref_components[0])
            # ensure we get the same number of returns for each invocation
            assert all(len(o) == num_returns for o in out_ref_components)
            # stack same index returns from each invocation
            stacked_returns = [
                torch.stack([o[r] for o in out_ref_components], dim=0)
                for r in range(num_returns)
            ]
            return type(out_ref_components[0])(stacked_returns)
        return torch.stack(out_ref_components, dim=0)

    # unbind reference works for other reductions
    return unbind_reference(op, sample)


def sample_inputs_elementwise_njt_unary(
    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
):
    if not op_kwargs:
        op_kwargs = {}

    for njt in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
    ):
        yield SampleInput(njt, kwargs=dict(op_kwargs), name=_describe_njt(njt))


def sample_inputs_elementwise_njt_binary(
    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
):
    if not op_kwargs:
        op_kwargs = {}

    for njt1 in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
    ):
        njt_desc = _describe_njt(njt1)
        njt2 = torch.randn_like(njt1)
        yield SampleInput(
            _clone(njt1),
            args=(njt2,),
            kwargs=dict(op_kwargs),
            name=f"{njt_desc}: (NT, NT)",
        )

        # broadcasting case: (B, j0, ...) with (B, 1, ...)
        dense_shape = list(njt1.shape)
        dense_shape[njt1._ragged_idx] = 1
        t = torch.randn(
            dense_shape,
            device=device,
            dtype=dtype,
            requires_grad=requires_grad,
        )
        t2 = _clone(t)
        # used for slicing in unbind_reference()
        t._batch_dim = 0
        t2._batch_dim = 0
        # (NT, T)
        yield SampleInput(
            _clone(njt1),
            args=(t,),
            kwargs=dict(op_kwargs),
            name=f"{njt_desc}: (NT, T) broadcasting 1 over ragged",
        )
        # (T, NT)
        yield SampleInput(
            t2,
            args=(_clone(njt1),),
            kwargs=dict(op_kwargs),
            name=f"{njt_desc}: (T, NT) broadcasting 1 over ragged",
        )

        # broadcasting case: (B, j0, ...) with (1, 1...)
        t = torch.randn(
            [1 for _ in range(njt1.dim())],
            device=device,
            dtype=dtype,
            requires_grad=requires_grad,
        )
        t2 = _clone(t)
        # used for slicing in unbind_reference()
        t._batch_dim = 0
        t2._batch_dim = 0
        # (NT, T)
        yield SampleInput(
            _clone(njt1),
            args=(t,),
            kwargs=dict(op_kwargs),
            name=f"{njt_desc}: (NT, T) broadcasting all 1s",
        )
        # (T, NT)
        yield SampleInput(
            t2,
            args=(_clone(njt1),),
            kwargs=dict(op_kwargs),
            name=f"{njt_desc}: (T, NT) broadcasting all 1s",
        )

        # broadcasting case: (B, j0, ...) with (...)
        if njt1.dim() > njt1._ragged_idx + 1:
            t = torch.randn(
                njt1.shape[njt1._ragged_idx + 1 :],
                device=device,
                dtype=dtype,
                requires_grad=requires_grad,
            )
            # (NT, T)
            yield SampleInput(
                _clone(njt1),
                args=(_clone(t),),
                kwargs=dict(op_kwargs),
                name=f"{njt_desc}: (NT, T) broadcasting normal dims",
            )
            # (T, NT)
            yield SampleInput(
                _clone(t),
                args=(_clone(njt1),),
                kwargs=dict(op_kwargs),
                name=f"{njt_desc}: (T, NT) broadcasting normal dims",
            )

        # broadcasting case: (B, j0, ...) with scalar
        t = torch.randn((), device=device, dtype=dtype, requires_grad=requires_grad)
        # (NT, T)
        yield SampleInput(
            _clone(njt1),
            args=(_clone(t),),
            kwargs=dict(op_kwargs),
            name=f"{njt_desc}: (NT, T) broadcasting with scalar",
        )
        # (T, NT)
        yield SampleInput(
            _clone(t),
            args=(_clone(njt1),),
            kwargs=dict(op_kwargs),
            name=f"{njt_desc}: (T, NT) broadcasting with scalar",
        )

    # mixed broadcasting case: (B, j0, 1) with (B, 1, D)
    B = 4
    D = 16
    njt = random_nt_from_dims(
        (B, None, 1),
        device=device,
        dtype=dtype,
        requires_grad=requires_grad,
        layout=torch.jagged,
    )
    njt_desc = _describe_njt(njt)
    t = torch.randn(B, 1, D, device=device, dtype=dtype, requires_grad=requires_grad)
    t2 = _clone(t)
    # used for slicing in unbind_reference()
    t._batch_dim = 0
    t2._batch_dim = 0

    # (NT, T)
    yield SampleInput(
        _clone(njt),
        args=(t,),
        kwargs=dict(op_kwargs),
        name=f"{njt_desc}: (NT, T) mixed broadcasting",
    )
    # (T, NT)
    yield SampleInput(
        t2,
        args=(_clone(njt),),
        kwargs=dict(op_kwargs),
        name=f"{njt_desc}: (T, NT) mixed broadcasting",
    )


def sample_inputs_njt_reduction(
    op_info,
    device,
    dtype,
    requires_grad,
    supports_keepdim=True,
    op_kwargs=None,
    **kwargs,
):
    if not op_kwargs:
        op_kwargs = {}

    # extract info about the dim args this op supports
    assert op_info._extra_op_data.dim_args is not None
    (
        single_dim_argname,
        dimlist_argname,
    ) = op_info._extra_op_data.get_dim_argnames()
    assert single_dim_argname is not None
    supports_dimlist = dimlist_argname is not None

    for njt in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
    ):
        njt_desc = _describe_njt(njt)
        keepdim_values = [False, True] if supports_keepdim else [None]
        for keepdim in keepdim_values:
            keepdim_suffix = f" with keepdim={keepdim}" if supports_keepdim else ""
            # single dim-wise reduction; includes reduction over the ragged dim
            # NB: reduction over the batch dim is not supported!
            # TODO: Cover this in the set of error inputs
            for dim in range(1, njt.dim()):
                dim_desc = "normal" if dim != njt._ragged_idx else "ragged"
                yield SampleInput(
                    _clone(njt),
                    kwargs={
                        **op_kwargs,
                        single_dim_argname: dim,
                        **({"keepdim": keepdim} if supports_keepdim else {}),
                    },
                    name=f"{njt_desc}: {dim_desc} dim reduction{keepdim_suffix}",
                )

            if supports_dimlist:
                # reduce on both batch and ragged dims
                yield SampleInput(
                    _clone(njt),
                    kwargs={
                        **op_kwargs,
                        dimlist_argname: [0, njt._ragged_idx],
                        **({"keepdim": keepdim} if supports_keepdim else {}),
                    },
                    name=f"{njt_desc}: batch+ragged reduction{keepdim_suffix}",
                )

                # reduce on batch, ragged, and other dims
                for other_dim in range(njt._ragged_idx + 1, njt.dim()):
                    yield SampleInput(
                        _clone(njt),
                        kwargs={
                            **op_kwargs,
                            dimlist_argname: [0, njt._ragged_idx, other_dim],
                            **({"keepdim": keepdim} if supports_keepdim else {}),
                        },
                        name=(
                            f"{njt_desc}: batch+ragged+dim={other_dim} "
                            f"reduction{keepdim_suffix}"
                        ),
                    )

                # reduce on two non-ragged, non-batch dims
                if njt.dim() > 3 and njt._ragged_idx == 1:
                    yield SampleInput(
                        _clone(njt),
                        kwargs={
                            **op_kwargs,
                            dimlist_argname: [njt.dim() - 2, njt.dim() - 1],
                            **({"keepdim": keepdim} if supports_keepdim else {}),
                        },
                        name=f"{njt_desc}: two normal dim reduction{keepdim_suffix}",
                    )

                # full reduction by specifying all dims
                yield SampleInput(
                    _clone(njt),
                    kwargs={
                        **op_kwargs,
                        dimlist_argname: list(range(njt.dim())),
                        **({"keepdim": keepdim} if supports_keepdim else {}),
                    },
                    name=f"{njt_desc}: all dim reduction{keepdim_suffix}",
                )

                # TODO: Reducing on ragged dim and non-batch dim is not supported;
                # cover this in the set of error inputs.

        # full reduction
        yield SampleInput(
            _clone(njt),
            kwargs=dict(op_kwargs),
            name=f"{njt_desc}: full reduction with keepdim={keepdim}",
        )


def unsupported_sample_inputs_func(op_name):
    def _f(op_info, device, dtype, requires_grad, op_name=op_name, **kwargs):
        raise RuntimeError(
            f"OpInfo for {op_name} does not support NJT. Support can be added by modifying "
            "torch/testing/_internal/opinfo/definitions/nested.py."
        )

    return _f


def unsupported_reference(op_name):
    def _f(op, sample):
        raise RuntimeError(
            f"OpInfo for {op_name} does not define a ref() function. Support can be added by "
            "modifying torch/testing/_internal/opinfo/definitions/nested.py."
        )

    return _f


# === BEGIN OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES ===
def sample_inputs_unary_dimwise(
    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
):
    if op_kwargs is None:
        op_kwargs = {}

    # only support a single non-list dim arg for now
    assert op_info._extra_op_data is not None
    single_dim_argname, dimlist_argname = op_info._extra_op_data.get_dim_argnames()
    assert single_dim_argname is not None
    assert dimlist_argname is None

    for njt in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
    ):
        for dim in range(njt.dim()):
            kwargs = {single_dim_argname: dim}
            kwargs.update(op_kwargs)
            yield SampleInput(
                _clone(njt),
                kwargs=kwargs,
                name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}",
            )


def batchwise_reference_chunk(op, sample):
    # reference for chunk() over dim=0
    B = sample.input.size(0)
    num_chunks = sample.kwargs["chunks"]
    chunk_size = math.ceil(B / num_chunks)
    num_full_chunks = B // chunk_size
    chunk_sizes = [chunk_size for _ in range(num_full_chunks)]
    if B % chunk_size != 0:
        # final chunk contains the leftovers
        chunk_sizes.append(B % chunk_size)

    # split unbound components into chunks according to calculated sizes
    components = list(sample.input.unbind())
    start = 0
    chunks = []
    for chunk_size in chunk_sizes:
        chunks.append(components[start : start + chunk_size])
        start += chunk_size

    # rejoin into NJT outputs
    return [torch.nested.as_nested_tensor(lst, layout=torch.jagged) for lst in chunks]


def batchwise_reference_narrow(op, sample):
    # TODO: write this!
    raise NotImplementedError


def batchwise_reference_select(op, sample):
    # reference for select() over dim=0
    return sample.input.unbind()[sample.kwargs["index"]]


def batchwise_reference_split(op, sample):
    # TODO: write this!
    raise NotImplementedError


def batchwise_reference_split_with_sizes(op, sample):
    # TODO: write this!
    raise NotImplementedError


def batchwise_reference_unflatten(op, sample):
    # TODO: write this!
    raise NotImplementedError


def batchwise_reference_unsqueeze(op, sample):
    raise ValueError("unsqueeze() is not intended to operate on the batch dim")


def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs):
    # non-contiguous NJTs
    for njt in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
    ):
        yield SampleInput(njt, name=_describe_njt(njt))

    for memory_format in (torch.contiguous_format, torch.preserve_format):
        # construct a "non-contiguous with holes" NJT
        values = torch.randn(
            10, 5, device=device, dtype=dtype, requires_grad=requires_grad
        )
        offsets = torch.tensor([0, 2, 4, 10], device=device, dtype=torch.int64)
        lengths = torch.tensor([2, 1, 3], device=device, dtype=torch.int64)
        njt = torch.nested.nested_tensor_from_jagged(
            values, offsets=offsets, lengths=lengths
        )

        njt_desc = _describe_njt(njt)
        yield SampleInput(
            njt,
            kwargs={"memory_format": memory_format},
            name=f"{njt_desc}: {memory_format})",
        )


def sample_inputs_fill(op_info, device, dtype, requires_grad, **kwargs):
    # scalar case
    unary_func = partial(sample_inputs_elementwise_njt_unary, op_kwargs={"value": 42.0})
    yield from unary_func(op_info, device, dtype, requires_grad)

    # TODO: add Tensor case


def sample_inputs_mvl_gamma(p):
    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"p": p})


def sample_inputs_polygamma_n(n):
    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n})


def sample_inputs_special_polygamma_n(n):
    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n})


def sample_inputs_to(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs):
    for njt in _sample_njts(
        device=device,
        dtype=dtype,
        requires_grad=requires_grad,
        dims=[2, 3, 4],
    ):
        other_dtypes = (
            d for d in (torch.float32, torch.half, torch.double) if d is not dtype
        )
        for other_dtype in other_dtypes:
            sample_name = f"{njt.dim()}D: {dtype} -> {other_dtype}"
            yield SampleInput(_clone(njt), kwargs={"dtype": dtype}, name=sample_name)

        # only include device transfer for CUDA inputs
        if "cuda" in device:
            other_device = "cpu"
            sample_name = f"{_describe_njt(njt)}: {device} -> {other_device}"
            yield SampleInput(
                _clone(njt), kwargs={"device": other_device}, name=sample_name
            )


def sample_inputs_bmm(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs):
    for njt_3d in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3]
    ):
        # (B, j1, D) x (B, D, E) => (B, j1, E)
        if njt_3d._ragged_idx == 1:
            B, D = njt_3d.shape[0], njt_3d.shape[-1]
            E = D + 2
            other = torch.randn(B, D, E, device=device, dtype=dtype)
            # used for slicing in unbind_reference()
            other._batch_dim = 0
            njt_desc = _describe_njt(njt_3d)
            yield SampleInput(
                _clone(njt_3d),
                kwargs={"mat2": other},
                name=f"{njt_desc}: (B, j, D) x (B, D, E)",
            )

        # TODO (need factory functions):
        # (B, D, j1) x (B, j1, E) => (B, D, E)


def reference_bmm(op, sample):
    # unbind reduces a dim and bmm requires 3D, so use matmul as the reference
    matmul_op = copy(op)
    matmul_op.op = torch.matmul
    # change arg name from mat2 -> other
    modified_sample = copy(sample)
    other = modified_sample.kwargs["mat2"]
    del modified_sample.kwargs["mat2"]
    modified_sample.kwargs["other"] = other
    return unbind_reference(matmul_op, modified_sample)


def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs):
    for sample_input in sample_inputs_unary_dimwise(
        op_info, device, dtype, requires_grad, **kwargs
    ):
        # ragged dim chunking: test a single chunks value
        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
            yield _update_sample(sample_input, {"chunks": 3})
        # other dim chunking: test different chunks values
        else:
            D = sample_input.input.size(sample_input.kwargs["dim"])
            for chunks in [1, D // 2, D - 1, D]:
                yield _update_sample(sample_input, {"chunks": chunks})


def sample_inputs_matmul(
    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
):
    # also run bmm samples through
    for sample_input in sample_inputs_bmm(op_info, device, dtype, requires_grad):
        # change arg name from mat2 -> other
        other = sample_input.kwargs["mat2"]
        del sample_input.kwargs["mat2"]
        sample_input.kwargs["other"] = other
        yield sample_input

    # 3D cases not covered by bmm
    for njt_3d in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3]
    ):
        # (B, j1, D) x (D, E) => (B, j1, E)
        if njt_3d._ragged_idx == 1:
            D = njt_3d.shape[-1]
            E = D + 2
            njt_desc = _describe_njt(njt_3d)
            yield SampleInput(
                _clone(njt_3d),
                kwargs={"other": torch.randn(D, E, device=device, dtype=dtype)},
                name=f"{njt_desc}: (B, j, D) x (D, E)",
            )

    # 4D cases
    for njt_4d in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[4]
    ):
        # (B, j1, D, E) x (E, F) => (B, j1, D, F)
        if njt_4d._ragged_idx == 1:
            E = njt_4d.shape[-1]
            F = E + 2
            njt_desc = _describe_njt(njt_4d)
            yield SampleInput(
                _clone(njt_4d),
                kwargs={"other": torch.randn(E, F, device=device, dtype=dtype)},
                name=f"{njt_desc}: (B, j, D, E) x (E, F)",
            )

    # Dense x NJT cases
    for njt_3d in _sample_njts(
        device=device,
        dtype=dtype,
        requires_grad=requires_grad,
        dims=[3],
    ):
        # (B, F, E) x (B, E, j1) => (B, F, j1)
        if njt_3d._ragged_idx == 2:
            B = njt_3d.shape[0]
            E = njt_3d.shape[1]
            F = E + 2
            njt_desc = _describe_njt(njt_3d)
            dense_t = torch.randn(
                B, F, E, device=device, dtype=dtype, requires_grad=requires_grad
            )
            dense_t._batch_dim = 0  # for unbind_reference()
            yield SampleInput(
                dense_t,
                args=(_clone(njt_3d),),
                name=f"{njt_desc}: (B, F, E) x (B, E, j1)",
            )

    # NJT x NJT => Dense case
    for njt_3d in _sample_njts(
        device=device,
        dtype=dtype,
        requires_grad=requires_grad,
        dims=[3],
    ):
        # (B, E, j1) x (B, j1, F) => (B, E, F)
        if njt_3d._ragged_idx == 2 and njt_3d.is_contiguous():
            B, E, _ = njt_3d.shape
            sum_j1 = len(njt_3d.values())
            other_cont = torch.randn(
                sum_j1, E + 2, device=device, dtype=dtype, requires_grad=requires_grad
            )
            other_njt = torch.nested.nested_tensor_from_jagged(
                other_cont, njt_3d.offsets(), lengths=njt_3d._lengths
            )
            njt_desc = _describe_njt(njt_3d)
            yield SampleInput(
                _clone(njt_3d),
                kwargs={"other": _clone(other_njt)},
                name=f"{njt_desc}: (B, E, j1) x (B, j1, F)",
            )

        # TODO (need factory functions):
        # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F)


def sample_inputs_masked_select(
    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
):
    for njt in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2]
    ):
        yield SampleInput(
            njt,
            kwargs={"mask": (torch.randn_like(njt, requires_grad=False) < 0.0)},
            name=_describe_njt(njt),
        )


def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs):
    for sample_input in sample_inputs_unary_dimwise(
        op_info, device, dtype, requires_grad, **kwargs
    ):
        # ragged dim narrowing: test a single start, length value
        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
            yield _update_sample(sample_input, {"start": 1, "length": 2})
        # other dim narrowing: test different start, length values
        else:
            D = sample_input.input.size(sample_input.kwargs["dim"])
            for start, length in [(0, D), (0, D - 1), (1, D - 1), (D - 1, 1)]:
                yield _update_sample(sample_input, {"start": start, "length": length})


def sample_inputs_nn_functional_embedding(
    op_info, device, dtype, requires_grad, **kwargs
):
    indices = torch.nested.nested_tensor(
        [
            torch.tensor([0, 2, 1, 3]),
            torch.tensor([4, 2, 1]),
            torch.tensor([6, 7, 5, 2, 4]),
        ],
        layout=torch.jagged,
        dtype=torch.int64,
        device=device,
    )

    NUM_EMBEDDINGS = 20
    EMBEDDING_DIM = 32
    weight = torch.randn(NUM_EMBEDDINGS, EMBEDDING_DIM, device=device, dtype=dtype)

    # NB: the OpInfo entry for embedding_bag expects weight first so the gradients
    # can be checked
    yield SampleInput(
        _clone(weight).requires_grad_(),
        args=(indices,),
    )

    yield SampleInput(
        _clone(weight).requires_grad_(),
        args=(indices,),
        kwargs={"padding_idx": 1},
    )


def sample_inputs_index_put(
    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
):
    for njt in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
    ):
        for dim in range(njt.dim()):
            indices = [
                torch.tensor(list(range(njt.size(0))), device=njt.device),
                *[
                    torch.tensor([0] * njt.size(0), device=njt.device)
                    for _ in range(dim - 1)
                ],
            ]
            njt_desc = _describe_njt(njt)
            yield SampleInput(
                _clone(njt),
                kwargs={
                    "indices": indices,
                    "values": torch.tensor(1.0, device=njt.device),
                },
                name=f"{njt_desc}: up to dim {dim - 1}",
            )

    # Non-cont NJT for completeness
    offsets = torch.tensor([0, 2, 5, 7], device=device)
    lengths = torch.tensor([2, 2, 2], device=device)
    indices = [
        torch.tensor([0, 1, 2], device=device),
        torch.tensor([0, 1, 1], device=device),
        torch.tensor([0, 0, 0], device=device),
    ]
    a = torch.nested.nested_tensor_from_jagged(
        torch.zeros(7, 3, device=device), offsets, lengths
    ).requires_grad_(requires_grad)

    njt_desc = _describe_njt(a)
    yield SampleInput(
        _clone(a),
        kwargs={"indices": indices, "values": torch.tensor(1.0, device=a.device)},
        name=f"{njt_desc}: all dims",
    )


def sample_inputs_nn_functional_embedding_bag(
    op_info, device, dtype, requires_grad, **kwargs
):
    for generate_per_sample_weight in (True, False):
        for mode in ("sum", "mean", "max"):
            # per_sample_weights is only supported for mode='sum'
            if mode != "sum" and generate_per_sample_weight:
                continue

            NUM_EMBEDDINGS = 10
            EMBEDDING_DIM = 32
            weight = torch.randn(
                NUM_EMBEDDINGS, EMBEDDING_DIM, dtype=dtype, device=device
            )

            njt = torch.nested.nested_tensor(
                [
                    torch.randint(0, NUM_EMBEDDINGS, size=(2,)),
                    torch.randint(0, NUM_EMBEDDINGS, size=(3,)),
                    torch.randint(0, NUM_EMBEDDINGS, size=(4,)),
                ],
                layout=torch.jagged,
                dtype=torch.int64,
                device=device,
            )

            per_sample_weights = None
            if generate_per_sample_weight:
                per_sample_weights = torch.randn_like(njt, dtype=dtype)

            # NB: the OpInfo entry for embedding_bag expects weight first so the gradients
            # can be checked
            yield SampleInput(
                weight,
                args=(njt,),
                kwargs={
                    "mode": mode,
                    "per_sample_weights": per_sample_weights,
                },
            )


def reference_nn_functional_embedding_bag(op, sample):
    # run reference on a single bag at a time
    new_kwargs = dict(sample.kwargs)
    new_kwargs.update(
        {"offsets": torch.tensor([0], dtype=torch.int64, device=sample.input.device)}
    )
    # flip input / weight back to what unbind_reference() expects
    sample = SampleInput(sample.args[0], args=(sample.input,), kwargs=new_kwargs)
    old_op = op.op
    op.op = torch.nn.functional.embedding_bag
    output = unbind_reference(op, sample, wrap_output_as_njt=False)
    op.op = old_op
    # concat bag outputs to get final output
    return torch.cat(output, dim=0)


def sample_inputs_nn_functional_linear(op_info, device, dtype, requires_grad, **kwargs):
    for njt in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4, 5]
    ):
        # projection over a ragged dim is not currently supported
        if is_nested_int(njt.size(-1)):
            continue

        # with bias
        NUM_OUTPUT = 10
        weight = torch.randn(
            NUM_OUTPUT,
            njt.size(-1),
            device=device,
            dtype=dtype,
            requires_grad=requires_grad,
        )
        bias = torch.randn(
            NUM_OUTPUT, device=device, dtype=dtype, requires_grad=requires_grad
        )
        yield SampleInput(
            _clone(njt),
            kwargs={
                "weight": _clone(weight),
                "bias": _clone(bias),
            },
            name=f"{_describe_njt(njt)}: with bias",
        )

        # without bias
        yield SampleInput(
            _clone(njt),
            kwargs={
                "weight": _clone(weight),
            },
            name=f"{_describe_njt(njt)}: without bias",
        )


def sample_inputs_nn_functional_prelu(op_info, device, dtype, requires_grad, **kwargs):
    for njt in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4]
    ):
        # Second dim is interpreted as number of channels; this should be non-ragged for now
        num_channels = njt.size(1)
        if is_nested_int(num_channels):
            continue

        # 1D weight
        weight = torch.randn(
            num_channels,
            device=device,
            dtype=dtype,
            requires_grad=requires_grad,
        )

        yield SampleInput(
            _clone(njt),
            kwargs={
                "weight": _clone(weight),
            },
            name=f"{_describe_njt(njt)}: 1D weight",
        )

        # scalar tensor weight
        yield SampleInput(
            _clone(njt),
            kwargs={
                "weight": torch.tensor(4.2, device=device, dtype=dtype),
            },
            name=f"{_describe_njt(njt)}: scalar tensor weight",
        )


def sample_inputs_nn_functional_rms_norm(
    op_info, device, dtype, requires_grad, **kwargs
):
    for njt in _sample_njts(
        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4]
    ):
        # normalize over non-ragged dims
        for start_dim in range(njt.dim()):
            if start_dim <= njt._ragged_idx:
                continue

            normalized_shape = njt.shape[start_dim:]
            weight = torch.randn(
                normalized_shape,
                device=device,
                dtype=dtype,
                requires_grad=requires_grad,
            )

            yield SampleInput(
                _clone(njt),
                kwargs={
                    "normalized_shape": normalized_shape,
                    "weight": weight,
                },
                name=f"{_describe_njt(njt)}",
            )


sample_inputs_nn_functional_threshold = partial(
    sample_inputs_elementwise_njt_unary,
    op_kwargs={"threshold": float.fromhex("0x1.3ap-3"), "value": -9},
)


def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
    for sample_input in sample_inputs_unary_dimwise(
        op_info, device, dtype, requires_grad, **kwargs
    ):
        # ragged dim chunking: test a single index
        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
            yield _update_sample(sample_input, {"index": 0})
        # other dim chunking: test different indices
        else:
            D = sample_input.input.size(sample_input.kwargs["dim"])
            for index in [0, D // 2, D - 1]:
                yield _update_sample(sample_input, {"index": index})


def sample_inputs_split(op_info, device, dtype, requires_grad, **kwargs):
    for sample_input in sample_inputs_unary_dimwise(
        op_info, device, dtype, requires_grad, **kwargs
    ):
        # ragged dim chunking: test a single split size
        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
            yield _update_sample(sample_input, {"split_size_or_sections": 3})
        # other dim chunking: test different split sizes
        else:
            D = sample_input.input.size(sample_input.kwargs["dim"])
            for split_size in [1, D // 2, D - 1, D]:
                yield _update_sample(
                    sample_input, {"split_size_or_sections": split_size}
                )


def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs):
    for sample_input in sample_inputs_unary_dimwise(
        op_info, device, dtype, requires_grad, **kwargs
    ):
        # It will never make sense to operate on the ragged dim.
        # TODO: Handle this with error_inputs
        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
            continue

        D = sample_input.input.size(sample_input.kwargs["dim"])
        # splits should add up to D
        split1 = torch.randint(0, D - 1, size=()).item()
        split2 = D - split1
        yield _update_sample(sample_input, {"split_sizes": [split1, split2]})


def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs):
    # squeeze-specific NJT generator (need to ensure there are some 1s in the shape)
    def _get_njts():
        njt = random_nt_from_dims(
            (4, None, 1, 3, 1),
            device=device,
            dtype=dtype,
            requires_grad=requires_grad,
            layout=torch.jagged,
        )
        yield njt
        # without min / max seqlen cached
        values = njt.values().detach().clone()
        offsets = njt.offsets().detach().clone()
        yield torch.nested.nested_tensor_from_jagged(values, offsets)
        # non-contiguous transposed
        yield njt.transpose(1, 3)
        # non-contiguous with holes
        values = njt.values().detach().clone()
        offsets = njt.offsets().detach().clone()
        # subtract 1 to cause holes
        lengths = (offsets.diff() - 1).detach().clone()
        yield torch.nested.nested_tensor_from_jagged(
            values=values,
            offsets=offsets,
            lengths=lengths,
        )

    for njt in _get_njts():
        # single dim operation
        for dim in range(njt.dim()):
            # Operation on batch / ragged dim is never expected to work.
            # TODO: Handle these via error_inputs.
            if dim == 0 or dim == njt._ragged_idx:
                continue

            yield SampleInput(
                _clone(njt),
                kwargs={"dim": dim},
                name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}",
            )

        # multiple dim operation (pass no args)
        yield SampleInput(
            _clone(njt),
            kwargs={"dim": dim},
            name=f"{_describe_njt(njt)}: multiple dims",
        )


def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs):
    for sample_input in sample_inputs_unary_dimwise(
        op_info, device, dtype, requires_grad, **kwargs
    ):
        # It will never make sense to operate on the ragged dim.
        # TODO: Handle this with error_inputs
        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
            continue

        D = sample_input.input.size(sample_input.kwargs["dim"])
        # sizes should multiply to be D
        yield _update_sample(sample_input, {"sizes": [D, 1]})
        yield _update_sample(sample_input, {"sizes": [1, D]})
        if D % 2 == 0:
            yield _update_sample(sample_input, {"sizes": [D // 2, 2]})
            yield _update_sample(sample_input, {"sizes": [2, D // 2]})


def sample_inputs_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
    for sample_input in sample_inputs_unary_dimwise(
        op_info, device, dtype, requires_grad, **kwargs
    ):
        yield sample_input

        last_dim_sample = _update_sample(sample_input, {"dim": -1})
        last_dim_sample.name = (
            f"{_describe_njt(last_dim_sample.input)}: add dim to the end"
        )
        # Tell the unbind reference how to canonicalize the dim kwargs
        # This is necessary because unsqueeze() allows for a dim after
        # the last dim to indicate an unsqueeze at the end.
        last_dim_sample.input._ndim = last_dim_sample.input.dim() + 1
        yield last_dim_sample


def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
    for sample in sample_inputs_elementwise_njt_binary(
        op_info, device, dtype, requires_grad, **kwargs
    ):
        other = sample.args[0]
        sample.args = ()
        sample.kwargs["other"] = other
        sample.kwargs["condition"] = sample.input > 0.0
        sample.name = sample.name.replace("(", "(NT, ")
        yield sample


# === END OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES ===


# Mapping of OpInfo full names -> sample_inputs_funcs, which define the set of sample inputs
# (involving NJTs) to pass to the op. Full name consists of the OpInfo's name and variant name
# separated by a period (e.g. special.polygamma.special_polygamma_n_0). These are necessary
# to specify if they cannot be auto-generated for some reason. Try to keep these sorted
# in alphabetical order!
njt_sample_inputs = {
    "bmm": sample_inputs_bmm,
    "chunk": sample_inputs_chunk,
    "clone": sample_inputs_clone,
    "count_nonzero": partial(sample_inputs_njt_reduction, supports_keepdim=False),
    "fill": sample_inputs_fill,
    **{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)},
    "nn.functional.embedding": sample_inputs_nn_functional_embedding,
    "nn.functional.embedding_bag": sample_inputs_nn_functional_embedding_bag,
    "nn.functional.linear": sample_inputs_nn_functional_linear,
    "nn.functional.prelu": sample_inputs_nn_functional_prelu,
    "nn.functional.rms_norm": sample_inputs_nn_functional_rms_norm,
    "nn.functional.threshold": sample_inputs_nn_functional_threshold,
    **{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)},
    "special.polygamma.special_polygamma_n_0": sample_inputs_special_polygamma_n(n=0),
    "to": sample_inputs_to,
    "matmul": sample_inputs_matmul,
    "masked_select": sample_inputs_masked_select,
    "narrow": sample_inputs_narrow,
    "index_put": sample_inputs_index_put,
    # these two don't have ReductionOpInfo entries
    "max.reduction_with_dim": sample_inputs_njt_reduction,
    "min.reduction_with_dim": sample_inputs_njt_reduction,
    "select": sample_inputs_select,
    "split": sample_inputs_split,
    "split_with_sizes": sample_inputs_split_with_sizes,
    "squeeze": sample_inputs_squeeze,
    "unflatten": sample_inputs_unflatten,
    "unsqueeze": sample_inputs_unsqueeze,
    "where": sample_inputs_where,
}

njt_references = {
    "bmm": reference_bmm,
    "chunk": partial(
        unary_dimwise_reference, batchwise_reference=batchwise_reference_chunk
    ),
    "count_nonzero": reduction_reference,
    # these two don't have ReductionOpInfo entries
    "max.reduction_with_dim": reduction_reference,
    "min.reduction_with_dim": reduction_reference,
    "narrow": partial(
        unary_dimwise_reference, batchwise_reference=batchwise_reference_narrow
    ),
    "select": partial(
        unary_dimwise_reference, batchwise_reference=batchwise_reference_select
    ),
    "split": partial(
        unary_dimwise_reference, batchwise_reference=batchwise_reference_split
    ),
    "split_with_sizes": partial(
        unary_dimwise_reference,
        batchwise_reference=batchwise_reference_split_with_sizes,
    ),
    "squeeze": unbind_reference,
    "nn.functional.embedding_bag": reference_nn_functional_embedding_bag,
    "unflatten": partial(
        unary_dimwise_reference, batchwise_reference=batchwise_reference_unflatten
    ),
    "unsqueeze": partial(
        unary_dimwise_reference, batchwise_reference=batchwise_reference_unsqueeze
    ),
}


# Translates an OpInfo entry to one that operates on NJTs.
def translate_opinfo(op):
    new_op = copy(op)
    new_op.supports_njt = True
    # add some extra info for use in generating tests on the right subset of ops
    new_op._extra_op_data = extra_op_data.get(op.full_name, ExtraOpData())

    if op.full_name in njt_sample_inputs:
        new_op.sample_inputs_func = njt_sample_inputs[op.full_name]
        new_op.ref = njt_references.get(op.full_name, unbind_reference)
    elif isinstance(op, UnaryUfuncInfo):
        new_op.sample_inputs_func = partial(
            sample_inputs_elementwise_njt_unary, op_kwargs=None
        )
        new_op.ref = unbind_reference
    elif isinstance(op, BinaryUfuncInfo):
        new_op.sample_inputs_func = partial(
            sample_inputs_elementwise_njt_binary, op_kwargs=None
        )
        new_op.ref = unbind_reference
    elif isinstance(op, ReductionOpInfo):
        new_op.sample_inputs_func = partial(sample_inputs_njt_reduction, op_kwargs=None)
        new_op.ref = reduction_reference
    # TODO: Translate the rest of the OpInfos
    else:
        new_op.sample_inputs_func = unsupported_sample_inputs_func(op.full_name)
        new_op.ref = unsupported_reference(op.full_name)
        new_op.supports_njt = False

    return new_op


njt_op_db = [translate_opinfo(op) for op in op_db]
