# mypy: allow-untyped-defs
from typing import Optional, Union

import sympy

import torch

from .ir import Pointwise, TensorBox
from .lowering import fallback_handler, is_integer_type, register_lowering
from .virtualized import ops


# pyre-ignore[2,3]
def dense_idx_to_jagged_idx(batch_idx, seq_idx, offsets_loader, jagged_len):
    # jagged_len + 1 is used as the upper bound,
    # because the last sequence length may be zero
    begin_idx = ops.indirect_indexing(
        offsets_loader([batch_idx]),
        jagged_len + 1,
    )
    end_idx = offsets_loader([batch_idx + 1])
    jagged_idx = begin_idx + seq_idx
    return jagged_idx, end_idx


def get_inverse_offsets(
    offsets: TensorBox,
    jagged_len: Union[int, sympy.Expr],
    realize: bool = True,
) -> TensorBox:
    """
    Returns "inverse_offsets" - the inverse of the offsets array.
    offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor).
    inverse_offsets maps jagged index to batch index.

    e.g. for offsets [0, 3, 4, 9, 10] this will return
    inverse_offsets = [0, 0, 0, 1, 2, 2, 2, 2, 2, 3]

    For the given offsets, the computed inverse_offsets are cached
    on the first call and reused in the further calls.
    """

    if hasattr(offsets, "inverse_offsets"):
        # inverse_offsets are already computed
        # for these offsets: can reuse
        return offsets.inverse_offsets

    # ops.bucketize takes offsets.get_name() which doesn't exist on Pointwise
    # kernels, i.e. we need to realize it before using. In other words, we need
    # offsets to be in global memory so that we can binary search over the
    # entire tensor
    offsets.realize()
    device: torch.device = offsets.get_device_or_error()
    dtype: torch.dtype = offsets.get_dtype()

    # pyre-ignore[2,3]
    def inner_fn(index):
        idx = index[0]
        bucket = ops.bucketize(
            values=ops.index_expr(idx, dtype),
            boundaries=(
                offsets.get_name(),
                offsets.get_size()[-1],
                offsets.get_size()[0] * offsets.get_stride()[0],
                offsets.get_stride()[-1],
            ),
            boundary_indices=0,
            indexing_dtype=dtype,
            right=True,
        )
        # ops.bucketize above returns 1-based bucket indices,
        # but we need 0-based, hence we subtract 1 from batch
        return bucket - 1

    inverse_offsets = Pointwise.create(
        device=device,
        dtype=dtype,
        inner_fn=inner_fn,
        ranges=[jagged_len],
    )

    if realize:
        # "freeze" the node so that it doesn't get inlined downstream.
        inverse_offsets.realize()

    # cache inverse_offsets for further reuse
    offsets.inverse_offsets = inverse_offsets  # type: ignore[attr-defined]

    return inverse_offsets


def jagged_idx_to_dense_idx(
    jagged_idx,  # pyre-ignore[2]
    inverse_offsets_loader,  # pyre-ignore[2]
    offsets_loader,  # pyre-ignore[2]
    batch_size: Union[int, sympy.Expr],
    max_seq_len: Union[int, sympy.Expr],
    offsets_dtype: torch.dtype,
) -> tuple[sympy.Expr, sympy.Expr]:
    batch_idx = ops.indirect_indexing(
        inverse_offsets_loader([jagged_idx]),
        batch_size + 1,
    )
    batch_start = offsets_loader([batch_idx])
    seq = ops.index_expr(jagged_idx, offsets_dtype) - batch_start
    # check=False because there may be sequences longer than max_seq_len
    seq_idx = ops.indirect_indexing(seq, max_seq_len, check=False)
    return batch_idx, seq_idx


