# mypy: ignore-errors

import os

import torch
from torch.testing import make_tensor  # noqa: F401
from torch.testing._internal.opinfo.core import (  # noqa: F401
    BinaryUfuncInfo,
    ErrorInput,
    generate_elementwise_binary_tensors,
    ReductionOpInfo,
    sample_inputs_reduction,
    SampleInput,
)


def _check_validate(op_info, sample):
    def _check_fail(sample):
        try:
            op_info(
                sample.sample_input.input,
                *sample.sample_input.args,
                **sample.sample_input.kwargs,
            )
        except sample.error_type:
            pass
        except Exception as msg:
            raise AssertionError(  # noqa: B904
                f"{op_info.name} on {sample.sample_input=} expected exception "
                f"{sample.error_type}: {sample.error_regex}, got {type(msg).__name__}: {msg}"
            )
        else:
            raise AssertionError(
                f"{op_info.name} on {sample.sample_input=} expected exception "
                f"{sample.error_type}: {sample.error_regex}, got none."
            )

    def _check_success(sample):
        try:
            op_info(sample.input, *sample.args, **sample.kwargs)
        except Exception as msg:
            raise AssertionError(  # noqa: B904
                f"{op_info.name} on {sample=} expected to succeed "
                f", got {type(msg).__name__}: {msg}"
            )

    if isinstance(sample, ErrorInput):
        _check_fail(sample)
    else:
        _check_success(sample)


def _sample_inputs_sparse(
    sample_inputs,
    maybe_failing_sample_inputs,
    validate_sample_input,
    op_info,
    *args,
    **kwargs,
):
    check_validate = (
        os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1"
    )
    for sample in sample_inputs(op_info, *args, **kwargs):
        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
        if isinstance(sample, SampleInput):
            yield sample
        # Error inputs are handled in error_inputs_sparse

    for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs):
        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
        if isinstance(sample, SampleInput):
            yield sample


def _error_inputs_sparse(
    maybe_failing_sample_inputs, validate_sample_input, op_info, *args, **kwargs
):
    check_validate = (
        os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1"
    )
    for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs):
        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
        if isinstance(sample, ErrorInput):
            yield sample
        # Sample inputs are handled in sample_inputs_sparse


def _apply_requires_grad_to_samples(sample_inputs):
    """Decorator to _maybe_failing_sample_inputs_... generator functions
    that clones and sets requires_grad argument to tensors in sample
    input arguments. This is needed when the generated samples share
    tensor instances.
    """

    def wrapper(op_info, device, dtype, requires_grad, layout, **kwargs):
        def apply_requires_grad(x):
            if (
                not isinstance(x, torch.Tensor)
                or x.requires_grad
                or not requires_grad
                or not (x.is_floating_point() or x.is_complex())
            ):
                return x
            return x.detach().clone().requires_grad_(requires_grad)

        if requires_grad:
            for sample_input in sample_inputs(
                op_info, device, dtype, requires_grad, layout, **kwargs
            ):
                yield sample_input.transform(apply_requires_grad)
        else:
            yield from sample_inputs(
                op_info, device, dtype, requires_grad, layout, **kwargs
            )

    return wrapper


