# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
from torch import Tensor  # noqa: F401
from torch._jit_internal import List, Optional  # noqa: F401

from .utils import _hide_packed_params_repr, _quantize_weight


__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]


class EmbeddingPackedParams(torch.nn.Module):
    _version = 1

    def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
        super().__init__()
        self.dtype = dtype
        if self.dtype in [torch.quint8, torch.quint4x2]:
            scales = torch.ones(num_embeddings, dtype=torch.float)
            zero_points = torch.zeros(num_embeddings, dtype=torch.float)
            wq = torch._empty_per_channel_affine_quantized(
                [num_embeddings, embedding_dim],
                scales=scales,
                zero_points=zero_points,
                axis=0,
                dtype=self.dtype,
            )
            self.set_weight(wq)
        else:
            raise NotImplementedError(
                f"Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}"
            )

    @torch.jit.export
    def set_weight(self, weight: torch.Tensor) -> None:
        if self.dtype in [torch.quint8, torch.quint4x2]:
            self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
        else:
            raise NotImplementedError(
                "Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2."
            )

    @torch.jit.export
    def _weight(self):
        if self.dtype in [torch.quint8, torch.quint4x2]:
            return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
        else:
            raise NotImplementedError(
                "Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2."
            )

    def forward(self, x):
        return x

    # Version 1
    #   self
    #   |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase
    #   |--- dtype : torch.dtype

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        super()._save_to_state_dict(destination, prefix, keep_vars)
        destination[prefix + "dtype"] = self.dtype
        destination[prefix + "_packed_weight"] = self._weight()

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        self.dtype = state_dict[prefix + "dtype"]
        state_dict.pop(prefix + "dtype")

        weight = state_dict[prefix + "_packed_weight"]
        state_dict.pop(prefix + "_packed_weight")
        self.set_weight(weight)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            False,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

    def __repr__(self):
        return self._weight().__repr__()


class Embedding(torch.nn.Module):
    r"""
    A quantized Embedding module with quantized packed weights as inputs.
    We adopt the same interface as `torch.nn.Embedding`, please see
    https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html for documentation.

    Similar to :class:`~torch.nn.Embedding`, attributes will be randomly
    initialized at module creation time and will be overwritten later

    Attributes:
        weight (Tensor): the non-learnable quantized weights of the module of
                         shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.

    Examples::
        >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12)
        >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8])
        >>> output = m(indices)
        >>> print(output.size())
        torch.Size([9, 12])

    """
    _version = 1

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        _weight: Optional[Tensor] = None,
        dtype=torch.quint8,
    ) -> None:
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.dtype = dtype

        if _weight is None:
            scales = torch.ones(num_embeddings, dtype=torch.float)
            zero_points = torch.zeros(num_embeddings, dtype=torch.float)
            qweight = torch._empty_per_channel_affine_quantized(
                [num_embeddings, embedding_dim],
                scales=scales,
                zero_points=zero_points,
                axis=0,
                dtype=torch.quint8,
            )
        else:
            assert list(_weight.shape) == [
                num_embeddings,
                embedding_dim,
            ], "Shape of weight does not match num_embeddings and embedding_dim"
            qweight = _weight

        self._packed_params = EmbeddingPackedParams(
            num_embeddings, embedding_dim, dtype
        )
        self._packed_params.set_weight(qweight)

    def forward(self, indices: Tensor) -> Tensor:
        if self.dtype == torch.quint4x2:
            return torch.ops.quantized.embedding_4bit(
                self._packed_params._packed_weight, indices
            )
        else:
            return torch.ops.quantized.embedding_byte(
                self._packed_params._packed_weight, indices
            )

    def _get_name(self):
        return "QuantizedEmbedding"

    def __repr__(self):
        return _hide_packed_params_repr(self, EmbeddingPackedParams)

    def extra_repr(self):
        extra_repr_str = (
            f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, "
            f"dtype={self._packed_params.dtype}, qscheme={self.weight().qscheme()}"
        )

        return extra_repr_str

    def set_weight(self, w: torch.Tensor) -> None:
        self._packed_params.set_weight(w)

    def weight(self):
        return self._packed_params._weight()

    @classmethod
    def from_float(cls, mod, use_precomputed_fake_quant=False):
        r"""Create a quantized embedding module from a float module

        Args:
            mod (Module): a float module, either produced by torch.ao.quantization
                          utilities or provided by user
        """
        if hasattr(mod, "weight_fake_quant"):
            assert type(mod) == torch.ao.nn.qat.Embedding, (
                "nnq."
                + cls.__name__
                + ".from_float "
                + "with fake quant only works for "
                + torch.ao.nn.qat.Embedding.__name__
            )
            weight_observer = mod.weight_fake_quant
        else:
            assert type(mod) == nn.Embedding, (
                "nnq."
                + cls.__name__
                + ".from_float only works for "
                + nn.Embedding.__name__
            )
            assert hasattr(
                mod, "qconfig"
            ), "Embedding input float module must have qconfig defined"
            from torch.ao.quantization import float_qparams_weight_only_qconfig

            if mod.qconfig is not None and mod.qconfig.weight is not None:  # type: ignore[union-attr]
                weight_observer = mod.qconfig.weight()  # type: ignore[union-attr, operator]
            else:
                weight_observer = float_qparams_weight_only_qconfig.weight()

        dtype = weight_observer.dtype
        is_float_qparams_qconfig = (
            weight_observer.qscheme == torch.per_channel_affine_float_qparams
        )
        assert (
            is_float_qparams_qconfig
        ), "Embedding quantization is only supported with float_qparams_weight_only_qconfig."

        assert (
            dtype == torch.quint8 or dtype == torch.quint4x2
        ), f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}"

        # Run the observer to calculate qparams.
        weight_observer(mod.weight)
        qweight = _quantize_weight(mod.weight.float(), weight_observer)

        # Create quantized Embedding module and pass in the quantized weight
        qembedding = Embedding(mod.num_embeddings, mod.embedding_dim)
        qembedding.set_weight(qweight)
        return qembedding

    @classmethod
    def from_reference(cls, ref_embedding):
        qembedding = cls(
            ref_embedding.num_embeddings,
            ref_embedding.embedding_dim,
            ref_embedding.padding_idx,
            ref_embedding.max_norm,
            ref_embedding.norm_type,
            ref_embedding.scale_grad_by_freq,
            ref_embedding.sparse,
            ref_embedding.get_quantized_weight(),
            ref_embedding.weight_dtype,
        )
        return qembedding


