# mypy: ignore-errors

from abc import abstractmethod
import tempfile
import unittest

from copy import deepcopy
from functools import reduce, partial
from itertools import product
from operator import mul


import torch
import torch.cuda
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import _reduction as _Reduction
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
    gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors
from torch.autograd import Variable
from torch.types import _TensorOrTensors
import torch.backends.cudnn

from typing import Callable, Union, Any
from collections.abc import Sequence

TemporaryFile = tempfile.TemporaryFile
PRECISION = 1e-5


def get_reduction(m):
    result = getattr(m, 'reduction', None)
    if result is None:
        result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False)
    assert result is not None
    return result


def get_weight(m):
    result = getattr(m, 'weight', None)
    if result is not None:
        return result
    return getattr(m, 'weights', None)

# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
#
# The way to check API parity is to add parity tests for the NN module / functional of interest.
# Here are the detailed steps:
#
# For NN module:
# 1. Make sure you already have a test dict with the module configuration you want to test.
# 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching
#    the Python module constructor arguments. For example, if in the test dict we pass
#    `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)`
#    as the corresponding C++ constructor argument to `torch::nn::Linear`.
# 3. If in the process of performing the above step you referenced any variables
#    in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry
#    to the test dict to make sure that those variables are populated with the right Python values.
#    For example, if the Python constructor call is
#    `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`,
#    the corresponding C++ constructor argument is
#    `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`,
#    and the `cpp_var_map` entry must be
#    `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples`
#    used in the C++ constructor argument with the Python tensor value `random_samples`.
#
# For NN functional:
# 1. Make sure you already have a test dict with the functional configuration you want to test.
# 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`,
#    then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python
#    functional optional arguments. For example, if the test dict's `constructor` entry is
#    `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`,
#    then the `cpp_options_args` entry should be
#    "F::InterpolateFuncOptions().size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)".
# 3. Otherwise, if the test dict's `constructor` entry looks like
#    `wrap_functional(lambda i: F.some_functional_name(...))`,
#    then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python
#    functional function call. For example, if the test dict's `constructor` entry is
#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
#    then the `cpp_function_call` entry should be
#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
# 4. If in the process of performing the above two steps you referenced any variables
#    in the `cpp_options_args` or `cpp_function_call` entry, you must
#    add `cpp_var_map` entry to the test dict to make sure that those variables
#    are populated with the right Python values. For example, if the test dict's `constructor` entry is
#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
#    then the `cpp_function_call` entry should be
#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
#    Notice that there are two variables `i` and `t` that need to have their values provided,
#    and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`.
#    (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value
#    and the C++ parity test mechanism will populate `i` with the Python input value correctly.)
#
# There are also a few optional flags in the test dict to control the C++ parity test behavior:
#
# - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True.
# - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True.


module_tests = [
    dict(
        module_name='Linear',
        constructor_args=(10, 8),
        cpp_constructor_args='torch::nn::LinearOptions(10, 8)',
        input_size=(4, 10),
        reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
        with_tf32=True,
        tf32_precision=0.005,
        default_dtype=torch.double,
    ),
    dict(
        module_name='Linear',
        constructor_args=(10, 8, False),
        cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)',
        input_size=(4, 10),
        desc='no_bias',
        reference_fn=lambda i, p, _: torch.mm(i, p[0].t()),
        with_tf32=True,
        tf32_precision=0.005,
        default_dtype=torch.double,
    ),
    dict(
        module_name='RReLU',
        input_size=(1, 2, 2),
        test_cuda=False,
        default_dtype=torch.double,
    ),
    dict(
        module_name='RReLU',
        constructor_args=(0.1, 0.9),
        cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
        input_size=(4, 4, 5),
        desc='with_up_down',
        test_cuda=False,
        default_dtype=torch.double,
    ),
    dict(
        module_name='Flatten',
        input_size=(2, 3, 4, 5),
        reference_fn=lambda i, *_: torch.flatten(i, 1),
        default_dtype=torch.double,
    ),
    # TODO: reference function
    dict(
        module_name='CrossMapLRN2d',
        constructor_args=(5, 5e-3, 1e-3, 2),
        cpp_constructor_args='torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)',
        input_size=(2, 3, 6, 6),
        check_gradgrad=False,
        # TODO(#50743): Figure out the error. "RuntimeError: Unrecognized tensor type ID: Batched"
        check_batched_grad=False,
        default_dtype=torch.double,
    ),
]


# Generates rand tensor with non-equal values. This ensures that duplicate
# values won't be causing test failure for modules like MaxPooling.
# size should be small, otherwise randperm fails / long overflows.
def _rand_tensor_non_equal(*size):
    total = reduce(mul, size, 1)
    return torch.randperm(total).view(*size).double()


def wrap_functional(fn, **kwargs):
    class FunctionalModule(nn.Module):
        def forward(self, *args):
            return fn(*args, **kwargs)
    return FunctionalModule


def poissonnllloss_no_reduce_test():
    t = torch.randn(10, 10)
    return dict(
        fullname='PoissonNLLLoss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')),
        cpp_function_call='F::poisson_nll_loss('
                          'i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))',
        input_fn=lambda: torch.rand(10, 10),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_: i.exp() - t.mul(i),
        pickle=False,
        default_dtype=torch.double)


def bceloss_no_reduce_test():
    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
    return dict(
        fullname='BCELoss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
        cpp_function_call='F::binary_cross_entropy('
                          'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
        pickle=False,
        precision=7e-4,
        default_dtype=torch.double)