def sample_inputs_sparse_reduction(
    op_info, device, dtype, requires_grad, layout, blocksize=None, **kwargs
):
    """Sample inputs for reduction operations on sparse tensors."""
    layout_name = str(layout).split(".", 1)[-1].rsplit("_coo", 1)[0]
    op_supports_layout = getattr(op_info, "supports_" + layout_name)
    if not op_supports_layout:
        return

    for sample_input in sample_inputs_reduction(
        op_info, device, dtype, requires_grad, **kwargs
    ):
        if sample_input.input.ndim == 0:
            # scalar sparse tensors are not supported
            continue

        if layout in {
            torch.sparse_csr,
            torch.sparse_csc,
            torch.sparse_bsr,
            torch.sparse_bsc,
        }:
            if sample_input.input.ndim < 2:
                # conversion to sparse compressed tensors requires at
                # least 2 dimensional tensors
                continue
            if sample_input.input.ndim > 2 and (sample_input.input == 0).any():
                # Skip batched sparse compressed samples that contain
                # explicit zeros because to_sparse(layout=..) will
                # fail, see gh-98495.
                # TODO: remove this if-block after gh-98495 is fixed.
                continue

        if layout in {torch.sparse_bsr, torch.sparse_bsc} and blocksize is None:
            blocksize = (1, 1)

        yield SampleInput(
            sample_input.input.detach()
            .to_sparse(layout=layout, blocksize=blocksize)
            .requires_grad_(requires_grad),
            args=sample_input.args,
            kwargs=sample_input.kwargs,
        )

        if layout is torch.sparse_coo and (dtype.is_floating_point or dtype.is_complex):
            # uncoalesced samples
            inp = sample_input.input.detach().to_sparse(layout=layout)
            inp = torch.sparse_coo_tensor(
                inp.indices().repeat(1, 2),
                inp.values().repeat(2),
                inp.shape,
                dtype=inp.dtype,
                device=inp.device,
            )
            assert not inp.is_coalesced()
            yield SampleInput(
                inp.requires_grad_(requires_grad),
                args=sample_input.args,
                kwargs=sample_input.kwargs,
            )

        if sample_input.input.ndim > 2:
            # hybrid samples
            yield SampleInput(
                sample_input.input.detach()
                .to_sparse(
                    layout=layout,
                    blocksize=blocksize,
                    dense_dim=sample_input.input.ndim - 2,
                )
                .requires_grad_(requires_grad),
                args=sample_input.args,
                kwargs=sample_input.kwargs,
            )


def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=False):
    """Return the specified sample when it is valid and supported by the
    operation. Otherwise, return the sample as ErrorInput instance.

    When check_validate is True, the result is validated against
    calling the op on the sample.
    """
    UNSPECIFIED = object()
    if op_info.name == "sum":
        sample = _validate_sample_input_sparse_reduction_sum(sample)

    if op_info.name in {"masked.sum"}:
        mask = sample.kwargs.get("mask", UNSPECIFIED)
        if (
            mask not in {None, UNSPECIFIED}
            and mask.ndim > 2
            and mask.layout is torch.strided
            and (mask == 0).any()
        ):
            # TODO: remove this if-block after gh-98495 is fixed.
            sample = ErrorInput(
                sample,
                error_regex="Expect the same number of specified elements per batch.",
            )
        elif not sample.kwargs.get("keepdim"):
            sample = ErrorInput(
                sample,
                error_type=(AssertionError, RuntimeError),
                error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported",
            )
        elif mask is UNSPECIFIED:
            sample = ErrorInput(
                sample,
                error_type=ValueError,
                error_regex="masked (.*) expects explicit mask for sparse_csr tensor input",
            )
        elif sample.input.ndim > 2:
            sample = ErrorInput(
                sample,
                error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.",
            )

    if op_info.name in {"masked.amax", "masked.amin", "masked.mean", "masked.prod"}:
        t_inp = sample.input
        mask = sample.kwargs.get("mask")
        if (
            mask is not None
            and mask.ndim > 2
            and mask.layout is torch.strided
            and (mask == 0).any()
        ):
            # TODO: remove this if-block after gh-98495 is fixed.
            sample = ErrorInput(
                sample,
                error_regex="Expect the same number of specified elements per batch.",
            )
        elif mask is None:
            sample = ErrorInput(
                sample,
                error_type=ValueError,
                error_regex="masked (.*) expects explicit mask for sparse_csr tensor input",
            )
        elif (
            mask.layout is sample.input.layout
            and mask.ndim > 2
            and op_info.name == "masked.mean"
        ):
            sample = ErrorInput(
                sample,
                error_type=TypeError,
                error_regex=(
                    "where[(][)] received an invalid combination of arguments"
                    " - got [(]Tensor, Tensor, NoneType[)]"
                ),
            )
        elif not sample.kwargs.get("keepdim"):
            sample = ErrorInput(
                sample,
                error_type=(AssertionError, RuntimeError),
                error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported",
            )
        elif (
            sample.input.ndim > 2
            and (sample.kwargs.get("dim") not in {0, 1})
            and mask.ndim > 2
            and mask.layout is not torch.strided
        ):
            if sample.kwargs.get("dim") == (0, -1):
                sample = ErrorInput(
                    sample,
                    error_regex="tensor dimensionality must be sum of batch, base, and dense dimensionalities",
                )
            elif op_info.name == "masked.prod":
                sample = ErrorInput(
                    sample,
                    error_regex="input_dim == 2 INTERNAL ASSERT FAILED at",
                )
            else:
                sample = ErrorInput(
                    sample,
                    error_type=AssertionError,
                    error_regex="Sparse CSR tensors are 2D and only support reduction along dim 0 or 1.",
                )
        elif sample.input.ndim > 2:
            sample = ErrorInput(
                sample,
                error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.",
            )
        elif (
            mask.layout is t_inp.layout
            and mask._nnz() != t_inp._nnz()
            and t_inp.dense_dim() > 0
        ):
            sample = ErrorInput(
                sample,
                error_regex="Index tensor must have the same number of dimensions as src tensor",
            )

    if check_validate:
        _check_validate(op_info, sample)

    return sample