class EmbeddingBag(Embedding):
    r"""
    A quantized EmbeddingBag module with quantized packed weights as inputs.
    We adopt the same interface as `torch.nn.EmbeddingBag`, please see
    https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html for documentation.

    Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly
    initialized at module creation time and will be overwritten later

    Attributes:
        weight (Tensor): the non-learnable quantized weights of the module of
                         shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.

    Examples::
        >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum')
        >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
        >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32])
        >>> output = m(indices, offsets)
        >>> print(output.size())
        torch.Size([5, 12])

    """
    _version = 1

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        mode: str = "sum",
        sparse: bool = False,
        _weight: Optional[Tensor] = None,
        include_last_offset: bool = False,
        dtype=torch.quint8,
    ) -> None:
        super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype)

        self.mode = mode
        self.pruned_weights = False
        self.include_last_offset = include_last_offset
        self.dtype = dtype

    def forward(
        self,
        indices: Tensor,
        offsets: Optional[Tensor] = None,
        per_sample_weights: Optional[Tensor] = None,
        compressed_indices_mapping: Optional[Tensor] = None,
    ) -> Tensor:
        if self.dtype == torch.quint4x2:
            return torch.ops.quantized.embedding_bag_4bit(
                self._packed_params._packed_weight,
                indices,
                offsets,
                False,
                0,
                self.pruned_weights,
                per_sample_weights,
                compressed_indices_mapping,
                self.include_last_offset,
            )
        else:
            return torch.ops.quantized.embedding_bag_byte(
                self._packed_params._packed_weight,
                indices,
                offsets,
                False,
                0,
                self.pruned_weights,
                per_sample_weights,
                compressed_indices_mapping,
                self.include_last_offset,
            )

    def _get_name(self):
        return "QuantizedEmbeddingBag"

    @classmethod
    def from_float(cls, mod, use_precomputed_fake_quant=False):
        r"""Create a quantized embedding_bag module from a float module

        Args:
            mod (Module): a float module, either produced by torch.ao.quantization
                          utilities or provided by user
        """
        if hasattr(mod, "weight_fake_quant"):
            weight_observer = mod.weight_fake_quant
        else:
            assert type(mod) == nn.EmbeddingBag, (
                "nnq."
                + cls.__name__
                + ".from_float only works for "
                + nn.EmbeddingBag.__name__
            )
            assert hasattr(
                mod, "qconfig"
            ), "EmbeddingBag input float module must have qconfig defined"
            from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig

            if mod.qconfig is not None and mod.qconfig.weight is not None:  # type: ignore[union-attr]
                weight_observer = mod.qconfig.weight()  # type: ignore[union-attr, operator]
            else:
                weight_observer = float_qparams_weight_only_qconfig.weight()

        dtype = weight_observer.dtype
        is_float_qparams_qconfig = (
            weight_observer.qscheme == torch.per_channel_affine_float_qparams
        )
        assert (
            is_float_qparams_qconfig
        ), "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig."

        assert (
            dtype == torch.quint8 or dtype == torch.quint4x2
        ), f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}"

        # Run the observer to calculate qparams.
        weight_observer(mod.weight)
        qweight = _quantize_weight(mod.weight.float(), weight_observer)

        # Create quantized EmbeddingBag module and pass in the quantized weight
        qembedding_bag = EmbeddingBag(
            mod.num_embeddings,
            mod.embedding_dim,
            max_norm=mod.max_norm,
            norm_type=mod.norm_type,
            scale_grad_by_freq=mod.scale_grad_by_freq,
            mode=mod.mode,
            sparse=mod.sparse,
            include_last_offset=mod.include_last_offset,
            dtype=dtype,
        )
        qembedding_bag.set_weight(qweight)
        return qembedding_bag

    @classmethod
    def from_reference(cls, ref_embedding_bag):
        qembedding_bag = cls(
            ref_embedding_bag.num_embeddings,
            ref_embedding_bag.embedding_dim,
            ref_embedding_bag.max_norm,
            ref_embedding_bag.norm_type,
            ref_embedding_bag.scale_grad_by_freq,
            ref_embedding_bag.mode,
            ref_embedding_bag.sparse,
            ref_embedding_bag.get_quantized_weight(),
            ref_embedding_bag.include_last_offset,
            ref_embedding_bag.weight_dtype,
        )
        return qembedding_bag
