# mypy: ignore-errors

import unittest
from functools import partial

import numpy as np

import torch
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import SM53OrLater
from torch.testing._internal.common_device_type import precisionOverride
from torch.testing._internal.common_dtype import (
    all_types_and,
    all_types_and_complex_and,
)
from torch.testing._internal.common_utils import TEST_SCIPY, TEST_WITH_ROCM
from torch.testing._internal.opinfo.core import (
    DecorateInfo,
    ErrorInput,
    OpInfo,
    sample_inputs_spectral_ops,
    SampleInput,
    SpectralFuncInfo,
    SpectralFuncType,
)
from torch.testing._internal.opinfo.refs import (
    _find_referenced_opinfo,
    _inherit_constructor_args,
    PythonRefInfo,
)


has_scipy_fft = False
if TEST_SCIPY:
    try:
        import scipy.fft

        has_scipy_fft = True
    except ModuleNotFoundError:
        pass


class SpectralFuncPythonRefInfo(SpectralFuncInfo):
    """
    An OpInfo for a Python reference of an elementwise unary operation.
    """

    def __init__(
        self,
        name,  # the stringname of the callable Python reference
        *,
        op=None,  # the function variant of the operation, populated as torch.<name> if None
        torch_opinfo_name,  # the string name of the corresponding torch opinfo
        torch_opinfo_variant="",
        **kwargs,
    ):  # additional kwargs override kwargs inherited from the torch opinfo
        self.torch_opinfo_name = torch_opinfo_name
        self.torch_opinfo = _find_referenced_opinfo(
            torch_opinfo_name, torch_opinfo_variant, op_db=op_db
        )
        assert isinstance(self.torch_opinfo, SpectralFuncInfo)

        inherited = self.torch_opinfo._original_spectral_func_args
        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)

        super().__init__(**ukwargs)


def error_inputs_fft(op_info, device, **kwargs):
    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
    # Zero-dimensional tensor has no dimension to take FFT of
    yield ErrorInput(
        SampleInput(make_arg()),
        error_type=IndexError,
        error_regex="Dimension specified as -1 but tensor has no dimensions",
    )


def error_inputs_fftn(op_info, device, **kwargs):
    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
    # Specifying a dimension on a zero-dimensional tensor
    yield ErrorInput(
        SampleInput(make_arg(), dim=(0,)),
        error_type=IndexError,
        error_regex="Dimension specified as 0 but tensor has no dimensions",
    )


def sample_inputs_fft_with_min(
    op_info, device, dtype, requires_grad=False, *, min_size, **kwargs
):
    yield from sample_inputs_spectral_ops(
        op_info, device, dtype, requires_grad, **kwargs
    )
    if TEST_WITH_ROCM:
        # FIXME: Causes floating point exception on ROCm
        return

    # Check the "Invalid number of data points" error isn't too strict
    # https://github.com/pytorch/pytorch/pull/109083
    a = make_tensor(min_size, dtype=dtype, device=device, requires_grad=requires_grad)
    yield SampleInput(a)


def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs):
    def mt(shape, **kwargs):
        return make_tensor(
            shape, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
        )

    yield SampleInput(mt((9, 10)))
    yield SampleInput(mt((50,)), kwargs=dict(dim=0))
    yield SampleInput(mt((5, 11)), kwargs=dict(dim=(1,)))
    yield SampleInput(mt((5, 6)), kwargs=dict(dim=(0, 1)))
    yield SampleInput(mt((5, 6, 2)), kwargs=dict(dim=(0, 2)))