def _validate_sample_input_sparse_reduction_sum(sample, check_validate=False):
    # NOTE: When fixing a failing sample case, remove the
    #       corresponding if-block
    t_inp, t_kwargs = sample.input, sample.kwargs
    dim = t_kwargs.get("dim")
    keepdim = t_kwargs.get("keepdim")
    layout = t_inp.layout
    if isinstance(dim, (int, list, tuple)):
        if layout in {
            torch.sparse_csr,
            torch.sparse_csc,
            torch.sparse_bsr,
            torch.sparse_bsc,
        }:
            if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
                return ErrorInput(
                    sample,
                    error_regex=(
                        "Currently the only compressed sparse format supported for sum.dim_IntList is CSR, but got layout"
                    ),
                )
            if layout in {torch.sparse_csr, torch.sparse_csc} and not keepdim:
                return ErrorInput(
                    sample,
                    error_regex=(
                        "reduction operations on CSR tensors with keepdim=False is unsupported"
                    ),
                )
            if t_inp.dim() != 2:
                return ErrorInput(
                    sample,
                    error_regex=("input_dim == 2 INTERNAL ASSERT"),
                )
            if layout == torch.sparse_csr:
                if t_inp.dtype == torch.bool:
                    return ErrorInput(
                        sample,
                        error_regex=("_sparse_csr_sum_cpu not implemented for 'Bool'"),
                    )
                if t_inp.dtype == torch.complex32:
                    return ErrorInput(
                        sample,
                        error_regex=(
                            "_sparse_csr_sum_cuda not implemented for 'ComplexHalf'"
                        ),
                    )
    return sample


def _maybe_failing_sample_inputs_sparse_reduction_sum(
    op_info, device, dtype, requires_grad, layout, **kwargs
):
    """Generator of samples that are known to fail or that were failing in past."""
    # NOTE: When fixing a failing case, remove the Exception comment
    #       but keep the `yield sample` statement.
    if layout in [
        torch.sparse_csr,
        torch.sparse_csc,
    ]:
        # NotImplementedError: Could not run 'aten::sum.IntList_out' with arguments from the 'SparseCsrCPU' backend.
        yield SampleInput(
            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
            .to_sparse(layout=layout)
            .requires_grad_(requires_grad),
            kwargs=dict(dim=0, keepdim=True),
        )
        yield SampleInput(
            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
            .to_sparse(layout=layout, dense_dim=1)
            .requires_grad_(requires_grad),
            kwargs=dict(dim=0),
        )
        yield SampleInput(
            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
            .to_sparse(layout=layout)
            .requires_grad_(requires_grad),
            kwargs=dict(dim=(0,)),
        )
        yield SampleInput(
            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
            .to_sparse(layout=layout)
            .requires_grad_(requires_grad),
            kwargs=dict(dim=(0,), keepdim=True),
        )
        yield SampleInput(
            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
            .to_sparse(layout=layout, dense_dim=1)
            .requires_grad_(requires_grad),
            kwargs=dict(dim=(0,)),
        )

        # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2]
        yield SampleInput(
            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
            .to_sparse(layout=layout)
            .requires_grad_(requires_grad),
            kwargs=dict(dim=0),
        )

    if layout in [
        torch.sparse_bsr,
        torch.sparse_bsc,
    ]:
        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr
        yield SampleInput(
            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
            .to_sparse(layout=layout, blocksize=(2, 2))
            .requires_grad_(requires_grad),
            kwargs=dict(dim=0, keepdim=True),
        )
        yield SampleInput(
            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
            .to_sparse(layout=layout, dense_dim=1, blocksize=(1, 1))
            .requires_grad_(requires_grad),
            kwargs=dict(dim=0),
        )
        yield SampleInput(
            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
            .to_sparse(layout=layout, blocksize=(1, 1))
            .requires_grad_(requires_grad),
            kwargs=dict(dim=(0,)),
        )
        yield SampleInput(
            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
            .to_sparse(layout=layout, blocksize=(1, 1))
            .requires_grad_(requires_grad),
            kwargs=dict(dim=(0,), keepdim=True),
        )
        yield SampleInput(
            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
            .to_sparse(layout=layout, blocksize=(1, 1), dense_dim=1)
            .requires_grad_(requires_grad),
            kwargs=dict(dim=(0,)),
        )

        # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2]
        yield SampleInput(
            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
            .to_sparse(layout=layout, blocksize=(1, 1))
            .requires_grad_(requires_grad),
            kwargs=dict(dim=0),
        )