def bceloss_no_reduce_scalar_test():
    t = torch.randn(()).gt(0).to(torch.double)
    return dict(
        fullname='BCELoss_no_reduce_scalar',
        constructor=wrap_functional(
            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
        cpp_function_call='F::binary_cross_entropy('
                          'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
        pickle=False,
        default_dtype=torch.double)


def bceloss_weights_no_reduce_test():
    t = Variable(torch.randn(15, 10, dtype=torch.double).gt(0).to(torch.double))
    weights = torch.rand(10, dtype=torch.double)
    return dict(
        fullname='BCELoss_weights_no_reduce',
        constructor=wrap_functional(
            lambda i: F.binary_cross_entropy(i, t.type_as(i),
                                             weight=weights.type_as(i), reduction='none')),
        cpp_function_call='F::binary_cross_entropy('
                          'i, t.to(i.options()), '
                          'F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))',
        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
        reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
        pickle=False,
        precision=3e-4,
        default_dtype=torch.double,
    )


def bceloss_weights_no_reduce_scalar_test():
    t = torch.randn(()).gt(0).to(torch.double)
    weights = torch.rand((), dtype=torch.double)
    return dict(
        fullname='BCELoss_weights_no_reduce_scalar',
        constructor=wrap_functional(
            lambda i: F.binary_cross_entropy(i, t.type_as(i),
                                             weight=weights.type_as(i), reduction='none')),
        cpp_function_call='''F::binary_cross_entropy(
            i, t.to(i.options()),
            F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
        pickle=False,
        default_dtype=torch.double,
    )


def bce_with_logistic_legacy_enum_test():
    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
    sigmoid = nn.Sigmoid()
    return dict(
        fullname='BCEWithLogitsLoss_legacy_enum',
        constructor=wrap_functional(
            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
        cpp_function_call='''F::binary_cross_entropy_with_logits(
            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double,
    )


def bce_with_logistic_no_reduce_test():
    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
    sigmoid = nn.Sigmoid()
    return dict(
        fullname='BCEWithLogitsLoss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
        cpp_function_call='''F::binary_cross_entropy_with_logits(
            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double,
    )


def bce_with_logistic_no_reduce_scalar_test():
    t = torch.randn(()).gt(0).to(torch.double)
    sigmoid = nn.Sigmoid()
    return dict(
        fullname='BCEWithLogitsLoss_no_reduce_scalar',
        constructor=wrap_functional(
            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
        cpp_function_call='''F::binary_cross_entropy_with_logits(
            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double,
    )


def kldivloss_with_target_no_reduce_test():
    t = torch.rand(10, 10, dtype=torch.double)
    return dict(
        fullname='KLDivLoss_with_target_no_reduce',
        constructor=wrap_functional(
            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
        input_fn=lambda: torch.rand(10, 10).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def kldivloss_no_reduce_test():
    t = torch.rand(10, 10, dtype=torch.double)
    return dict(
        fullname='KLDivLoss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
        input_fn=lambda: torch.rand(10, 10).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double,
    )


def kldivloss_no_reduce_scalar_test():
    t = torch.rand((), dtype=torch.double)
    return dict(
        fullname='KLDivLoss_no_reduce_scalar',
        constructor=wrap_functional(
            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
        input_fn=lambda: torch.rand(()).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def kldivloss_with_log_target_no_reduce_test():
    t = torch.rand(10, 10, dtype=torch.double).log()
    return dict(
        fullname='KLDivLoss_with_log_target_no_reduce',
        constructor=wrap_functional(
            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
        input_fn=lambda: torch.rand(10, 10).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def kldivloss_no_reduce_log_target_test():
    t = torch.rand(10, 10, dtype=torch.double).log()
    return dict(
        fullname='KLDivLoss_no_reduce_log_target',
        constructor=wrap_functional(
            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
        input_fn=lambda: torch.rand(10, 10).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double,
    )


def kldivloss_no_reduce_scalar_log_target_test():
    t = torch.rand((), dtype=torch.double).log()
    return dict(
        fullname='KLDivLoss_no_reduce_scalar_log_target',
        constructor=wrap_functional(
            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
        input_fn=lambda: torch.rand(()).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def l1loss_no_reduce_test():
    t = torch.randn(2, 3, 4, dtype=torch.double)
    return dict(
        fullname='L1Loss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
        input_fn=lambda: torch.randn(2, 3, 4),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def l1loss_no_reduce_complex_test():
    t = torch.randn(2, 3, 4, dtype=torch.cdouble)
    return dict(
        fullname='L1Loss_no_reduce_complex',
        constructor=wrap_functional(
            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
        input_fn=lambda: torch.randn(2, 3, 4, dtype=torch.cdouble),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
        supports_forward_ad=True,
        pickle=False)


def l1loss_no_reduce_scalar_test():
    t = torch.randn((), dtype=torch.double)
    return dict(
        fullname='L1Loss_no_reduce_scalar',
        constructor=wrap_functional(
            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
        input_fn=lambda: torch.randn(()),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def mseloss_no_reduce_test():
    input_size = (2, 3, 4, 5)
    target = torch.randn(*input_size, dtype=torch.double)
    return dict(
        fullname='MSELoss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
        cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
        input_size=input_size,
        cpp_var_map={'i': '_get_input()', 'target': target},
        reference_fn=lambda i, *_: (i - target).pow(2),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def mseloss_no_reduce_scalar_test():
    input_size = ()
    target = torch.randn(input_size, dtype=torch.double)
    return dict(
        fullname='MSELoss_no_reduce_scalar',
        constructor=wrap_functional(
            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
        cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
        input_size=input_size,
        cpp_var_map={'i': '_get_input()', 'target': target},
        reference_fn=lambda i, *_: (i - target).pow(2),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def nllloss_no_reduce_test():
    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
    kwargs = {'reduction': 'none'}
    return dict(
        fullname='NLLLoss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
        cpp_function_call='''F::nll_loss(
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(15, 10).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
        pickle=False,
        default_dtype=torch.double)


def nllloss_no_reduce_ignore_index_test():
    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
    kwargs: dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
    return dict(
        fullname='NLLLoss_no_reduce_ignore_index',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
                                 reduction=str(kwargs['reduction']))),
        cpp_function_call='''F::nll_loss(
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(15, 10).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
        pickle=False,
        default_dtype=torch.double)


def nllloss_no_reduce_weights_test():
    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
    weight = torch.rand(10)

    def kwargs(i):
        return {'weight': weight.type_as(i), 'reduction': 'none'}

    return dict(
        fullname='NLLLoss_no_reduce_weights',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
        cpp_function_call='''F::nll_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
        reference_fn=lambda i, *_:
            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
        pickle=False,
        default_dtype=torch.double)


def nllloss_no_reduce_weights_ignore_index_test():
    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
    weight = torch.rand(10)

    def kwargs(i):
        return {'weight': weight.type_as(i), 'reduction': 'none',
                'ignore_index': 2}

    return dict(
        fullname='NLLLoss_no_reduce_weights_ignore_index',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
        cpp_function_call='''F::nll_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))''',
        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
        reference_fn=lambda i, *_:
            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
        pickle=False,
        default_dtype=torch.double)


def nllloss_no_reduce_weights_ignore_index_neg_test():
    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
    weight = torch.rand(10)

    def kwargs(i):
        return {'weight': weight.type_as(i), 'reduction': 'none',
                'ignore_index': -1}

    return dict(
        fullname='NLLLoss_no_reduce_weights_ignore_index_neg',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
        cpp_function_call='''F::nll_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))''',
        input=torch.rand(15, 10, dtype=torch.double).add(1e-2).log(),
        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
        reference_fn=lambda i, *_:
            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
        pickle=False,
        default_dtype=torch.double)


def nllloss2d_no_reduce_test():
    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
    kwargs = {'reduction': 'none'}
    return dict(
        fullname='NLLLoss2d_no_reduce',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
        cpp_function_call='''F::nll_loss(
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
        pickle=False,
        default_dtype=torch.double)


def nllloss2d_no_reduce_ignore_index_test():
    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
    kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
    return dict(
        fullname='NLLLoss2d_no_reduce_ignore_index',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
                                 reduction=str(kwargs['reduction']))),
        cpp_function_call='''F::nll_loss(
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
        pickle=False,
        default_dtype=torch.double)


def nllloss2d_no_reduce_weights_test():
    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
    weight = torch.rand(3)

    def kwargs(i):
        return {'weight': weight.type_as(i), 'reduction': 'none'}

    return dict(
        fullname='NLLLoss2d_no_reduce_weights',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
        cpp_function_call='''F::nll_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
        reference_fn=lambda i, *_:
            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
        pickle=False,
        default_dtype=torch.double)


def nlllossNd_no_reduce_test():
    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
    kwargs = {'reduction': 'none'}
    return dict(
        fullname='NLLLossNd_no_reduce',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
        cpp_function_call='''F::nll_loss(
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
        pickle=False,
        default_dtype=torch.double)


def nlllossNd_no_reduce_ignore_index_test():
    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
    kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
    return dict(
        fullname='NLLLossNd_no_reduce_ignore_index',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
                                 reduction=str(kwargs['reduction']))),
        cpp_function_call='''F::nll_loss(
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
        pickle=False,
        default_dtype=torch.double)


def nlllossNd_no_reduce_weights_test():
    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
    weight = torch.rand(3)

    def kwargs(i):
        return {'weight': weight.type_as(i), 'reduction': 'none'}

    return dict(
        fullname='NLLLossNd_no_reduce_weights',
        constructor=wrap_functional(
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
        cpp_function_call='''F::nll_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
        reference_fn=lambda i, *_:
            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
        pickle=False,
        default_dtype=torch.double)


def smoothl1loss_no_reduce_test():
    t = torch.randn(2, 3, 4, dtype=torch.double)
    return dict(
        fullname='SmoothL1Loss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
        cpp_function_call='''F::smooth_l1_loss(
            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(2, 3, 4),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def smoothl1loss_no_reduce_scalar_test():
    t = torch.randn((), dtype=torch.double)
    return dict(
        fullname='SmoothL1Loss_no_reduce_scalar',
        constructor=wrap_functional(
            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
        cpp_function_call='''F::smooth_l1_loss(
            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(()),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def smoothl1loss_beta_test():
    t = torch.randn(2, 3, 4, dtype=torch.double)
    return dict(
        fullname='SmoothL1Loss_beta',
        constructor=wrap_functional(
            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0.5)),
        cpp_function_call='''F::smooth_l1_loss(
            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)''',
        input_fn=lambda: torch.randn(2, 3, 4),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0.5),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def smoothl1loss_zero_beta_test():
    t = torch.randn(2, 3, 4, dtype=torch.double)
    return dict(
        fullname='SmoothL1Loss_zero_beta',
        constructor=wrap_functional(
            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0)),
        cpp_function_call='''F::smooth_l1_loss(
            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)''',
        input_fn=lambda: torch.randn(2, 3, 4),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def huberloss_delta_test():
    t = torch.randn(2, 3, 4)
    return dict(
        fullname='HuberLoss_delta',
        constructor=wrap_functional(
            lambda i: F.huber_loss(i, t.type_as(i), reduction='none', delta=0.5)),
        cpp_function_call='''F::huber_loss(
            i, t.to(i.options()), F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5))''',
        input_fn=lambda: torch.randn(2, 3, 4),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['HuberLoss'](i, t.type_as(i), reduction='none', delta=0.5),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def multilabelmarginloss_0d_no_reduce_test():
    t = torch.zeros(()).long()
    return dict(
        fullname='MultiLabelMarginLoss_0d_no_reduce',
        constructor=wrap_functional(
            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
        cpp_function_call='''F::multilabel_margin_loss(
            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(()),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
        check_sum_reduction=True,
        check_gradgrad=False,
        pickle=False)


def multilabelmarginloss_1d_no_reduce_test():
    t = Variable(torch.rand(10).mul(10).floor().long())
    return dict(
        fullname='MultiLabelMarginLoss_1d_no_reduce',
        constructor=wrap_functional(
            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
        cpp_function_call='''F::multilabel_margin_loss(
            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(10),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
        check_sum_reduction=True,
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double)


def multilabelmarginloss_index_neg_test():
    t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1))
    return dict(
        fullname='MultiLabelMarginLoss_index_neg',
        constructor=wrap_functional(
            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
        cpp_function_call='''F::multilabel_margin_loss(
            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(5, 10),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
        check_sum_reduction=True,
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double)


def multilabelmarginloss_no_reduce_test():
    t = Variable(torch.rand(5, 10).mul(10).floor().long())
    return dict(
        fullname='MultiLabelMarginLoss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
        cpp_function_call='''F::multilabel_margin_loss(
            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(5, 10),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
        check_sum_reduction=True,
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double)


def hingeembeddingloss_no_reduce_test():
    t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
    return dict(
        fullname='HingeEmbeddingLoss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')),
        cpp_function_call='''F::hinge_embedding_loss(
            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(10),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'),
        check_sum_reduction=True,
        pickle=False,
        default_dtype=torch.double)


def hingeembeddingloss_margin_no_reduce_test():
    t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
    return dict(
        fullname='HingeEmbeddingLoss_margin_no_reduce',
        constructor=wrap_functional(
            lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')),
        cpp_function_call='''F::hinge_embedding_loss(
            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(10),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'),
        check_sum_reduction=True,
        pickle=False,
        default_dtype=torch.double)


def softmarginloss_no_reduce_test():
    t = torch.randn(5, 5, dtype=torch.double)
    return dict(
        fullname='SoftMarginLoss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')),
        cpp_function_call='''F::soft_margin_loss(
            i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(5, 5),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'),
        supports_forward_ad=True,
        pickle=False,
        default_dtype=torch.double)


def multilabelsoftmarginloss_no_reduce_test():
    t = torch.rand(5, 10).mul(2).floor()
    return dict(
        fullname='MultiLabelSoftMarginLoss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')),
        cpp_function_call='''F::multilabel_soft_margin_loss(
            i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(5, 10),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1),
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double)


def multilabelsoftmarginloss_weights_no_reduce_test():
    t = torch.rand(5, 10).mul(2).floor()
    weights = torch.rand(10)
    return dict(
        fullname='MultiLabelSoftMarginLoss_weights_no_reduce',
        constructor=wrap_functional(
            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
                                                    weight=weights.type_as(i), reduction='none')),
        cpp_function_call='''F::multilabel_soft_margin_loss(
            i, t.to(i.options()),
            F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(5, 10),
        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
        reference_fn=lambda i, *_:
            (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1),
        check_sum_reduction=True,
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double)


def multimarginloss_no_reduce_test():
    t = torch.rand(5).mul(8).floor().long()
    return dict(
        fullname='MultiMarginLoss_no_reduce',
        constructor=wrap_functional(
            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
        cpp_function_call='''F::multi_margin_loss(
            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(5, 10),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
        check_sum_reduction=True,
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double)


def multimarginloss_1d_no_reduce_test():
    t = torch.rand(1).mul(8).floor().long()
    return dict(
        fullname='MultiMarginLoss_1d_no_reduce',
        constructor=wrap_functional(
            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
        cpp_function_call='''F::multi_margin_loss(
            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(10),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
        check_sum_reduction=True,
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double)


def multimarginloss_1d_input_0d_target_no_reduce_test():
    t = torch.rand(()).mul(8).floor().long()
    return dict(
        fullname='multimarginloss_1d_input_0d_target_no_reduce',
        constructor=wrap_functional(
            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
        cpp_function_call='''F::multi_margin_loss(
            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(10),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
        check_sum_reduction=True,
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double)


def multimarginloss_p_no_reduce_test():
    t = torch.rand(5).mul(8).floor().long()
    return dict(
        fullname='MultiMarginLoss_p_no_reduce',
        constructor=wrap_functional(
            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')),
        cpp_function_call='''F::multi_margin_loss(
            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'),
        check_sum_reduction=True,
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double)


def multimarginloss_margin_no_reduce_test():
    t = torch.rand(5).mul(8).floor().long()
    return dict(
        fullname='MultiMarginLoss_margin_no_reduce',
        constructor=wrap_functional(
            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')),
        cpp_function_call='''F::multi_margin_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(5, 10),
        cpp_var_map={'i': '_get_input()', 't': t},
        reference_fn=lambda i, *_:
            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
                                                  margin=0.5, reduction='none'),
        check_sum_reduction=True,
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double)


def multimarginloss_weights_no_reduce_test():
    t = torch.rand(5).mul(8).floor().long()
    weights = torch.rand(10, dtype=torch.double)
    return dict(
        fullname='MultiMarginLoss_weights_no_reduce',
        constructor=wrap_functional(
            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i),
                                          reduction='none')),
        cpp_function_call='''F::multi_margin_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
        input_fn=lambda: torch.randn(5, 10),
        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
        reference_fn=lambda i, *_:
            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
                                                  weight=weights, reduction='none'),
        check_sum_reduction=True,
        check_gradgrad=False,
        pickle=False,
        default_dtype=torch.double)


def single_batch_reference_fn(input, parameters, module):
    """Reference function for modules supporting no batch dimensions.

    The module is passed the input and target in batched form with a single item.
    The output is squeezed to compare with the no-batch input.
    """
    def unsqueeze_inp(inp):
        if isinstance(inp, (list, tuple)):
            return [t.unsqueeze(0) for t in inp]
        return inp.unsqueeze(0)

    single_batch_input = unsqueeze_inp(input)
    single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input
    with freeze_rng_state():
        return module(*single_batch_input).squeeze(0)


def get_new_module_tests():
    new_module_tests = [
        poissonnllloss_no_reduce_test(),
        bceloss_no_reduce_test(),
        bceloss_weights_no_reduce_test(),
        bce_with_logistic_legacy_enum_test(),
        bce_with_logistic_no_reduce_test(),
        bceloss_no_reduce_scalar_test(),
        bceloss_weights_no_reduce_scalar_test(),
        bce_with_logistic_no_reduce_scalar_test(),
        kldivloss_with_target_no_reduce_test(),
        kldivloss_no_reduce_test(),
        kldivloss_no_reduce_scalar_test(),
        kldivloss_with_log_target_no_reduce_test(),
        kldivloss_no_reduce_log_target_test(),
        kldivloss_no_reduce_scalar_log_target_test(),
        l1loss_no_reduce_test(),
        l1loss_no_reduce_complex_test(),
        l1loss_no_reduce_scalar_test(),
        mseloss_no_reduce_test(),
        mseloss_no_reduce_scalar_test(),
        nllloss_no_reduce_test(),
        nllloss_no_reduce_ignore_index_test(),
        nllloss_no_reduce_weights_test(),
        nllloss_no_reduce_weights_ignore_index_test(),
        nllloss_no_reduce_weights_ignore_index_neg_test(),
        nllloss2d_no_reduce_test(),
        nllloss2d_no_reduce_weights_test(),
        nllloss2d_no_reduce_ignore_index_test(),
        nlllossNd_no_reduce_test(),
        nlllossNd_no_reduce_weights_test(),
        nlllossNd_no_reduce_ignore_index_test(),
        smoothl1loss_no_reduce_test(),
        smoothl1loss_no_reduce_scalar_test(),
        smoothl1loss_beta_test(),
        smoothl1loss_zero_beta_test(),
        huberloss_delta_test(),
        multilabelmarginloss_0d_no_reduce_test(),
        multilabelmarginloss_1d_no_reduce_test(),
        multilabelmarginloss_index_neg_test(),
        multilabelmarginloss_no_reduce_test(),
        hingeembeddingloss_no_reduce_test(),
        hingeembeddingloss_margin_no_reduce_test(),
        softmarginloss_no_reduce_test(),
        multilabelsoftmarginloss_no_reduce_test(),
        multilabelsoftmarginloss_weights_no_reduce_test(),
        multimarginloss_no_reduce_test(),
        multimarginloss_1d_no_reduce_test(),
        multimarginloss_1d_input_0d_target_no_reduce_test(),
        multimarginloss_p_no_reduce_test(),
        multimarginloss_margin_no_reduce_test(),
        multimarginloss_weights_no_reduce_test(),
        dict(
            module_name='Conv1d',
            constructor_args=(4, 5, 3),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
            input_size=(2, 4, 10),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv1d',
            constructor_args=(4, 5, 3, 2),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)',
            input_size=(2, 4, 10),
            cudnn=True,
            desc='stride',
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv1d',
            constructor_args=(4, 5, 3, 1, 1),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)',
            input_size=(2, 4, 10),
            cudnn=True,
            desc='pad1',
            with_tf32=True,
            tf32_precision=0.01,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv1d',
            constructor_args=(4, 5, 5, 1, 2),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)',
            input_size=(2, 4, 10),
            cudnn=True,
            desc='pad2',
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv1d',
            constructor_args=(4, 4, 3, 1, 1),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)',
            input_size=(1, 4, 1),
            cudnn=True,
            desc='pad1size1',
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv1d',
            constructor_args=(4, 4, 5, 1, 2),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)',
            input_size=(1, 4, 1),
            cudnn=True,
            desc='pad2size1',
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv1d',
            constructor_args=(4, 5, 3),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
            input_size=(0, 4, 10),
            cudnn=True,
            desc='zero_batch',
            with_tf32=True,
            tf32_precision=0.005,
        ),
        dict(
            fullname='Conv1d_dilated',
            constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)',
            input_size=(2, 4, 10),
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv1d_groups',
            constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)',
            input_size=(2, 4, 6),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv1d_pad_valid',
            constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)',
            input_size=(2, 4, 10),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv1d_pad_same',
            constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)',
            input_size=(2, 4, 10),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv1d_pad_same2',
            constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)',
            input_size=(2, 4, 10),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv1d_pad_same_dilated',
            constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2),
            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)',
            input_size=(2, 4, 10),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='ConvTranspose1d',
            constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)),
            cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)',
            cudnn=True,
            input_size=(1, 3, 7),
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='ConvTranspose1d',
            constructor_args=(3, 4, 3, 2, 1, 1, 1, False),
            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
                                    .stride(2).padding(1).output_padding(1).groups(1).bias(false)''',
            input_size=(1, 3, 6),
            cudnn=True,
            desc='no_bias',
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='ConvTranspose1d',
            constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2),
            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
                                    .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''',
            input_size=(1, 3, 6),
            cudnn=True,
            desc='dilated',
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='ConvTranspose1d_groups',
            constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2),
            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3)
                                    .stride(3).padding(1).output_padding(1).groups(2)''',
            cudnn=True,
            input_size=(2, 4, 7),
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv2d',
            constructor_args=(3, 4, (3, 2)),
            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
            input_size=(2, 3, 7, 5),
            cudnn=True,
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv2d',
            constructor_args=(3, 4, (3, 3), (2, 2)),
            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})',
            input_size=(2, 3, 6, 6),
            cudnn=True,
            desc='strided',
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv2d',
            constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})',
            input_size=(2, 3, 6, 6),
            cudnn=True,
            desc='padding',
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv2d',
            constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)),
            cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})',
            input_size=(2, 3, 8, 8),
            cudnn=True,
            desc='dilated',
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv2d',
            constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False),
            cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2})
                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
            input_size=(2, 3, 6, 5),
            cudnn=True,
            desc='no_bias',
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.015,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv2d',
            constructor_args=(3, 4, (3, 2)),
            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
            input_size=(0, 3, 7, 5),
            cudnn=True,
            desc='zero_batch',
            check_with_long_tensor=True,
            with_tf32=True,
        ),
        dict(
            fullname='Conv2d_groups',
            constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
            cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
            input_size=(2, 4, 6, 5),
            cudnn=True,
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.015,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv2d_groups_thnn',
            constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
            cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
            input_size=(2, 4, 6, 5),
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.015,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv2d_pad_valid',
            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"),
            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)',
            input_size=(2, 2, 6, 5),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv2d_pad_same',
            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"),
            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)',
            input_size=(2, 2, 6, 5),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.01,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv2d_pad_same_dilated',
            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2),
            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)',
            input_size=(2, 2, 6, 5),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.01,
            default_dtype=torch.double,
        ),
        dict(
            module_name='ConvTranspose2d',
            constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
                                    .stride({3, 2}).padding(1).output_padding({1, 1})''',
            cudnn=True,
            input_size=(1, 3, 7, 6),
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.01,
            default_dtype=torch.double,
        ),
        dict(
            module_name='ConvTranspose2d',
            constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)),
            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
                                    .stride({2, 3})
                                    .padding(1)
                                    .output_padding({1, 1})
                                    .groups(1)
                                    .bias(false)
                                    .dilation({2, 2})''',
            input_size=(1, 3, 6, 7),
            cudnn=True,
            desc='dilated',
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.01,
            default_dtype=torch.double,
        ),
        dict(
            module_name='ConvTranspose2d',
            constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False),
            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
                                    .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''',
            input_size=(1, 3, 6, 7),
            cudnn=True,
            desc='no_bias',
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.01,
            default_dtype=torch.double,
        ),
        dict(
            fullname='ConvTranspose2d_groups',
            constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2),
            cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)',
            input_size=(1, 2, 4, 5),
            cudnn=True,
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.01,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv2d_depthwise',
            constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)',
            input_size=(2, 4, 6, 6),
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv2d_depthwise_with_multiplier',
            constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
            cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)',
            input_size=(2, 4, 6, 6),
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv2d_depthwise_strided',
            constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)',
            input_size=(2, 4, 6, 6),
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv2d_depthwise_padded',
            constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)',
            input_size=(2, 4, 6, 6),
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv2d_depthwise_dilated',
            constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)',
            input_size=(2, 4, 5, 5),
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv3d',
            constructor_args=(2, 3, (2, 3, 2)),
            cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})',
            input_size=(1, 2, 4, 5, 4),
            cudnn=True,
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv3d',
            constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False),
            cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
            input_size=(1, 2, 3, 4, 5),
            cudnn=True,
            desc='no_bias',
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv3d',
            constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False),
            cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
            input_size=(1, 2, 3, 4, 5),
            cudnn=True,
            desc='1x1x1_no_bias',
            check_with_long_tensor=False,
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv3d',
            constructor_args=(3, 4, 2, 2),
            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)',
            input_size=(2, 3, 5, 5, 5),
            cudnn=True,
            desc='stride',
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv3d',
            constructor_args=(3, 4, 2, 2, 1),
            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)',
            input_size=(2, 3, 5, 5, 5),
            cudnn=True,
            desc='stride_padding',
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Conv3d',
            constructor_args=(3, 4, (2, 3, 4)),
            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})',
            input_size=(0, 3, 3, 4, 5),
            cudnn=True,
            check_with_long_tensor=True,
            desc='zero_batch',
            with_tf32=True,
        ),
        dict(
            fullname='Conv3d_groups',
            constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2),
            cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)',
            input_size=(1, 2, 4, 5, 4),
            cudnn=True,
            check_with_long_tensor=True,
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv3d_dilated',
            constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)',
            input_size=(2, 3, 5, 5, 5),
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv3d_dilated_strided',
            constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)',
            input_size=(2, 3, 5, 5, 5),
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv3d_pad_valid',
            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"),
            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)',
            input_size=(2, 3, 6, 5, 4),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv3d_pad_same',
            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"),
            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)',
            input_size=(2, 3, 6, 5, 4),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Conv3d_pad_same_dilated',
            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2),
            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)',
            input_size=(2, 3, 6, 5, 4),
            cudnn=True,
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            module_name='ConvTranspose3d',
            constructor_args=(2, 3, (2, 3, 2)),
            cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})',
            cudnn=True,
            input_size=(1, 2, 4, 5, 4),
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            module_name='ConvTranspose3d',
            constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)),
            cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})
                                    .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''',
            cudnn=True,
            input_size=(1, 2, 4, 5, 4),
            desc='dilated',
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            module_name='ReplicationPad3d',
            constructor_args=((1, 2, 3, 3, 2, 1),),
            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
            input_size=(2, 3, 2, 2, 2),
            default_dtype=torch.double,
        ),
        dict(
            module_name='ReplicationPad3d',
            constructor_args=((1, 2, 3, 3, 2, 1),),
            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
            input_size=(3, 2, 2, 2),
            reference_fn=single_batch_reference_fn,
            desc='no_batch_dim',
            default_dtype=torch.double,
        ),
        dict(
            module_name='ReplicationPad3d',
            constructor_args=((1, 2, 3, 3, 2, 1),),
            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
            input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True),
            skip_half=True,
            desc='complex'
        ),
        dict(
            module_name='Embedding',
            constructor_args=(4, 3),
            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
            check_gradgrad=False,
            default_dtype=torch.double,
            decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
        ),
        dict(
            module_name='Embedding',
            constructor_args=(4, 3),
            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
            input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
            check_gradgrad=False,
            desc='discontiguous',
            default_dtype=torch.double,
            decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
        ),
        dict(
            module_name='EmbeddingBag',
            constructor_args=(4, 3),
            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
            check_gradgrad=False,
            desc='mean',
            default_dtype=torch.double,
        ),
        dict(
            module_name='EmbeddingBag',
            constructor_args=(4, 3),
            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
            input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
            check_gradgrad=False,
            desc='discontiguous',
            default_dtype=torch.double,
        ),
        dict(
            module_name='EmbeddingBag',
            constructor_args=(4, 3, None, 2., False, 'sum'),
            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''',
            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
            check_gradgrad=False,
            desc='sum',
            default_dtype=torch.double,
        ),
        dict(
            module_name='EmbeddingBag',
            constructor_args=(4, 3, None, 2., False, 'max'),
            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''',
            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
            check_gradgrad=False,
            desc='max',
            default_dtype=torch.double,
        ),
        dict(
            fullname='EmbeddingBag_mean_padding_idx',
            constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1),
            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)',
            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
            check_gradgrad=False,
            default_dtype=torch.double,
        ),
        dict(
            fullname='EmbeddingBag_sum_padding_idx',
            constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1),
            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''',
            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
            check_gradgrad=False,
            default_dtype=torch.double,
        ),
        dict(
            fullname='EmbeddingBag_max_padding_idx',
            constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1),
            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''',
            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
            check_gradgrad=False,
            default_dtype=torch.double,
        ),
        dict(
            fullname='EmbeddingBag_sparse',
            constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double),
            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
                                    .sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))''',
            input_fn=lambda: torch.randperm(2).repeat(1, 2),
            check_gradgrad=False,
            has_sparse_gradients=True,
        ),
        dict(
            constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True),
            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
            input_fn=lambda: torch.randperm(2).repeat(1, 2),
            fullname='Embedding_sparse',
            check_gradgrad=False,
            has_sparse_gradients=True,
        ),
        dict(
            module_name='PixelShuffle',
            constructor_args=(3,),
            cpp_constructor_args='torch::nn::PixelShuffleOptions(3)',
            input_size=(1, 9, 4, 4),
            default_dtype=torch.double,
        ),
        dict(
            module_name='PixelUnshuffle',
            constructor_args=(3,),
            cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)',
            input_size=(1, 1, 12, 12),
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
            input_size=(1, 2, 4),
            fullname='interpolate_nearest_1d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
            input_size=(0, 2, 4),
            fullname='interpolate_nearest_1d_zero_dim',
            pickle=False,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
            input_size=(1, 2, 3),
            fullname='interpolate_nearest_tuple_1d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt).scale_factor(std::vector<double>({4.})).mode(torch::kNearest)''',
            input_size=(1, 2, 4),
            fullname='interpolate_nearest_scale_1d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kLinear)
                                .align_corners(false)''',
            input_size=(1, 2, 4),
            fullname='interpolate_linear_1d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({4}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kLinear)
                                .align_corners(false)''',
            input_size=(1, 2, 3),
            fullname='interpolate_linear_tuple_1d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({4.}))
                                .mode(torch::kLinear)
                                .align_corners(false)''',
            input_size=(1, 2, 4),
            fullname='interpolate_linear_scale_1d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kLinear)
                                .align_corners(false)''',
            input_size=(0, 2, 4),
            fullname='interpolate_linear_1d_zero_dim',
            pickle=False,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kLinear)
                                .align_corners(true)''',
            input_size=(1, 2, 4),
            fullname='interpolate_linear_1d_align_corners',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({4.}))
                                .mode(torch::kLinear)
                                .align_corners(true)''',
            input_size=(1, 2, 4),
            fullname='interpolate_linear_scale_1d_align_corners',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({2, 2}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kNearest)''',
            input_size=(1, 128, 1, 1),
            fullname='interpolate_nearest_2d_launch_configs',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kNearest)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_nearest_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 16}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kNearest)''',
            input_size=(1, 2, 3, 4),
            fullname='interpolate_nearest_tuple_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({4., 4.}))
                                .mode(torch::kNearest)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_nearest_scale_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kNearest)''',
            input_size=(0, 2, 4, 4),
            fullname='interpolate_nearest_2d_zero_dim',
            pickle=False,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kBilinear)
                                .align_corners(false)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bilinear_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kBilinear)
                                .align_corners(false)''',
            input_size=(0, 2, 4, 4),
            fullname='interpolate_bilinear_2d_zero_dim',
            pickle=False,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
                                        mode='bilinear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({4, 6}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kBilinear)
                                .align_corners(false)''',
            input_size=(1, 2, 2, 3),
            fullname='interpolate_bilinear_tuple_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4.,
                                        mode='bilinear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({4., 4.}))
                                .mode(torch::kBilinear)
                                .align_corners(false)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bilinear_scale_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
                                        mode='bilinear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({2., 2.}))
                                .mode(torch::kBilinear)
                                .align_corners(false)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bilinear_scale_tuple_shared_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
                                        mode='bilinear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({2., 1.}))
                                .mode(torch::kBilinear)
                                .align_corners(false)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bilinear_scale_tuple_skewed_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({4, 6}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kBilinear)
                                .align_corners(true)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bilinear_tuple_2d_align_corners',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
                                        mode='bilinear', align_corners=True),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({2., 1.}))
                                .mode(torch::kBilinear)
                                .align_corners(true)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kBicubic)
                                .align_corners(false)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bicubic_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kBicubic)
                                .align_corners(false)''',
            input_size=(0, 2, 4, 4),
            fullname='interpolate_bicubic_2d_zero_dim',
            pickle=False,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
                                        mode='bicubic', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({4, 6}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kBicubic)
                                .align_corners(false)''',
            input_size=(1, 2, 2, 3),
            fullname='interpolate_bicubic_tuple_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({4., 4.}))
                                .mode(torch::kBicubic)
                                .align_corners(false)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bicubic_scale_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
                                        mode='bicubic', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({2., 2.}))
                                .mode(torch::kBicubic)
                                .align_corners(false)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bicubic_scale_tuple_shared_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
                                        mode='bicubic', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({2., 1.}))
                                .mode(torch::kBicubic)
                                .align_corners(false)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bicubic_scale_tuple_skewed_2d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({4, 6}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kBicubic)
                                .align_corners(true)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bicubic_tuple_2d_align_corners',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
                                        mode='bicubic', align_corners=True),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({2., 1.}))
                                .mode(torch::kBicubic)
                                .align_corners(true)''',
            input_size=(1, 2, 4, 4),
            fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 12, 12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kNearest)''',
            input_size=(1, 2, 4, 4, 4),
            fullname='interpolate_nearest_3d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 12, 12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kNearest)''',
            input_size=(0, 2, 4, 4, 4),
            fullname='interpolate_nearest_3d_zero_dim',
            pickle=False,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 16, 16}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kNearest)''',
            input_size=(1, 2, 3, 4, 4),
            fullname='interpolate_nearest_tuple_3d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({4., 4., 4.}))
                                .mode(torch::kNearest)''',
            input_size=(1, 2, 4, 4, 4),
            fullname='interpolate_nearest_scale_3d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 12, 12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kTrilinear)
                                .align_corners(false)''',
            input_size=(1, 2, 4, 4, 4),
            fullname='interpolate_trilinear_3d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({12, 12, 12}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kTrilinear)
                                .align_corners(false)''',
            input_size=(0, 2, 4, 4, 4),
            fullname='interpolate_trilinear_3d_zero_dim',
            pickle=False,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=(4, 6, 6),
                                        scale_factor=None, mode='trilinear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({4, 6, 6}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kTrilinear)
                                .align_corners(false)''',
            input_size=(1, 2, 2, 3, 3),
            fullname='interpolate_trilinear_tuple_3d',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({3., 3., 3.}))
                                .mode(torch::kTrilinear)
                                .align_corners(false)''',
            input_size=(1, 2, 3, 4, 5),
            fullname='interpolate_trilinear_scale_3d',
            # See https://github.com/pytorch/pytorch/issues/5006
            precision=3e-4,
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None,
                                        mode='trilinear', align_corners=True),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::vector<int64_t>({4, 6, 6}))
                                .scale_factor(std::nullopt)
                                .mode(torch::kTrilinear)
                                .align_corners(true)''',
            input_size=(1, 2, 2, 3, 3),
            fullname='interpolate_trilinear_tuple_3d_align_corners',
            pickle=False,
            default_dtype=torch.double
        ),
        dict(
            constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True),
            cpp_options_args='''F::InterpolateFuncOptions()
                                .size(std::nullopt)
                                .scale_factor(std::vector<double>({3., 3., 3.}))
                                .mode(torch::kTrilinear)
                                .align_corners(true)''',
            input_size=(1, 2, 3, 4, 4),
            fullname='interpolate_trilinear_scale_3d_align_corners',
            # See https://github.com/pytorch/pytorch/issues/5006
            precision=3e-4,
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.softmax, dim=-1),
            cpp_options_args='F::SoftmaxFuncOptions(-1)',
            input_size=(2, 128),  # trigger the last-dim algo in CUDA
            fullname='softmax_lastdim',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
            cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
            input_size=(2, 128),
            fullname='softmax_lastdim_dtype',
            pickle=False,
            test_cuda=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.softmax, dim=1),
            cpp_options_args='F::SoftmaxFuncOptions(1)',
            input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
            fullname='softmax_spatial_special',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.softmax, dim=1),
            cpp_options_args='F::SoftmaxFuncOptions(1)',
            input_size=(2, 2, 4, 4),  # regular spatial algorithm
            fullname='softmax_spatial',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
            cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
            input_size=(2, 2, 4, 4),  # regular spatial algorithm
            fullname='softmax_spatial_dtype',
            pickle=False,
            test_cuda=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.softmax, dim=0),
            cpp_options_args='F::SoftmaxFuncOptions(0)',
            input_size=(2, 3, 4, 5),
            fullname='softmax_functional_dim0',
            test_cuda=False,
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.softmax, dim=3),
            cpp_options_args='F::SoftmaxFuncOptions(3)',
            input_size=(2, 3, 4, 5),
            fullname='softmax_functional_dim3',
            test_cuda=False,
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.softmax, dim=-1),
            cpp_options_args='F::SoftmaxFuncOptions(-1)',
            input_size=(),
            fullname='softmax_functional_scalar',
            test_cuda=False,
            pickle=False,
        ),
        dict(
            constructor=wrap_functional(F.log_softmax, dim=-1),
            cpp_options_args='F::LogSoftmaxFuncOptions(-1)',
            input_size=(2, 128),  # trigger the last-dim algo in CUDA
            fullname='log_softmax_lastdim',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.log_softmax, dim=1),
            cpp_options_args='F::LogSoftmaxFuncOptions(1)',
            input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
            fullname='log_softmax_spatial_special',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.log_softmax, dim=1),
            cpp_options_args='F::LogSoftmaxFuncOptions(1)',
            input_size=(2, 2, 4, 4),  # regular spatial algorithm
            fullname='log_softmax_spatial',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.log_softmax, dim=0),
            cpp_options_args='F::LogSoftmaxFuncOptions(0)',
            input_size=(2, 3, 4, 5),
            fullname='log_softmax_dim0',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.log_softmax, dim=3),
            cpp_options_args='F::LogSoftmaxFuncOptions(3)',
            input_size=(2, 3, 4, 5),
            fullname='log_softmax_dim3',
            pickle=False,
            default_dtype=torch.double,
        ),
        dict(
            constructor=wrap_functional(F.log_softmax, dim=0),
            cpp_options_args='F::LogSoftmaxFuncOptions(0)',
            input_size=(),
            fullname='log_softmax_scalar',
            pickle=False,
        ),
        dict(
            fullname='Unfold',
            constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)),
            cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
            input_size=(2, 4, 3, 3),
            check_gradgrad=False,
            test_cuda=True,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Fold',
            constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
            cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
            input_size=(2, 16, 4),
            check_gradgrad=False,
            test_cuda=True,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Fold_no_batch_dim_input',
            constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
            cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
            input_size=(16, 4),
            check_gradgrad=False,
            ref=single_batch_reference_fn,
            test_cuda=True,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Unfold_int_input',
            constructor=lambda: nn.Unfold(2, 1, 0, 1),
            cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)',
            input_size=(2, 4, 3, 3),
            check_gradgrad=False,
            test_cuda=True,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Fold_int_input',
            constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
            cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
            input_size=(2, 16, 4),
            check_gradgrad=False,
            test_cuda=True,
            default_dtype=torch.double,
        ),
        dict(
            fullname='Fold_no_batch_dim_int_input',
            constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
            cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
            input_size=(16, 4),
            ref=single_batch_reference_fn,
            check_gradgrad=False,
            test_cuda=True,
            default_dtype=torch.double,
        ),
        dict(
            module_name='RReLU',
            constructor_args=(0.1, 0.9),
            cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
            input_size=(),
            desc='with_up_down_scalar',
            test_cuda=False,
            default_dtype=torch.double,
        ),
        dict(
            module_name='PairwiseDistance',
            input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
            default_dtype=torch.double,
        ),
        dict(
            module_name='PairwiseDistance',
            input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)),
            desc='broadcast_lhs',
            default_dtype=torch.double,
        ),
        dict(
            module_name='PairwiseDistance',
            input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)),
            desc='broadcast_rhs',
            default_dtype=torch.double,
        ),
        dict(
            module_name='PairwiseDistance',
            constructor_args=(1.5, 1e-05, True),
            cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)',
            input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
            desc='with_non_default_args',
            default_dtype=torch.double,
        ),
        dict(
            module_name='PairwiseDistance',
            input_fn=lambda: (torch.randn(8), torch.randn(8)),
            reference_fn=single_batch_reference_fn,
            desc='no_batch_dim',
            default_dtype=torch.double,
        ),
        dict(
            module_name='TransformerEncoderLayer',
            constructor_args=(4, 2, 16, 0.0),
            cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
                                    .dim_feedforward(16)
                                    .dropout(0.0)''',
            input_size=(2, 3, 4),
            desc='relu_activation',
            with_tf32=True,
            tf32_precision=0.1,
            # TODO(#50743): figure out the error
            # RuntimeError: The size of tensor a (6) must match the size of tensor b (4)
            # at non-singleton dimension 2
            check_batched_grad=False,
            check_gradgrad=False,
            default_dtype=torch.double,
        ),
        dict(
            module_name='TransformerEncoderLayer',
            constructor_args=(4, 2, 8, 0.0, F.gelu),
            cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
                                    .dim_feedforward(8)
                                    .dropout(0.0)
                                    .activation(torch::kGELU)''',
            input_size=(2, 3, 4),
            check_gradgrad=False,
            desc='gelu_activation',
            with_tf32=True,
            tf32_precision=0.08 if SM90OrLater else 0.05,
            default_dtype=torch.double,
        ),
        dict(
            module_name='TransformerDecoderLayer',
            constructor_args=(4, 2, 8, 0.0),
            cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
                                    .dim_feedforward(8)
                                    .dropout(0.0)''',
            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
            check_gradgrad=False,
            desc='relu_activation',
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            module_name='TransformerDecoderLayer',
            constructor_args=(4, 2, 8, 0.0, F.gelu),
            cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
                                    .dim_feedforward(8)
                                    .dropout(0.0)
                                    .activation(torch::kGELU)''',
            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
            check_gradgrad=False,
            desc='gelu_activation',
            with_tf32=True,
            tf32_precision=0.05,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Transformer',
            constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu),
            cpp_constructor_args='''torch::nn::TransformerOptions()
                                    .d_model(4)
                                    .nhead(2)
                                    .num_encoder_layers(2)
                                    .num_decoder_layers(2)
                                    .dim_feedforward(8)
                                    .dropout(0.0)
                                    .activation(torch::kReLU)''',
            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)),
            check_gradgrad=False,
            desc='multilayer_coder',
            with_tf32=True,
            tf32_precision=0.05 if SM90OrLater else 0.03,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Linear',
            constructor_args=(3, 5),
            cpp_constructor_args='torch::nn::LinearOptions(3, 5)',
            input_fn=lambda: torch.rand(3),
            reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1],
            desc="no_batch_dim",
            with_tf32=True,
            tf32_precision=0.005,
            default_dtype=torch.double,
        ),
        dict(
            module_name='Flatten',
            cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)',
            constructor_args=(-3, -1),
            input_size=(3, 4, 5),
            reference_fn=single_batch_reference_fn,
            desc="no_batch_dim",
            default_dtype=torch.double,
        ),
        dict(
            module_name='Unflatten',
            cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})',
            constructor_args=(-2, torch.Size([2, 2])),
            input_size=(3, 4, 5),
            reference_fn=single_batch_reference_fn,
            desc="no_batch_dim",
            default_dtype=torch.double,
        ),
        dict(
            module_name='LayerNorm',
            constructor_args=([56, 56, 56], 1e-5, False),
            cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)',
            input_size=(4, 56, 56, 56),
            cudnn=True,
            check_eval=True,
            gradcheck_fast_mode=True,
            check_half=True,
            desc='3d_no_affine_large_feature',
        ),
    ]

    # add conv padding mode tests:
    for padding_mode, cpp_padding_mode in zip(
            ['reflect', 'circular', 'replicate', 'zeros'],
            ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros']):
        # conv signature:
        #     in_channels, out_channels, kernel_size, stride=1,
        #     padding=0, dilation=1, groups=1,
        #     bias=True, padding_mode='zeros'
        for d in (1, 2, 3):
            if d == 3 and padding_mode == 'reflect':
                # FIXME: remove after implementing reflection pad 3d
                #        https://github.com/pytorch/pytorch/issues/27655
                continue
            padding = tuple(range(1, d + 1))
            cpp_padding = '{' + ', '.join(map(str, padding)) + '}'
            input_size = (2, 2) + (4,) * d
            output_size = (2, 3) + tuple(p + 1 for p in padding)  # simplified from `(4 + 2 * p - 3) // 2 + 1`
            new_module_tests.append(
                dict(
                    module_name=f'Conv{d}d',
                    constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode),
                    cpp_constructor_args=f'''torch::nn::Conv{d}dOptions(2, 3, 3)
                                            .stride(2)
                                            .padding({cpp_padding})
                                            .dilation(1)
                                            .groups(1)
                                            .bias(true)
                                            .padding_mode({cpp_padding_mode})''',
                    input_size=input_size,
                    output_size=output_size,
                    cudnn=True,
                    desc=f'{padding_mode}_stride2_pad2',
                    with_tf32=True,
                    tf32_precision=0.05,
                    default_dtype=torch.double,
                ),
            )

    # Check that non linear activations work with no batch dimensions
    non_linear_activations_no_batch = [
        'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU',
        'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU',
        'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh',
        'Tanhshrink', 'Threshold'
    ]
    non_linear_activations_extra_info: dict[str, dict] = {
        'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double},
        'Threshold': {'constructor_args': (2., 1.)},
        'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
        'Hardswish': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
        # For RRelu, test that compare CPU and GPU results fail because RNG
        # is different between CPU and GPU
        'RReLU': {'test_cuda': False, 'default_dtype': torch.double},
        'ELU': {'default_dtype': torch.double},
        'GELU': {'default_dtype': torch.double},
        'GLU': {'default_dtype': torch.double},
        'Hardshrink': {'default_dtype': torch.double},
        'Hardtanh': {'default_dtype': torch.double},
        'LeakyReLU': {'default_dtype': torch.double},
        'LogSigmoid': {'default_dtype': torch.double},
        'Mish': {'default_dtype': torch.double},
        'PReLU': {'default_dtype': torch.double},
        'ReLU6': {'default_dtype': torch.double},
        'ReLU': {'default_dtype': torch.double},
        'SELU': {'default_dtype': torch.double},
        'SiLU': {'default_dtype': torch.double},
        'Sigmoid': {'default_dtype': torch.double},
        'Softplus': {'default_dtype': torch.double},
        'Softshrink': {'default_dtype': torch.double},
        'Softsign': {'default_dtype': torch.double},
        'Tanh': {'default_dtype': torch.double},
        'Tanhshrink': {'default_dtype': torch.double},
    }
    for non_linear_activation in non_linear_activations_no_batch:
        activation_test_info = dict(
            module_name=non_linear_activation,
            input_size=(4,),
            reference_fn=single_batch_reference_fn,
            desc='no_batch_dim',
            test_cpp_api_parity=False,
        )
        extra_info = non_linear_activations_extra_info.get(non_linear_activation, {})
        activation_test_info.update(extra_info)
        new_module_tests.append(activation_test_info)


    return new_module_tests


def kldivloss_reference(input, target, reduction='mean', log_target=False):
    if log_target:
        result = torch.exp(target) * (target - input)
    else:
        result = target * (target.log() - input)
    if reduction == 'mean':
        return result.mean()
    elif reduction == 'sum':
        return result.sum()
    elif reduction == 'batchmean' and result.dim() != 0:
        return result.sum() / result.size(0)
    return result


def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
                        reduction='mean'):
    assert input.dim() >= 3
    N = input.size(0)
    C = input.size(1)
    out_size = (N,) + input.size()[2:]
    output = torch.zeros(out_size).type_as(input)

    if weight is None:
        weight = torch.ones(C).type_as(input)
    total_weight = 0
    for tup in product(*[range(size) for size in out_size]):
        t_nx = target[tup]
        norm = 0. if ignore_index == t_nx else weight[t_nx].item()
        input_index = list(tup)
        input_index.insert(1, t_nx)
        output[tup] = -input[tuple(input_index)] * norm
        total_weight += norm

    if reduction == 'mean':
        return output.sum() / total_weight
    elif reduction == 'sum':
        return output.sum()
    return output


def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean',
                                             label_smoothing=0.0):
    assert input.dim() >= 2

    input = torch.log_softmax(input, 1)
    C = input.size(1)
    if weight is None:
        weight = torch.ones(C).type_as(input)
    weight = weight.view(1, C, *(1 for _ in input.shape[2:]))

    if label_smoothing > 0.0:
        assert label_smoothing <= 1.0
        target = (target * (1 - label_smoothing) + label_smoothing / C)

    output = -(input * target * weight).sum(dim=1)
    if reduction == 'mean':
        return output.mean()
    elif reduction == 'sum':
        return output.sum()
    return output


def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100,
                                                reduction='mean', label_smoothing=0.0):
    log_softmax_input = torch.log_softmax(input, 1)
    nllloss = F.nll_loss(
        log_softmax_input,
        target,
        weight,
        ignore_index=ignore_index,
        reduction=reduction)

    if label_smoothing == 0.0:
        return nllloss

    assert 0.0 < label_smoothing <= 1.0

    input = torch.log_softmax(input, 1)
    C = input.size(1)
    if weight is not None:
        input = input * weight.view(1, C, *(1 for _ in input.shape[2:]))

    smooth_loss = -torch.sum(input, 1)

    ignore_mask = target == ignore_index
    smooth_loss.masked_fill_(ignore_mask, 0.0)

    if reduction == 'mean':
        if weight is not None:
            # TODO: This code can path can be removed if #61309 is resolved
            # loss is normalized by the weights to be consistent with nll_loss_nd
            ret = torch.sum(smooth_loss) / weight.gather(0, target.masked_select(ignore_mask.logical_not()).flatten()).sum()
        else:
            ret = torch.mean(smooth_loss.masked_select(ignore_mask.logical_not()))
    elif reduction == 'sum':
        ret = torch.sum(smooth_loss)
    else:
        ret = smooth_loss

    return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C)


def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean',
                                 label_smoothing=0.0):
    if input.shape == target.shape:
        return cross_entropy_loss_prob_target_reference(
            input,
            target,
            weight=weight,
            reduction=reduction,
            label_smoothing=label_smoothing)
    else:
        return cross_entropy_loss_indices_target_reference(
            input, target, weight=weight, reduction=reduction,
            ignore_index=ignore_index, label_smoothing=label_smoothing
        )


def nllloss_reference(input, target, weight=None, ignore_index=-100,
                      reduction='mean'):

    def nll_loss_helper(input, target, weight, ignore_index):
        if target == ignore_index:
            return (0, 0)
        norm = 1 if weight is None else weight[target]
        result = -input[target] * norm
        return (result, norm)

    losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index)
                          for i, t in zip(input, target)]
    losses, weights = zip(*losses_and_weights)
    losses_tensor = input.new_tensor(losses)
    if reduction == 'mean':
        return sum(losses_tensor) / sum(weights)
    elif reduction == 'sum':
        return sum(losses_tensor)
    else:
        return losses_tensor


def smoothl1loss_reference(input, target, reduction='mean', beta=1.0):
    abs_diff = (input - target).abs()
    ge_beta_mask = (abs_diff >= beta).type_as(abs_diff)
    lt_beta_mask = (abs_diff < beta).type_as(abs_diff)
    # when beta <= 0 we should just use l1_loss
    if beta == 0:
        output = abs_diff
    else:
        output = ge_beta_mask * (abs_diff - 0.5 * beta) + lt_beta_mask * 0.5 * (abs_diff ** 2) / beta
    if reduction == 'mean':
        return output.mean()
    elif reduction == 'sum':
        return output.sum()
    return output


def huberloss_reference(input, target, reduction='mean', delta=1.0):
    abs_diff = (input - target).abs()
    ge_delta_mask = (abs_diff >= delta)
    lt_delta_mask = (abs_diff < delta)
    output = ge_delta_mask * delta * (abs_diff - 0.5 * delta) + lt_delta_mask * 0.5 * (abs_diff ** 2)
    if reduction == 'mean':
        return output.mean()
    elif reduction == 'sum':
        return output.sum()
    return output


def _multilabelmarginloss_reference(input, target):
    targets = []
    for target_index in target:
        if target_index < 0:
            break
        targets.append(target_index)

    sum = 0
    for target_index in targets:
        for i in range(0, len(input)):
            if i not in targets:
                sum += max(0, 1 - input[target_index] + input[i])

    return sum


def multilabelmarginloss_reference(input, target, reduction='mean'):
    # make everything 2-dimensional
    input_dim = input.dim()
    if input.dim() < 2:
        assert target.dim() < 2
        input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
        target = target.unsqueeze(0) if target.dim() == 1 else target.unsqueeze(0).unsqueeze(0)

    n = input.size(0)
    dim = input.size(1)
    output = input.new(n).zero_()
    for i in range(0, n):
        output[i] = _multilabelmarginloss_reference(input[i], target[i])

    if reduction == 'mean':
        return output.mean() / dim
    elif reduction == 'sum':
        return output.sum() / dim
    elif input_dim < 2:
        # we know we have (1, C) X (1, C) -> (1,), so squeeze will get us
        # back to correct dimensionality
        return output.squeeze() / dim
    else:
        return output / dim


def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'):
    margin_clamp = (margin - input).clamp(min=0).type_as(input)
    output = torch.where(target == 1, input, margin_clamp)

    if reduction == 'mean':
        return output.mean()
    elif reduction == 'sum':
        return output.sum()
    return output


def softmarginloss_reference(input, target, reduction='mean'):
    output = (1 + (-input * target).exp()).log()

    if reduction == 'mean':
        return output.mean()
    elif reduction == 'sum':
        return output.sum()
    return output


def _multimarginloss_reference(input, target_idx, p, margin, weight):
    if weight is None:
        weight = input.new(len(input)).fill_(1)

    output = 0
    for i in range(0, len(input)):
        if i != target_idx:
            output += weight[target_idx] * (max(0, (margin - input[target_idx] + input[i])) ** p)
    return output


def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'):
    if input.dim() < 2:
        input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)

    target_dim = target.dim()
    if target.dim() == 0:
        target = target.unsqueeze(0)

    n = input.size(0)
    dim = input.size(1)
    output = input.new(n)
    for x in range(0, n):
        output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)

    if reduction == 'mean':
        return output.mean() / dim
    elif reduction == 'sum':
        return output.sum() / dim
    elif target_dim == 0:
        return output.squeeze(0) / dim
    return output / dim


def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'):
    def _cos(a, b):
        cos = a.new(a.size(0))
        for i in range(0, a.size(0)):
            cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5)
        return cos

    output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0))

    if reduction == 'mean':
        return output.mean()
    elif reduction == 'sum':
        return output.sum()
    return output


def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
                                reduction='mean'):
    d_p = torch.pairwise_distance(anchor, positive, p, eps)
    d_n = torch.pairwise_distance(anchor, negative, p, eps)
    if swap:
        d_s = torch.pairwise_distance(positive, negative, p, eps)
        d_n = torch.min(d_n, d_s)

    output = torch.clamp(margin + d_p - d_n, min=0.0)
    if reduction == 'mean':
        return output.mean()
    elif reduction == 'sum':
        return output.sum()
    return output


def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'):
    output = (-target * (input1 - input2) + margin).clamp(min=0)
    if reduction == 'mean':
        return output.mean()
    elif reduction == 'sum':
        return output.sum()
    return output


# this directly follows Graves et al.'s paper, in contrast to the production implementation, it does not use log-space
def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'):
    input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
    target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
    dt = log_probs.dtype
    log_probs = log_probs.double()  # we need the accuracy as we are not in logspace
    targets = targets.long()
    cum_target_lengths = target_lengths.cumsum(0)
    losses = []
    for i in range(log_probs.size(1)):
        input_length = input_lengths[i].item()
        target_length = target_lengths[i].item()
        cum_target_length = cum_target_lengths[i].item()
        targets_prime = targets.new_full((2 * target_length + 1,), blank)
        if targets.dim() == 2:
            targets_prime[1::2] = targets[i, :target_length]
        else:
            targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length]
        probs = log_probs[:input_length, i].exp()
        alpha = log_probs.new_zeros((target_length * 2 + 1,))
        alpha[0] = probs[0, blank]
        alpha[1] = probs[0, targets_prime[1]]
        mask_third = (targets_prime[:-2] != targets_prime[2:])
        for t in range(1, input_length):
            alpha_next = alpha.clone()
            alpha_next[1:] += alpha[:-1]
            alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
            alpha = probs[t, targets_prime] * alpha_next
        losses.append(-alpha[-2:].sum().log()[None])
    output = torch.cat(losses, 0)
    if reduction == 'mean':
        output = (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean()
    elif reduction == 'sum':
        output = output.sum()
    output = output.to(dt)
    return output


loss_reference_fns: dict['str', Callable] = {
    'KLDivLoss': kldivloss_reference,
    'KLDivLoss_log_target': partial(kldivloss_reference, log_target=True),
    'NLLLoss': nllloss_reference,
    'NLLLossNd': nlllossNd_reference,
    'SmoothL1Loss': smoothl1loss_reference,
    'HuberLoss': huberloss_reference,
    'MultiLabelMarginLoss': multilabelmarginloss_reference,
    'HingeEmbeddingLoss': hingeembeddingloss_reference,
    'SoftMarginLoss': softmarginloss_reference,
    'MultiMarginLoss': multimarginloss_reference,
    'CosineEmbeddingLoss': cosineembeddingloss_reference,
    'TripletMarginLoss': tripletmarginloss_reference,
    'MarginRankingLoss': marginrankingloss_reference,
    'CTCLoss': ctcloss_reference,
    'CrossEntropyLoss': cross_entropy_loss_reference
}


criterion_tests = []


def single_batch_reference_criterion_fn(*args):
    """Reference function for criterion supporting no batch dimensions.

    The criterion is passed the input and target in batched form with a single item.
    The output is squeezed to compare with the no-batch input.
    """
    criterion = args[-1]

    def unsqueeze_inp(inp):
        if isinstance(inp, (list, tuple)):
            return [t.unsqueeze(0) for t in inp]
        return inp.unsqueeze(0)

    def flatten(xs):
        result = []
        if isinstance(xs, (list, tuple)):
            for x in xs:
                result.extend(flatten(x))
        else:
            result.append(xs)
        return result

    single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]])

    output = criterion(*single_batch_input_args)
    reduction = get_reduction(criterion)

    if reduction == 'none':
        return output.squeeze(0)
    # reduction is 'sum' or 'mean' which results in a scalar
    return output


# Check that regression criterion work with no batch dimensions
regression_criterion_no_batch = [
    'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss'
]
reductions = ['none', 'mean', 'sum']
for name, reduction in product(regression_criterion_no_batch, reductions):
    regression_test_info = dict(
        fullname=f"{name}_no_batch_dim_{reduction}",
        constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
        input_size=(3, ),
        target_size=(3, ),
        reference_fn=single_batch_reference_criterion_fn,
        test_cpp_api_parity=False,
        default_dtype=torch.double,
    )
    criterion_tests.append(regression_test_info)


for reduction in reductions:
    regression_test_info = dict(
        fullname=f"KLDivLoss_no_batch_dim_{reduction}",
        constructor=lambda: nn.KLDivLoss(reduction=reduction),
        input_fn=lambda: torch.rand((3,)).log(),
        target_fn=lambda: torch.rand((3,)),
        reference_fn=single_batch_reference_criterion_fn,
        test_cpp_api_parity=False,
        default_dtype=torch.double,
    )
    criterion_tests.append(regression_test_info)


# Check that classification criterion work with no batch dimensions
# List of tuples of (name, input_fn, target_fn)
classification_criterion_no_batch = [
    (
        'BCELoss',
        lambda: torch.sigmoid(torch.randn(9, dtype=torch.double)),
        lambda: torch.randn(9, dtype=torch.double).gt(0).to(torch.double)
    ),
    ('BCEWithLogitsLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9, dtype=torch.double)),
    ('HingeEmbeddingLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
    ('MultiLabelMarginLoss', lambda: torch.randn(4, dtype=torch.double), lambda: torch.tensor([3, 0, -1, 1])),
    ('SoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
    ('NLLLoss', lambda: F.log_softmax(torch.randn(3, dtype=torch.double), dim=0), lambda: torch.tensor(1)),
    (
        'CosineEmbeddingLoss',
        lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
        lambda: torch.tensor(1, dtype=torch.double)
    ),
    # For MarginRankingLoss, input_fn : (x1, x2) and target_fn : target
    ('MarginRankingLoss', lambda: (torch.randn(()), torch.randn(())), lambda: torch.randn(()).sign()),
    # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative
    (
        'TripletMarginLoss',
        lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
        lambda: torch.randn(9, dtype=torch.double)
    ),
    ('MultiLabelSoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9)),
]
classification_criterion_no_batch_extra_info: dict[str, dict] = {
    'MultiLabelMarginLoss': {'check_gradgrad': False},
}
# TODO : Fix these discrepancies
classification_cpp_parity = {
    'BCELoss': False,
    'BCEWithLogitsLoss': False,
    'HingeEmbeddingLoss': False,
    'NLLLoss': False,
    'SoftMarginLoss': False,
}
reductions = ['none', 'mean', 'sum']
for (name, input_fn, target_fn), reduction in product(classification_criterion_no_batch,
                                                      reductions):
    classification_test_info = dict(
        fullname=f"{name}_no_batch_dim_{reduction}",
        constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
        input_fn=lambda f=input_fn: f(),
        target_fn=lambda f=target_fn: f(),
        reference_fn=single_batch_reference_criterion_fn,
        test_cpp_api_parity=True,
        has_parity=classification_cpp_parity.get(name, True)
    )
    extra_info = classification_criterion_no_batch_extra_info.get(name, {})
    classification_test_info.update(extra_info)
    criterion_tests.append(classification_test_info)


class NNTestCase(TestCase):

    # _forward is defined in classes inheriting from NNTestCase
    @abstractmethod
    def _forward(self, *args, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def _get_parameters(self, module: nn.Module) -> tuple[list[nn.Parameter], list[nn.Parameter]]:
        raise NotImplementedError

    @abstractmethod
    def _zero_grad_parameters(self, module: nn.Module) -> None:
        raise NotImplementedError

    @abstractmethod
    def _backward(self, module: nn.Module,
                  input: _TensorOrTensors, output: torch.Tensor,
                  grad_output: Union[torch.Tensor, Sequence[torch.Tensor]],
                  create_graph: bool = False):
        raise NotImplementedError

    def _jacobian(self, input, num_out):
        if isinstance(input, tuple):
            return tuple(self._jacobian(elem, num_out) for elem in input)
        elif isinstance(input, list):
            return [self._jacobian(elem, num_out) for elem in input]
        else:
            return torch.zeros(input.nelement(), num_out)

    def _flatten_tensors(self, x):
        if isinstance(x, torch.Tensor):
            if x.is_sparse:
                return x.to_dense().view(-1)
            else:
                return x.view(-1)
        else:
            return tuple(self._flatten_tensors(a) for a in x)

    def _zero_grad_input(self, input):
        if isinstance(input, torch.Tensor):
            if input.requires_grad and input.grad is not None:
                input.grad.zero_()
                input.grad.detach_()
        else:
            for i in input:
                self._zero_grad_input(i)

    def _analytical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
        output = self._forward(module, input)
        output_size = output.nelement()

        if jacobian_input:
            jacobian_inp = self._jacobian(input, output_size)
            flat_jacobian_input = list(_iter_tensors(jacobian_inp))

        if jacobian_parameters:
            num_param = sum(p.numel() for p in self._get_parameters(module)[0])
            jacobian_param = torch.zeros(num_param, output_size)

        for i in range(output_size):
            param, d_param = self._get_parameters(module)
            # make non grad zeros
            d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param)]

            d_out = torch.zeros_like(output)
            flat_d_out = d_out.view(-1)
            flat_d_out[i] = 1

            if jacobian_parameters:
                self._zero_grad_parameters(module)
            # Tensors will accumulate gradient from multiple steps
            if jacobian_input:
                self._zero_grad_input(input)
            d_input = self._backward(module, input, output, d_out)

            if jacobian_input:
                for jacobian_x, d_x in zip(flat_jacobian_input, _iter_tensors(d_input)):
                    jacobian_x[:, i] = d_x.contiguous().view(-1)
            if jacobian_parameters:
                jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)

        res: tuple[torch.Tensor, ...] = ()
        if jacobian_input:
            res += jacobian_inp,
        if jacobian_parameters:
            res += jacobian_param,

        return res

    def _numerical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
        def fw(*input):
            return self._forward(module, input).detach()

        res: tuple[torch.Tensor, ...] = ()
        if jacobian_input:
            res += _get_numerical_jacobian(fw, input, eps=1e-6),
        if jacobian_parameters:
            param, _ = self._get_parameters(module)
            to_cat = []
            for p in param:
                jacobian = _get_numerical_jacobian(fw, input, target=p, eps=1e-6)
                # get_numerical_jacobian returns a list of tuples but we require a tensor
                to_cat.append(jacobian[0][0])
            res += (torch.cat(to_cat, 0),)
        return res

    def check_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True):
        jacobian_parameters = bool(self._get_parameters(module)[0])
        analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters)
        numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters)
        analytical_t = list(_iter_tensors(analytical))
        numerical_t = list(_iter_tensors(numerical))

        differences = []
        for a, n in zip(analytical_t, numerical_t):
            if a.numel() != 0:
                differences.append(a.add(n, alpha=-1).abs().max())
            # TODO: compare structure (ensure analytic jacobian has correct shape)
        if len(differences) > 0:
            self.assertLessEqual(max(differences), PRECISION)  # type: ignore[type-var]


class TestBase:

    _required_arg_names = {'constructor_args', 'input', 'extra_args'}

    def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs):
        self.desc = desc
        self.fullname = fullname
        self.constructor = constructor
        self.reference_fn = reference_fn
        for name in self._required_arg_names:
            if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs:
                if name in {'constructor_args', 'extra_args'}:
                    kwargs[name] = ()
                else:
                    raise ValueError(f"{self.get_name()}: Specify {name} by a value, a function to generate it, or it's size!")
        self._extra_kwargs = kwargs
        self._arg_cache = {}

    def get_name(self):
        if self.fullname is not None:
            return 'test_' + self.fullname

        test_name = 'test_' + self.constructor.__name__
        if self.desc:
            test_name += '_' + self.desc
        return test_name

    def _unpack(self, value):
        if isinstance(value, torch.Tensor):
            return value
        elif is_iterable(value):
            return type(value)(self._unpack(v) for v in value)
        else:
            return value

    @property
    def constructor_args(self):
        return self._get_arg('constructor_args', True)

    @property
    def extra_args(self):
        return self._get_arg('extra_args', True)

    def _get_arg(self, name, unpack):
        assert name in self._required_arg_names

        if name not in self._arg_cache:
            fn_name = name + '_fn'
            size_name = name + '_size'

            if name in self._extra_kwargs:
                self._arg_cache[name] = self._extra_kwargs[name]
            elif fn_name in self._extra_kwargs:
                self._arg_cache[name] = self._extra_kwargs[fn_name]()
            else:
                assert size_name in self._extra_kwargs, \
                    f"Missing `{name}`, `{size_name}` or `{fn_name}` for {self.get_name()}"

                def map_tensor_sizes(sizes):
                    if isinstance(sizes, list):
                        return [map_tensor_sizes(s) for s in sizes]
                    elif isinstance(sizes, torch.Tensor):
                        return sizes.double()
                    else:
                        return torch.randn(sizes)

                self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name])

        return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name]

    def _get_input(self, unpack=True):
        return self._get_arg('input', unpack)

    def __call__(self, test_case):
        raise NotImplementedError


class ModuleTest(TestBase):

    @abstractmethod
    def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any:
        raise NotImplementedError

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.jacobian_input = kwargs.get('jacobian_input', True)
        self.should_test_cuda = kwargs.get('test_cuda', True)
        self.should_test_pickle = kwargs.get('pickle', True)
        self.check_gradgrad = kwargs.get('check_gradgrad', True)
        self.FIXME_no_cuda_gradgrad_comparison = \
            kwargs.get('FIXME_no_cuda_gradgrad_comparison', False)
        self.precision = kwargs.get('precision', 2e-4)
        self.check_forward_only = kwargs.get('check_forward_only', False)
        self.default_dtype = kwargs.get('default_dtype', None)
        if self.default_dtype is None:
            self.default_dtype = torch.get_default_dtype()

    def __call__(self, test_case):
        with set_default_dtype(self.default_dtype):
            module = self.constructor(*self.constructor_args)
            input = self._get_input()

            if self.reference_fn is not None:
                out = test_case._forward(module, input)
                ref_input = deepcopy(input)
                ref_module = deepcopy(module)
                expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0], ref_module)
                test_case.assertEqual(out, expected_out, exact_dtype=False)
            if self.check_forward_only:
                return
            self.test_noncontig(test_case, module, input)

            if self.should_test_pickle:
                # TODO: do this with in-memory files as soon as torch.save will support it
                with tempfile.TemporaryFile() as f:
                    test_case._forward(module, input)
                    torch.save(module, f)
                    f.seek(0)
                    # weights_only=False as this is legacy code that saves the model
                    module_copy = torch.load(f, weights_only=False)
                    test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))

            self._do_test(test_case, module, input)

    def noncontiguize(self, obj):
        if isinstance(obj, list):
            return [self.noncontiguize(o) for o in obj]
        elif isinstance(obj, tuple):
            return tuple(self.noncontiguize(o) for o in obj)
        tensor = obj
        ndim = tensor.dim()
        # Always making only the last dimension noncontiguous is easy to hide
        # bugs because .view(-1) will still work. So try to find a dim with size
        # > 1 and make that non-contiguous, i.e., stack + select on the
        # dimension directly after that.
        dim = ndim
        for d in range(ndim):
            if tensor.size(d) > 1:
                dim = d + 1
                break
        noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach()
        assert noncontig.numel() == 1 or noncontig.numel() == 0 or not noncontig.is_contiguous()
        noncontig.requires_grad = tensor.requires_grad
        return noncontig

    def test_noncontig(self, test_case, module, input):
        # check no scalars, can't make non-contig
        if isinstance(input, torch.Tensor) and input.dim() == 0:
            return
        if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)):
            return

        test_case._zero_grad_parameters(module)
        test_case._zero_grad_input(input)
        with freeze_rng_state():
            output = test_case._forward(module, input)
            if getattr(module, "return_indices", False):
                output = output[0]
            grad_output = output.new(output.shape).normal_()
            output = output.clone()
            d_input = deepcopy(test_case._backward(module, input, output, grad_output))
            d_param = deepcopy(test_case._get_parameters(module)[1])

        nc_input = self.noncontiguize(input)
        nc_grad_output = self.noncontiguize(grad_output)
        for contig_i, contig_g in product((True, False), repeat=2):
            i = input if contig_i else nc_input
            # Some ops, e.g., nn.Flatten, return gradient that shares
            # storage with the grad_output. Hence we copy here.
            go = deepcopy(grad_output if contig_g else nc_grad_output)
            test_case._zero_grad_parameters(module)
            test_case._zero_grad_input(i)
            with freeze_rng_state():
                out = test_case._forward(module, i)
                if getattr(module, "return_indices", False):
                    out = out[0]
                grad = test_case._backward(module, i, out, go)

                test_case.assertEqual(out, output)
                test_case.assertEqual(grad, d_input, atol=1e-4, rtol=0)
                test_case.assertEqual(test_case._get_parameters(module)[1], d_param)

    def test_cuda(self, test_case):
        if not TEST_CUDA or not self.should_test_cuda:
            raise unittest.SkipTest('Excluded from CUDA tests')

        with set_default_dtype(self.default_dtype):
            cpu_input = self._get_input()

            type_map = {torch.double: torch.float}
            cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)

            is_any_input_complex = any(isinstance(t, torch.Tensor) and t.dtype.is_complex for t in cpu_input_tuple)

            gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)

            cpu_module = self.constructor(*self.constructor_args)
            gpu_module = self.constructor(*self.constructor_args).float().cuda()
            cpu_param = test_case._get_parameters(cpu_module)
            gpu_param = test_case._get_parameters(gpu_module)
            for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]):
                gpu_p.data.copy_(cpu_p)

            test_case._zero_grad_input(cpu_input_tuple)
            test_case._zero_grad_input(gpu_input_tuple)
            test_case._zero_grad_parameters(cpu_module)
            test_case._zero_grad_parameters(gpu_module)
            cpu_output = test_case._forward(cpu_module, cpu_input_tuple)
            gpu_output = test_case._forward(gpu_module, gpu_input_tuple)
            if getattr(cpu_module, "return_indices", False):
                cpu_output = cpu_output[0]
                gpu_output = gpu_output[0]
            test_case.assertEqual(cpu_output, gpu_output, atol=self.precision, rtol=0, exact_dtype=False)

            # Run backwards on CPU and GPU and compare results
            for _ in range(5):
                cpu_gradOutput = cpu_output.clone().normal_()
                gpu_gradOutput = cpu_gradOutput.type_as(gpu_output)
                cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput)
                gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput)
                test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
                for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]):
                    test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0)

            # Run double-backwards on CPU and GPU and compare results
            if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison:
                cpu_output = cpu_module(*cpu_input_tuple)
                gpu_output = gpu_module(*gpu_input_tuple)
                if getattr(cpu_module, "return_indices", False):
                    cpu_output = cpu_output[0]
                    gpu_output = gpu_output[0]

                cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
                gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
                gpu_gradOutput.requires_grad = True

                cpu_gradInputs = torch.autograd.grad(
                    cpu_output,
                    cpu_input_tuple + tuple(cpu_module.parameters()),
                    cpu_gradOutput,
                    create_graph=True)
                gpu_gradInputs = torch.autograd.grad(
                    gpu_output,
                    gpu_input_tuple + tuple(gpu_module.parameters()),
                    gpu_gradOutput,
                    create_graph=True)

                for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs):
                    test_case.assertEqual(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0, exact_dtype=False)

                # We mix output into the second backwards computation so that
                # torch.autograd.grad doesn't complain that some inputs
                # are unreachable (which can happen if you differentiate
                # only on the gradient.
                if is_any_input_complex:
                    outputs_cpu = cpu_output.sum().abs() + sum(x.sum().abs() for x in cpu_gradInputs)
                    outputs_gpu = gpu_output.sum().abs() + sum(x.sum().abs() for x in gpu_gradInputs)
                else:
                    outputs_cpu = cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs)
                    outputs_gpu = gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs)

                cpu_gg = torch.autograd.grad(
                    outputs_cpu,
                    cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()),
                    retain_graph=True)
                gpu_gg = torch.autograd.grad(
                    outputs_gpu,
                    gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()),
                    retain_graph=True)
                test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
                for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg):
                    test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0, exact_dtype=False)

            self.test_noncontig(test_case, gpu_module, gpu_input_tuple)


class InputVariableMixin:
    def _get_input(self):
        input = TestBase._get_input(self, False)  # type: ignore[arg-type]

        def map_variables(i):
            if isinstance(i, torch.Tensor):
                if i.is_floating_point() or i.is_complex():
                    i.requires_grad = True
                return i
            else:
                return type(i)(map_variables(elem) for elem in i)

        return map_variables(input)


class NewModuleTest(InputVariableMixin, ModuleTest):  # type: ignore[misc]
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cudnn = kwargs.get('cudnn', False)
        self.check_inplace = kwargs.get('check_inplace', False)
        self.check_gradgrad = kwargs.get('check_gradgrad', True)
        self.skip_double = kwargs.get('skip_double', False)
        self.skip_half = kwargs.get('skip_half', False)
        self.with_tf32 = kwargs.get('with_tf32', False)
        self.tf32_precision = kwargs.get('tf32_precision', 0.001)
        self.test_cpu = kwargs.get('test_cpu', True)
        self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False)
        self.check_batched_grad = kwargs.get('check_batched_grad', True)
        self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None)
        self.supports_forward_ad = kwargs.get('supports_forward_ad', False)
        self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False)

    def _check_gradients(self, test_case, module, input_tuple):
        params = tuple(x for x in module.parameters())
        num_inputs = len(input_tuple)

        def fn_to_gradcheck(*inputs_and_params, **kwargs):
            assert not kwargs
            return test_case._forward(module, inputs_and_params[:num_inputs])

        # gradcheck doesn't support operators that take in dense inputs but
        # return sparse parameters. This only happens in the case of nn.Embedding
        # and nn.EmbeddingBag. Instead, we call `self.check_jacobian`, which
        # is a slightly different version of gradcheck that can handle this.
        if self.has_sparse_gradients:
            assert num_inputs == 1
            test_input_jacobian = torch.is_floating_point(input_tuple[0])
            test_case.check_jacobian(module, input_tuple[0], test_input_jacobian)
        else:
            test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params,
                                           check_batched_grad=self.check_batched_grad,
                                           fast_mode=self.gradcheck_fast_mode,
                                           check_forward_ad=self.supports_forward_ad))

        if self.check_gradgrad:
            test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params,
                                               check_batched_grad=self.check_batched_grad,
                                               fast_mode=self.gradcheck_fast_mode,
                                               check_fwd_over_rev=self.supports_fwgrad_bwgrad))

    def _do_test(self, test_case, module, input):
        num_threads = torch.get_num_threads()
        torch.set_num_threads(1)
        input_tuple = input if isinstance(input, tuple) else (input,)

        self._check_gradients(test_case, module, input_tuple)

        # check if module can be printed
        module.__repr__()

        if self.check_inplace:
            # check if the inplace variant of the module gives the same result
            # as the out-of-place

            # check_inplace doesn't support multiple input tensors, since we don't have any modules
            # that modify the inputs in-place and that accept more than one input
            assert len(input_tuple) == 1
            input = input_tuple[0]

            module_ip = self.constructor(*self.constructor_args, inplace=True)

            input_version = input._version
            with freeze_rng_state():
                output = module(input)
            test_case.assertEqual(input._version, input_version)

            input_ip = deepcopy(input)
            input_ip_clone = input_ip.clone()
            with freeze_rng_state():
                output_ip = module_ip(input_ip_clone)
            test_case.assertNotEqual(input_ip_clone._version, input_version)
            test_case.assertEqual(output, output_ip)
            grad = output.data.clone().normal_()
            if input.grad is not None:
                with torch.no_grad():
                    input.grad.zero_()
            if input_ip.grad is not None:
                with torch.no_grad():
                    input_ip.grad.zero_()
            output.backward(grad)
            output_ip.backward(grad)
            test_case.assertEqual(input.grad, input_ip.grad)

        def assert_module_parameters_are(tensor_type, device_id=None):
            for p in module.parameters():
                test_case.assertIsInstance(p, tensor_type)
                if device_id is not None:
                    test_case.assertEqual(p.get_device(), device_id)

        if all(isinstance(t, torch.LongTensor) for t in input_tuple) and TEST_CUDA:
            # check that cuda() moves module parameters to correct GPU device,
            # and that float() casts parameters correctly
            input_tuple = tuple(t.cuda() for t in input_tuple)
            module.float().cuda()
            module(*input_tuple)
            assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]

            if torch.cuda.device_count() > 1:
                input_tuple = tuple(t.cuda(1) for t in input_tuple)
                module.cuda(1)
                with torch.cuda.device(1):
                    module(*input_tuple)
                assert_module_parameters_are(torch.cuda.FloatTensor, 1)  # type: ignore[attr-defined]
        else:
            # check that float()/double() casters work correctly
            def to_type(tensor, real, complex):
                if tensor.is_complex():
                    return tensor.to(complex)
                elif tensor.is_floating_point():
                    return tensor.to(real)
                else:
                    return tensor

            def to_half(x):
                # TODO: torch.complex32 when properly supported
                return to_type(x, torch.float16, None)

            def to_single(x):
                return to_type(x, torch.float32, torch.complex64)

            def to_double(x):
                return to_type(x, torch.float64, torch.complex128)

            # to float
            input_tuple = tuple(to_single(t) for t in input_tuple)
            module.float()
            module(*input_tuple)
            assert_module_parameters_are(torch.FloatTensor)

            # and back to double
            input_tuple = tuple(to_double(t) for t in input_tuple)
            module.double()
            module(*input_tuple)
            assert_module_parameters_are(torch.DoubleTensor)

            if TEST_CUDA and self.should_test_cuda:
                # check that cuda() moves module parameters to correct GPU device,
                # and that float() casts parameters correctly

                # to GPU0
                input_tuple = tuple(to_single(t).cuda() for t in input_tuple)
                module.float().cuda()
                module(*input_tuple)
                assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]

                # to CPU
                input_tuple = tuple(t.cpu() for t in input_tuple)
                module.cpu()
                module(*input_tuple)
                assert_module_parameters_are(torch.FloatTensor)

                # back to GPU0
                input_tuple = tuple(t.cuda() for t in input_tuple)
                module.cuda()
                module(*input_tuple)
                assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]

                # test that forwards of module runs correctly without cuDNN
                if self.cudnn:
                    with torch.backends.cudnn.flags(enabled=False):
                        module(*input_tuple)
                        assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]

                if torch.cuda.device_count() >= 2:
                    # test cross-GPU transfer works
                    # to GPU1
                    input_tuple = tuple(t.cuda(1) for t in input_tuple)
                    module.cuda(1)
                    with torch.cuda.device(1):
                        module(*input_tuple)
                    assert_module_parameters_are(torch.cuda.FloatTensor, 1)  # type: ignore[attr-defined]

                if not self.skip_double:
                    # test double()
                    input_tuple = tuple(to_double(t).cuda() for t in input_tuple)
                    module.double().cuda()
                    module(*input_tuple)
                    assert_module_parameters_are(torch.cuda.DoubleTensor, 0)  # type: ignore[attr-defined]

                # test half()
                if not self.skip_half:
                    input_tuple = tuple(to_half(t).cuda() for t in input_tuple)
                    module.half().cuda()
                    module(*input_tuple)
                    assert_module_parameters_are(torch.cuda.HalfTensor, 0)  # type: ignore[attr-defined]
        torch.set_num_threads(num_threads)

    def _get_target(self):
        return self._get_arg('target', False)

    @property
    def constructor_args(self):
        return self._get_arg('constructor_args', False)


class CriterionTest(InputVariableMixin, TestBase):  # type: ignore[misc]
    # TODO: check that criterions don't ignore grad_output

    _required_arg_names = TestBase._required_arg_names.union({'target'})

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.should_test_cuda = kwargs.get('test_cuda', True)
        self.check_forward_only = kwargs.get('check_forward_only', False)
        self.check_gradgrad = kwargs.get('check_gradgrad', True)
        self.check_half = kwargs.get('check_half', True)
        self.check_bfloat16 = kwargs.get('check_bfloat16', False)
        self.check_complex = kwargs.get('check_complex', False)
        self.test_cpu = kwargs.get('test_cpu', True)
        self.with_tf32 = kwargs.get('with_tf32', True)
        self.tf32_precision = kwargs.get('tf32_precision', 0.001)
        self.check_batched_grad = kwargs.get('check_batched_grad', True)
        self.default_dtype = kwargs.get('default_dtype', None)
        if self.default_dtype is None:
            self.default_dtype = torch.get_default_dtype()

    def __call__(self, test_case):
        with set_default_dtype(self.default_dtype):
            module = self.constructor(*self.constructor_args)
            input = self._get_input()

            # Check that these methods don't raise errors
            module.__repr__()
            str(module)

            target = self._get_target()

            if self.reference_fn is not None:
                out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args)
                ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
                expected_out = self.reference_fn(*ref_args)
                test_case.assertEqual(out, expected_out)

            if self.check_forward_only:
                return

            params = tuple(x for x in module.parameters())
            if not isinstance(input, tuple):
                inputs = (input,) + params + (target,)

                def apply_fn(input, target, *params):
                    return module(input, target)
            else:
                inputs = input + params + (target,)

                def apply_fn(input1, input2, target, *params):  # type: ignore[misc]
                    return module(input1, input2, target)

            gradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)

            if self.check_gradgrad:
                gradgradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)

    def test_cuda(self, test_case, dtype, extra_args=None):
        def convert_dtype(obj, dtype, requires_grad=False):
            if isinstance(obj, torch.Tensor):
                return obj.detach().to(dtype=dtype).requires_grad_(requires_grad)
            elif isinstance(obj, tuple):
                return tuple(convert_dtype(o, dtype, requires_grad) for o in obj)
            else:
                return obj

        if not TEST_CUDA or not self.should_test_cuda:
            raise unittest.SkipTest('Excluded from CUDA tests')

        with set_default_dtype(self.default_dtype):
            cpu_input = self._get_input()
            cpu_target = self._get_target()
            cpu_module = self.constructor(*self.constructor_args)
            gpu_module = self.constructor(*self.constructor_args)

            # Convert input, target and module parameters to dtype
            cpu_input = convert_dtype(cpu_input, dtype, True)
            if cpu_target.is_floating_point() or cpu_target.is_complex():
                cpu_target = convert_dtype(cpu_target, dtype)
            cpu_module.type(dtype)
            gpu_module.type(dtype)

            # GPU setup
            gpu_input = to_gpu(cpu_input)
            gpu_target = to_gpu(cpu_target)
            gpu_module.cuda()

            # torch.HalfTensor doesn't support most operations, converting back to default
            if dtype in {torch.half, torch.bfloat16}:
                cpu_input = self._get_input()
                cpu_target = self._get_target()
                # Loss modules with weights require consistent input/module weight types
                cpu_module = self.constructor(*self.constructor_args)

            cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args)
            gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args)
            # dtype used to be able to be None, so set precision in this way instead of a precision map
            test_case.assertEqual(cpu_output, gpu_output,
                                  atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)

            cpu_gradInput = test_case._backward_criterion(
                cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args)
            gpu_gradInput = test_case._backward_criterion(
                gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args)
            # dtype used to be able to be None, so set precision in this way instead of a precision map
            test_case.assertEqual(cpu_gradInput, gpu_gradInput,
                                  atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)

    def _get_target(self):
        return self._get_arg('target', False)

    @property
    def constructor_args(self):
        return self._get_arg('constructor_args', False)

    @property
    def extra_args(self):
        return self._get_arg('extra_args', False)


def _test_bfloat16_ops(test_case, op, device, inp_dims=(), prec=1e-2, scale_factor=None):
    # fp32 compute
    input1 = torch.randn(inp_dims, dtype=torch.float32, device=device, requires_grad=True)
    if scale_factor is not None:
        input1 = (torch.rand(inp_dims, dtype=torch.bfloat16, device=device) * scale_factor).float().requires_grad_()
    out1 = op(input1)
    grad_input1 = torch.randn_like(out1, device=device)
    out1.backward(grad_input1)

    # bfloat16 compute
    op_bfp16 = op.bfloat16()
    input2 = input1.detach().bfloat16().requires_grad_()
    grad_input2 = grad_input1.bfloat16()
    out2 = op_bfp16(input2)
    out2.backward(grad_input2)

    test_case.assertEqual(out1, out2, atol=prec, rtol=prec, exact_dtype=False)
    test_case.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=prec, exact_dtype=False)

def _test_module_empty_input(test_case, module, inp, check_size=True, inference=False):
    if not inference:
        inp.requires_grad_(True)
    out = module(inp)
    if not inference:
        gO = torch.rand_like(out)
        out.backward(gO)
    if check_size:
        test_case.assertEqual(out.size(), inp.size())
    if not inference:
        for p in module.parameters():
            if p.requires_grad:
                test_case.assertEqual(p.grad, torch.zeros_like(p.grad))
        test_case.assertEqual(inp.grad, torch.zeros_like(inp))


def _create_basic_net():
    class Layer(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.layer_dummy_param = nn.Parameter(torch.empty(3, 5))
            self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7))

    class Net(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.l1 = Layer()
            self.dummy_param = nn.Parameter(torch.empty(3, 5))
            self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1))

    l = Layer()
    n = Net()
    s = nn.Sequential(n, n)

    return l, n, s