# Operator database
op_db: list[OpInfo] = [
    SpectralFuncInfo(
        "fft.fft",
        aten_name="fft_fft",
        decomp_aten_name="_fft_c2c",
        ref=np.fft.fft,
        ndimensional=SpectralFuncType.OneD,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
        error_inputs_func=error_inputs_fft,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
    ),
    SpectralFuncInfo(
        "fft.fft2",
        aten_name="fft_fft2",
        ref=np.fft.fft2,
        decomp_aten_name="_fft_c2c",
        ndimensional=SpectralFuncType.TwoD,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
        skips=(
            DecorateInfo(
                unittest.skip("Skipped!"),
                "TestCommon",
                "test_complex_half_reference_testing",
                device_type="cuda",
                dtypes=[torch.complex32],
                active_if=TEST_WITH_ROCM,
            ),
        ),
    ),
    SpectralFuncInfo(
        "fft.fftn",
        aten_name="fft_fftn",
        decomp_aten_name="_fft_c2c",
        ref=np.fft.fftn,
        ndimensional=SpectralFuncType.ND,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
    ),
    SpectralFuncInfo(
        "fft.hfft",
        aten_name="fft_hfft",
        decomp_aten_name="_fft_c2r",
        ref=np.fft.hfft,
        ndimensional=SpectralFuncType.OneD,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=2),
        error_inputs_func=error_inputs_fft,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        check_batched_gradgrad=False,
        skips=(
            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
            DecorateInfo(
                unittest.skip("Skipped!"),
                "TestSchemaCheckModeOpInfo",
                "test_schema_correctness",
                dtypes=(torch.complex64, torch.complex128),
            ),
        ),
    ),
    SpectralFuncInfo(
        "fft.hfft2",
        aten_name="fft_hfft2",
        decomp_aten_name="_fft_c2r",
        ref=scipy.fft.hfft2 if has_scipy_fft else None,
        ndimensional=SpectralFuncType.TwoD,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        check_batched_gradgrad=False,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
                "TestFFT",
                "test_reference_nd",
            ),
        ],
        skips=(
            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
            DecorateInfo(
                unittest.skip("Skipped!"),
                "TestSchemaCheckModeOpInfo",
                "test_schema_correctness",
            ),
            # FIXME: errors are too large; needs investigation
            DecorateInfo(
                unittest.skip("Skipped!"),
                "TestCommon",
                "test_complex_half_reference_testing",
                device_type="cuda",
            ),
        ),
    ),
    SpectralFuncInfo(
        "fft.hfftn",
        aten_name="fft_hfftn",
        decomp_aten_name="_fft_c2r",
        ref=scipy.fft.hfftn if has_scipy_fft else None,
        ndimensional=SpectralFuncType.ND,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        check_batched_gradgrad=False,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
                "TestFFT",
                "test_reference_nd",
            ),
        ],
        skips=(
            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
            DecorateInfo(
                unittest.skip("Skipped!"),
                "TestSchemaCheckModeOpInfo",
                "test_schema_correctness",
            ),
        ),
    ),
    SpectralFuncInfo(
        "fft.rfft",
        aten_name="fft_rfft",
        decomp_aten_name="_fft_r2c",
        ref=np.fft.rfft,
        ndimensional=SpectralFuncType.OneD,
        dtypes=all_types_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and(
            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
        ),
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
        error_inputs_func=error_inputs_fft,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        check_batched_grad=False,
        skips=(),
        check_batched_gradgrad=False,
    ),
    SpectralFuncInfo(
        "fft.rfft2",
        aten_name="fft_rfft2",
        decomp_aten_name="_fft_r2c",
        ref=np.fft.rfft2,
        ndimensional=SpectralFuncType.TwoD,
        dtypes=all_types_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and(
            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
        ),
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        decorators=[
            precisionOverride({torch.float: 1e-4}),
        ],
    ),
    SpectralFuncInfo(
        "fft.rfftn",
        aten_name="fft_rfftn",
        decomp_aten_name="_fft_r2c",
        ref=np.fft.rfftn,
        ndimensional=SpectralFuncType.ND,
        dtypes=all_types_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and(
            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
        ),
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        check_batched_grad=False,
        check_batched_gradgrad=False,
        decorators=[
            precisionOverride({torch.float: 1e-4}),
        ],
    ),
    SpectralFuncInfo(
        "fft.ifft",
        aten_name="fft_ifft",
        decomp_aten_name="_fft_c2c",
        ref=np.fft.ifft,
        ndimensional=SpectralFuncType.OneD,
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
        error_inputs_func=error_inputs_fft,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
    ),
    SpectralFuncInfo(
        "fft.ifft2",
        aten_name="fft_ifft2",
        decomp_aten_name="_fft_c2c",
        ref=np.fft.ifft2,
        ndimensional=SpectralFuncType.TwoD,
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
                "TestFFT",
                "test_reference_nd",
            )
        ],
    ),
    SpectralFuncInfo(
        "fft.ifftn",
        aten_name="fft_ifftn",
        decomp_aten_name="_fft_c2c",
        ref=np.fft.ifftn,
        ndimensional=SpectralFuncType.ND,
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
                "TestFFT",
                "test_reference_nd",
            )
        ],
    ),
    SpectralFuncInfo(
        "fft.ihfft",
        aten_name="fft_ihfft",
        decomp_aten_name="_fft_r2c",
        ref=np.fft.ihfft,
        ndimensional=SpectralFuncType.OneD,
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
        error_inputs_func=error_inputs_fft,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        dtypes=all_types_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and(
            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
        ),
        skips=(),
        check_batched_grad=False,
    ),
    SpectralFuncInfo(
        "fft.ihfft2",
        aten_name="fft_ihfft2",
        decomp_aten_name="_fft_r2c",
        ref=scipy.fft.ihfftn if has_scipy_fft else None,
        ndimensional=SpectralFuncType.TwoD,
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        dtypes=all_types_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and(
            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
        ),
        check_batched_grad=False,
        check_batched_gradgrad=False,
        decorators=(
            # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
            DecorateInfo(
                precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
            ),
            # Mismatched elements!
            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warnings"),
        ),
    ),
    SpectralFuncInfo(
        "fft.ihfftn",
        aten_name="fft_ihfftn",
        decomp_aten_name="_fft_r2c",
        ref=scipy.fft.ihfftn if has_scipy_fft else None,
        ndimensional=SpectralFuncType.ND,
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        dtypes=all_types_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss
        dtypesIfCUDA=all_types_and(
            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
        ),
        check_batched_grad=False,
        check_batched_gradgrad=False,
        decorators=[
            # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
            # Mismatched elements!
            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
            DecorateInfo(
                precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
            ),
        ],
    ),
    SpectralFuncInfo(
        "fft.irfft",
        aten_name="fft_irfft",
        decomp_aten_name="_fft_c2r",
        ref=np.fft.irfft,
        ndimensional=SpectralFuncType.OneD,
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
        error_inputs_func=error_inputs_fft,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
        check_batched_gradgrad=False,
    ),
    SpectralFuncInfo(
        "fft.irfft2",
        aten_name="fft_irfft2",
        decomp_aten_name="_fft_c2r",
        ref=np.fft.irfft2,
        ndimensional=SpectralFuncType.TwoD,
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
        check_batched_gradgrad=False,
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
                "TestFFT",
                "test_reference_nd",
            )
        ],
    ),
    SpectralFuncInfo(
        "fft.irfftn",
        aten_name="fft_irfftn",
        decomp_aten_name="_fft_c2r",
        ref=np.fft.irfftn,
        ndimensional=SpectralFuncType.ND,
        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
        error_inputs_func=error_inputs_fftn,
        # https://github.com/pytorch/pytorch/issues/80411
        gradcheck_fast_mode=True,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
        # See https://github.com/pytorch/pytorch/pull/78358
        check_batched_forward_grad=False,
        dtypes=all_types_and_complex_and(torch.bool),
        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
        dtypesIfCUDA=all_types_and_complex_and(
            torch.bool,
            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
        ),
        check_batched_gradgrad=False,
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
                "TestFFT",
                "test_reference_nd",
            )
        ],
    ),
    OpInfo(
        "fft.fftshift",
        dtypes=all_types_and_complex_and(
            torch.bool, torch.bfloat16, torch.half, torch.chalf
        ),
        sample_inputs_func=sample_inputs_fftshift,
        supports_out=False,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
    ),
    OpInfo(
        "fft.ifftshift",
        dtypes=all_types_and_complex_and(
            torch.bool, torch.bfloat16, torch.half, torch.chalf
        ),
        sample_inputs_func=sample_inputs_fftshift,
        supports_out=False,
        supports_forward_ad=True,
        supports_fwgrad_bwgrad=True,
    ),
]