def sample_inputs_sparse_reduction_sum(
    op_info, device, dtype, requires_grad, layout, **kwargs
):
    """Sample inputs for sum on sparse tensors."""
    yield from _sample_inputs_sparse(
        sample_inputs_sparse_reduction,
        _maybe_failing_sample_inputs_sparse_reduction_sum,
        _validate_sample_input_sparse_reduction,
        op_info,
        device,
        dtype,
        requires_grad,
        layout,
        **kwargs,
    )


def error_inputs_sparse_reduction_sum(op_info, device, layout, **kwargs):
    """Error inputs for sum on sparse tensors."""
    dtype = torch.float64
    requires_grad = False
    yield from _error_inputs_sparse(
        _maybe_failing_sample_inputs_sparse_reduction_sum,
        _validate_sample_input_sparse_reduction,
        op_info,
        device,
        dtype,
        requires_grad,
        layout,
        **kwargs,
    )


def sample_inputs_sparse_elementwise_binary_operation(
    op_info, device, dtype, requires_grad, layout, **kwargs
):
    """Sample inputs for elementwise binary operations on sparse tensors.

    The samples include regular, zero-sized, batched, and hybrid
    sparse tensors as well as rhs scalars. All tensors are full tensors.
    """

    def _to_sparse(tensor, **kwargs):
        return tensor.detach().to_sparse(**kwargs).requires_grad_(requires_grad)

    for sample_input in generate_elementwise_binary_tensors(
        op_info,
        device=device,
        dtype=dtype,
        requires_grad=requires_grad,
        exclude_zero=True,
        **kwargs,
    ):
        lhs, rhs = sample_input.input, sample_input.args[0]
        min_dense_dim = 0
        max_dense_dim = lhs.ndim - 1
        if layout in {
            torch.sparse_csr,
            torch.sparse_csc,
            torch.sparse_bsr,
            torch.sparse_bsc,
        }:
            if lhs.ndim < 2:
                # sparse compressed tensors sparse_dim must be 2
                continue
            max_dense_dim = lhs.ndim - 2

        for dense_dim in range(min_dense_dim, max_dense_dim + 1):
            if layout in {torch.sparse_bsr, torch.sparse_bsc}:
                blocksizes = [(1, 1)]
                if lhs.numel() > 0:
                    blocksizes.append(
                        (
                            lhs.shape[lhs.ndim - 2 - dense_dim],
                            lhs.shape[lhs.ndim - 1 - dense_dim],
                        )
                    )
            else:
                blocksizes = [None]
            for blocksize in blocksizes:
                to_sparse_kwargs = dict(
                    layout=layout, dense_dim=dense_dim, blocksize=blocksize
                )
                lhs_sparse = _to_sparse(lhs, **to_sparse_kwargs)
                rhs_sparse = _to_sparse(rhs, **to_sparse_kwargs)
                # op(sparse, sparse)
                yield SampleInput(
                    lhs_sparse,
                    args=(rhs_sparse, *sample_input.args[1:]),
                    kwargs=sample_input.kwargs,
                )
                # op(sparse, scalar)
                yield SampleInput(
                    lhs_sparse,
                    args=(
                        make_tensor(
                            (), dtype=dtype, device=device, requires_grad=requires_grad
                        ),
                        *sample_input.args[1:],
                    ),
                    kwargs=sample_input.kwargs,
                )