def register_jagged_ops():
    # pyre-ignore[56]
    @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default)
    def _jagged_to_padded_dense_forward(
        jagged_values: TensorBox,
        jagged_offsets: list[TensorBox],
        max_lengths: list[int],  # list of ints/SymInts
        padding_value: float = 0.0,
    ) -> TensorBox:
        device = jagged_values.get_device_or_error()
        dtype = jagged_values.get_dtype()

        jagged_values_size = jagged_values.get_size()

        # only handle the common case of a single jagged dimension
        if (
            len(jagged_offsets) != 1
            or device.type != "cuda"
            or device != jagged_offsets[0].get_device()
            or len(jagged_values_size) != 2
            or len(jagged_offsets[0].get_size()) != 1
            or len(max_lengths) != len(jagged_offsets)
            or not is_integer_type(jagged_offsets[0])
        ):
            return fallback_handler(
                torch.ops.aten._jagged_to_padded_dense_forward.default,
                add_to_fallback_set=False,
            )(
                jagged_values,
                jagged_offsets,
                max_lengths,
                padding_value,
            )

        offsets: TensorBox = jagged_offsets[0]
        offsets_len = offsets.get_size()[0]
        offsets_dtype = offsets.get_dtype()
        batch_size = offsets_len - 1
        max_seq_len = max_lengths[0]
        embedding_len = jagged_values_size[1]
        jagged_len = jagged_values_size[0]

        output_size = [batch_size, max_seq_len, embedding_len]

        values_loader = jagged_values.make_loader()
        offsets_loader = offsets.make_loader()

        # pyre-ignore[2,3,53]
        def inner_fn(index):
            # dense tensor size: [B, N, D]
            batch_idx, seq_idx, emb_idx = index
            jagged_idx, end_idx = dense_idx_to_jagged_idx(
                batch_idx=batch_idx,
                seq_idx=seq_idx,
                offsets_loader=offsets_loader,
                jagged_len=jagged_len,
            )
            return ops.masked(
                ops.lt(
                    ops.index_expr(jagged_idx, offsets_dtype),
                    end_idx,
                ),
                lambda: values_loader([jagged_idx, emb_idx]),
                padding_value,
            )

        return Pointwise.create(
            device=device,
            dtype=dtype,
            inner_fn=inner_fn,
            ranges=output_size,
        )

    def _dense_to_jagged_forward_impl(
        fallback_op,  # pyre-ignore[2]
        dense: TensorBox,
        jagged_offsets: list[TensorBox],
        jagged_len: Optional[int] = None,
    ) -> TensorBox:
        device = dense.get_device_or_error()
        dtype = dense.get_dtype()

        dense_size = dense.get_size()

        # only handle the common case of a single jagged dimension
        if (
            len(jagged_offsets) != 1
            or device.type != "cuda"
            or device != jagged_offsets[0].get_device()
            or len(jagged_offsets[0].get_size()) != 1
            or len(dense_size) != 3
            or jagged_len is None
            or not is_integer_type(jagged_offsets[0])
        ):
            return fallback_handler(fallback_op, add_to_fallback_set=False)(
                dense,
                jagged_offsets,
                jagged_len,
            )

        offsets: TensorBox = jagged_offsets[0]
        offsets_dtype = offsets.get_dtype()
        batch_size = dense_size[0]
        max_seq_len = dense_size[1]
        embedding_len = dense_size[-1]

        output_size = [jagged_len, embedding_len]

        dense_loader = dense.make_loader()
        offsets_loader = offsets.make_loader()

        inverse_offsets = get_inverse_offsets(
            offsets=offsets,
            jagged_len=jagged_len,
        )
        inverse_offsets_loader = inverse_offsets.make_loader()

        # pyre-ignore[2,3,53]
        def inner_fn(index):
            # jagged tensor size: [sum_B(N_B), D]
            jagged_idx, emb_idx = index
            batch_idx, seq_idx = jagged_idx_to_dense_idx(
                jagged_idx=jagged_idx,
                offsets_loader=offsets_loader,
                inverse_offsets_loader=inverse_offsets_loader,
                batch_size=batch_size,
                max_seq_len=max_seq_len,
                offsets_dtype=offsets_dtype,
            )
            return ops.masked(
                ops.lt(
                    ops.index_expr(seq_idx, offsets_dtype),
                    ops.index_expr(max_seq_len, offsets_dtype),
                ),
                lambda: dense_loader([batch_idx, seq_idx, emb_idx]),
                0.0,  # jagged sequence longer than max_seq_len
            )

        return Pointwise.create(
            device=device,
            dtype=dtype,
            inner_fn=inner_fn,
            ranges=output_size,
        )

    # pyre-ignore[56]
    @register_lowering(torch.ops.aten._padded_dense_to_jagged_forward)
    def _dense_to_jagged_forward(
        dense: TensorBox,
        jagged_offsets: list[TensorBox],
        jagged_len: Optional[int] = None,
    ) -> TensorBox:
        return _dense_to_jagged_forward_impl(
            fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default,
            dense=dense,
            jagged_offsets=jagged_offsets,
            jagged_len=jagged_len,
        )