python_ref_db: list[OpInfo] = [
    SpectralFuncPythonRefInfo(
        "_refs.fft.fft",
        torch_opinfo_name="fft.fft",
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.ifft",
        torch_opinfo_name="fft.ifft",
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.rfft",
        torch_opinfo_name="fft.rfft",
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.irfft",
        torch_opinfo_name="fft.irfft",
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.hfft",
        torch_opinfo_name="fft.hfft",
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.ihfft",
        torch_opinfo_name="fft.ihfft",
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.fftn",
        torch_opinfo_name="fft.fftn",
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
                "TestFFT",
                "test_reference_nd",
            )
        ],
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.ifftn",
        torch_opinfo_name="fft.ifftn",
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
                "TestFFT",
                "test_reference_nd",
            )
        ],
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.rfftn",
        torch_opinfo_name="fft.rfftn",
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.irfftn",
        torch_opinfo_name="fft.irfftn",
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
                "TestFFT",
                "test_reference_nd",
            )
        ],
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.hfftn",
        torch_opinfo_name="fft.hfftn",
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
                "TestFFT",
                "test_reference_nd",
            )
        ],
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.ihfftn",
        torch_opinfo_name="fft.ihfftn",
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 2e-4}),
                "TestFFT",
                "test_reference_nd",
            ),
            # AssertionError: Reference result was farther (0.09746177145360499) from the precise
            # computation than the torch result was (0.09111555632069855)
            DecorateInfo(
                unittest.skip("Skipped!"),
                "TestCommon",
                "test_python_ref_torch_fallback",
                dtypes=(torch.float16,),
                device_type="cuda",
            ),
            # AssertionError: Reference result was farther (0.0953431016138116) from the precise
            # computation than the torch result was (0.09305490684430734)
            DecorateInfo(
                unittest.skip("Skipped!"),
                "TestCommon",
                "test_python_ref_executor",
                dtypes=(torch.float16,),
                device_type="cuda",
            ),
        ],
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.fft2",
        torch_opinfo_name="fft.fft2",
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.ifft2",
        torch_opinfo_name="fft.ifft2",
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
                "TestFFT",
                "test_reference_nd",
            )
        ],
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.rfft2",
        torch_opinfo_name="fft.rfft2",
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.irfft2",
        torch_opinfo_name="fft.irfft2",
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
                "TestFFT",
                "test_reference_nd",
            )
        ],
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.hfft2",
        torch_opinfo_name="fft.hfft2",
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
                "TestFFT",
                "test_reference_nd",
            )
        ],
    ),
    SpectralFuncPythonRefInfo(
        "_refs.fft.ihfft2",
        torch_opinfo_name="fft.ihfft2",
        decorators=[
            DecorateInfo(
                precisionOverride({torch.float: 2e-4}),
                "TestFFT",
                "test_reference_nd",
            ),
            # FIXME:
            # Reference result was farther (0.0953431016138116) from the precise computation
            # than the torch result was (0.09305490684430734)!
            DecorateInfo(
                unittest.skip("Skipped!"),
                "TestCommon",
                "test_python_ref_executor",
                device_type="cuda",
            ),
        ],
    ),
    PythonRefInfo(
        "_refs.fft.fftshift",
        op_db=op_db,
        torch_opinfo_name="fft.fftshift",
    ),
    PythonRefInfo(
        "_refs.fft.ifftshift",
        op_db=op_db,
        torch_opinfo_name="fft.ifftshift",
    ),
]