def _validate_sample_input_elementwise_binary_sparse_mul(sample):
    # NOTE: When fixing a failing sample case, remove the
    #       corresponding if-block
    t_inp, t_args = sample.input, sample.args
    batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
    layout = t_inp.layout
    dtype = t_inp.dtype
    if layout is torch.sparse_csr and batch_dim > 0 and t_args[0].ndim > 0:
        return ErrorInput(
            sample,
            error_regex=(
                "coo_to_sparse_csr: conversion from Sparse to SparseCsr for input"
                " tensors with sparse_dim[(][)]!=2 is not supported"
            ),
        )
    elif layout is torch.sparse_csc and t_args[0].ndim > 0:
        return ErrorInput(
            sample, error_regex="Expected result Tensor to be of format CSR"
        )
    elif layout is torch.sparse_bsr and t_args[0].ndim > 0:
        return ErrorInput(
            sample,
            error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsr",
        )
    elif layout is torch.sparse_bsc and t_args[0].ndim > 0:
        return ErrorInput(
            sample,
            error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsc",
        )
    elif (
        layout is torch.sparse_coo
        and dtype is torch.bool
        and t_args[0].ndim > 0
        and t_inp.is_cpu
        and t_inp.numel() > 0
        and t_inp.dense_dim() > 0
    ):
        return ErrorInput(
            sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Bool'"
        )
    elif (
        layout in {torch.sparse_coo, torch.sparse_csr}
        and dtype is torch.bool
        and t_inp._nnz() > 0
        and t_args[0].ndim > 0
        and t_inp.is_cpu
        and t_inp.numel() > 0
    ):
        return ErrorInput(
            sample, error_regex="\"mul_out_sparse\" not implemented for 'Bool'"
        )
    elif (
        layout is torch.sparse_csr
        and t_args[0].layout is torch.strided
        and 0 < t_args[0].ndim
        and t_args[0].ndim < t_inp.ndim
    ):
        return ErrorInput(
            sample, error_regex="sparse_mask_sparse_csr expects self to be 2D"
        )
    elif layout is torch.sparse_csr and (
        (t_args[0].layout is torch.strided and 0 < t_args[0].ndim)
        or (t_args[0].layout is layout and t_inp.shape != t_args[0].shape)
    ):
        return ErrorInput(
            sample,
            error_regex=(
                "expects sparse inputs with equal dimensionality, number of sparse dimensions,"
                " and shape of sparse dimensions"
            ),
        )
    elif (
        layout is torch.sparse_csr
        and t_inp.dense_dim() > 0
        and t_inp._nnz() > 0
        and t_inp.is_cpu
        and dtype is torch.float16
        and t_args[0].ndim > 0
    ):
        return ErrorInput(
            sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Half'"
        )
    return sample


@_apply_requires_grad_to_samples
def _maybe_failing_sample_inputs_sparse_elementwise_binary_mul(
    op_info, device, dtype, requires_grad, layout, **kwargs
):
    """Generator of samples that are known to fail or that were failing in past."""
    # NOTE: When fixing a failing case, remove the Exception comment
    #       but keep the `yield sample` statement.

    blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
    regular = torch.tensor([[1, 2], [3, 4]], device=device, dtype=dtype).to_sparse(
        layout=layout, dense_dim=0, blocksize=blocksize
    )
    batch = torch.tensor(
        [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], device=device, dtype=dtype
    ).to_sparse(layout=layout, dense_dim=0, blocksize=blocksize)
    hybrid = torch.tensor(
        [[[1], [2]], [[3], [4]]], device=device, dtype=dtype
    ).to_sparse(layout=layout, dense_dim=1, blocksize=blocksize)

    if layout is torch.sparse_csr:
        # RuntimeError: crow_indices is supposed to be a vector, but got 2 dimensional tensor
        yield SampleInput(batch, args=(batch,))
        # RuntimeError: Only tensors with two sparse dimensions can be
        # converted to the SparseCsr layout, got self with 3 sparse
        # dimensions.
        yield SampleInput(
            torch.zeros_like(hybrid).requires_grad_(requires_grad),
            args=(torch.zeros_like(hybrid).requires_grad_(requires_grad),),
        )
        if dtype is torch.complex32:
            # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf'
            yield SampleInput(regular, args=(regular,))
        if dtype is torch.bool and regular.is_cpu:
            # RuntimeError: "mul_out_sparse" not implemented for 'Bool'
            yield SampleInput(regular, args=(regular,))
    if layout is torch.sparse_csc:
        # RuntimeError: Expected result Tensor to be of format CSR
        yield SampleInput(regular, args=(regular,))
    if layout is torch.sparse_bsr:
        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr
        yield SampleInput(regular, args=(regular,))
    if layout is torch.sparse_bsc:
        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsc
        yield SampleInput(regular, args=(regular,))
    if layout is torch.sparse_coo:
        if dtype is torch.complex32:
            # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf'
            yield SampleInput(regular, args=(regular,))
        if dtype is torch.bool and regular.is_cpu:
            # RuntimeError: "mul_out_sparse" not implemented for 'Bool'
            yield SampleInput(regular, args=(regular,))
        if dtype in {torch.bool, torch.float16} and regular.is_cpu:
            # RuntimeError: "addcmul_cpu_out" not implemented for '(Bool|Half)'
            yield SampleInput(hybrid, args=(hybrid,))


def _validate_sample_input_sparse_elementwise_binary_operation(
    op_info, sample, check_validate=False
):
    if op_info.name == "mul":
        sample = _validate_sample_input_elementwise_binary_sparse_mul(sample)

    if check_validate:
        _check_validate(op_info, sample)
    return sample


def sample_inputs_sparse_mul(op_info, device, dtype, requires_grad, layout, **kwargs):
    """Sample inputs for mul operation on sparse tensors."""
    yield from _sample_inputs_sparse(
        sample_inputs_sparse_elementwise_binary_operation,
        _maybe_failing_sample_inputs_sparse_elementwise_binary_mul,
        _validate_sample_input_sparse_elementwise_binary_operation,
        op_info,
        device,
        dtype,
        requires_grad,
        layout,
        **kwargs,
    )


def error_inputs_sparse_mul(op_info, device, layout, **kwargs):
    """Error inputs for mul operation on sparse tensors."""
    dtype = torch.float64
    requires_grad = False
    yield from _error_inputs_sparse(
        _maybe_failing_sample_inputs_sparse_elementwise_binary_mul,
        _validate_sample_input_sparse_elementwise_binary_operation,
        op_info,
        device,
        dtype,
        requires_grad,
        layout,
        **kwargs,
    )


def _sample_inputs_sparse_like_fns(
    op_info, device, dtype, requires_grad, layout, **kwargs
):
    from torch.testing._internal.common_utils import TestCase

    for tensor in TestCase().generate_simple_inputs(
        layout,
        device=device,
        dtype=dtype,
        enable_batch=True,
        enable_hybrid=True,
        enable_zero_sized=True,
        enable_non_contiguous_indices=False,
        enable_non_contiguous_values=False,
    ):
        yield SampleInput(tensor, args=(), kwargs={})
        yield SampleInput(
            tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout)
        )

        if dtype is not torch.float64:
            yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64))

        if torch.cuda.is_available():
            other_device = "cuda" if tensor.device.type == "cpu" else "cpu"
            yield SampleInput(tensor, args=(), kwargs=dict(device=other_device))

        if layout is torch.sparse_csr:
            other_layout = torch.sparse_csc
        elif layout is torch.sparse_csc:
            other_layout = torch.sparse_csr
        elif layout is torch.sparse_bsr:
            other_layout = torch.sparse_bsc
        elif layout is torch.sparse_bsc:
            other_layout = torch.sparse_bsr
        else:
            other_layout = torch.strided
        yield SampleInput(tensor, args=(), kwargs=dict(layout=other_layout))

        if layout is not torch.sparse_coo:
            yield SampleInput(tensor, args=(), kwargs=dict(layout=torch.sparse_coo))


def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
    if sample.input.layout in {
        torch.sparse_csr,
        torch.sparse_csc,
        torch.sparse_bsr,
        torch.sparse_bsc,
    } and op_info.name not in {"zeros_like"}:
        if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
            return ErrorInput(
                sample,
                error_regex=(
                    "empty_like with different sparse layout is not supported"
                    " \\(self is Sparse(Csc|Csr|Bsc|Bsr) but you requested Sparse(Csr|Csc|Bsr|Bsc)\\)"
                ),
            )
    if sample.input.layout is torch.sparse_coo:
        return ErrorInput(
            sample,
            error_regex=(
                "Could not run 'aten::normal_' with arguments from the 'Sparse(CPU|CUDA)' backend."
            ),
        )
    if check_validate:
        _check_validate(op_info, sample)
    return sample


def _maybe_failing_sample_inputs_sparse_like_fns(
    op_info, device, dtype, requires_grad, layout, **kwargs
):
    if torch.cuda.is_available() and layout is not torch.sparse_coo:
        other_device = "cuda" if torch.device(device).type == "cpu" else "cpu"
        if layout is torch.sparse_csr:
            other_layout = torch.sparse_csc
        elif layout is torch.sparse_csc:
            other_layout = torch.sparse_csr
        elif layout is torch.sparse_bsr:
            other_layout = torch.sparse_bsc
        elif layout is torch.sparse_bsc:
            other_layout = torch.sparse_bsr
        else:
            other_layout = torch.strided

        blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None

        yield SampleInput(
            torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
                layout=layout, blocksize=blocksize
            ),
            kwargs=dict(device=other_device),
        )

        yield SampleInput(
            torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
                layout=layout, blocksize=blocksize
            ),
            kwargs=dict(layout=other_layout),
        )


def sample_inputs_sparse_like_fns(
    op_info, device, dtype, requires_grad, layout, **kwargs
):
    """Sample inputs for like-functions on sparse tensors."""
    yield from _sample_inputs_sparse(
        _sample_inputs_sparse_like_fns,
        _maybe_failing_sample_inputs_sparse_like_fns,
        _validate_sample_input_sparse_like_fns,
        op_info,
        device,
        dtype,
        requires_grad,
        layout,
        **kwargs,
    )


def error_inputs_sparse_like_fns(op_info, device, layout, **kwargs):
    """Error inputs for like-functions on sparse tensors."""
    dtype = torch.float64
    requires_grad = False
    yield from _error_inputs_sparse(
        _maybe_failing_sample_inputs_sparse_like_fns,
        _validate_sample_input_sparse_like_fns,
        op_info,
        device,
        dtype,
        requires_grad,
        layout,
        **kwargs,
    )


def _validate_sample_input_sparse_default(op_info, sample, check_validate=False):
    if op_info.name == "to_sparse":
        if (
            sample.input.layout
            in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
            and len(sample.args) == 1
            and isinstance(sample.args[0], int)
            and sample.args[0] != 2
        ):
            sample = ErrorInput(
                sample,
                error_regex="sparse dim argument must be 2 for sparse_compressed_to_sparse",
            )

    if check_validate:
        _check_validate(op_info, sample)
    return sample


def validate_sample_input_sparse(op_info, sample, check_validate=False):
    """Return the specified sample when it is valid and supported by the
    operation. Otherwise, return the sample as ErrorInput instance.

    When check_validate is True, the result is validated against
    calling the op on the sample.
    """
    if isinstance(op_info, ReductionOpInfo):
        return _validate_sample_input_sparse_reduction(
            op_info, sample, check_validate=check_validate
        )
    elif isinstance(op_info, BinaryUfuncInfo):
        return _validate_sample_input_sparse_elementwise_binary_operation(
            op_info, sample, check_validate=check_validate
        )
    else:
        return _validate_sample_input_sparse_default(
            op_info, sample, check_validate=check_validate
        )
