"""
We will recreate all the RNN modules as we require the modules to be decomposed
into its building blocks to be able to observe.
"""

# mypy: allow-untyped-defs

import numbers
import warnings
from typing import Optional

import torch
from torch import Tensor


__all__ = ["LSTMCell", "LSTM"]


class LSTMCell(torch.nn.Module):
    r"""A quantizable long short-term memory (LSTM) cell.

    For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`

    `split_gates`: specify True to compute the input/forget/cell/output gates separately
    to avoid an intermediate tensor which is subsequently chunk'd. This optimization can
    be beneficial for on-device inference latency. This flag is cascaded down from the
    parent classes.

    Examples::

        >>> import torch.ao.nn.quantizable as nnqa
        >>> rnn = nnqa.LSTMCell(10, 20)
        >>> input = torch.randn(6, 10)
        >>> hx = torch.randn(3, 20)
        >>> cx = torch.randn(3, 20)
        >>> output = []
        >>> for i in range(6):
        ...     hx, cx = rnn(input[i], (hx, cx))
        ...     output.append(hx)
    """
    _FLOAT_MODULE = torch.nn.LSTMCell
    __constants__ = ["split_gates"]  # for jit.script

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        bias: bool = True,
        device=None,
        dtype=None,
        *,
        split_gates=False,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.input_size = input_dim
        self.hidden_size = hidden_dim
        self.bias = bias
        self.split_gates = split_gates

        if not split_gates:
            self.igates: torch.nn.Module = torch.nn.Linear(
                input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
            )
            self.hgates: torch.nn.Module = torch.nn.Linear(
                hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
            )
            self.gates: torch.nn.Module = torch.ao.nn.quantized.FloatFunctional()
        else:
            # keep separate Linear layers for each gate
            self.igates = torch.nn.ModuleDict()
            self.hgates = torch.nn.ModuleDict()
            self.gates = torch.nn.ModuleDict()
            for g in ["input", "forget", "cell", "output"]:
                # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
                self.igates[g] = torch.nn.Linear(
                    input_dim, hidden_dim, bias=bias, **factory_kwargs
                )
                # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
                self.hgates[g] = torch.nn.Linear(
                    hidden_dim, hidden_dim, bias=bias, **factory_kwargs
                )
                # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
                self.gates[g] = torch.ao.nn.quantized.FloatFunctional()

        self.input_gate = torch.nn.Sigmoid()
        self.forget_gate = torch.nn.Sigmoid()
        self.cell_gate = torch.nn.Tanh()
        self.output_gate = torch.nn.Sigmoid()

        self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
        self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
        self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()

        self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()

        self.initial_hidden_state_qparams: tuple[float, int] = (1.0, 0)
        self.initial_cell_state_qparams: tuple[float, int] = (1.0, 0)
        self.hidden_state_dtype: torch.dtype = torch.quint8
        self.cell_state_dtype: torch.dtype = torch.quint8

    def forward(
        self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None
    ) -> tuple[Tensor, Tensor]:
        if hidden is None or hidden[0] is None or hidden[1] is None:
            hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
        hx, cx = hidden

        if not self.split_gates:
            igates = self.igates(x)
            hgates = self.hgates(hx)
            gates = self.gates.add(igates, hgates)  # type: ignore[operator]

            input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)

            input_gate = self.input_gate(input_gate)
            forget_gate = self.forget_gate(forget_gate)
            cell_gate = self.cell_gate(cell_gate)
            out_gate = self.output_gate(out_gate)
        else:
            # apply each input + hidden projection and add together
            gate = {}
            for (key, gates), igates, hgates in zip(
                self.gates.items(),  # type: ignore[operator]
                self.igates.values(),  # type: ignore[operator]
                self.hgates.values(),  # type: ignore[operator]
            ):
                gate[key] = gates.add(igates(x), hgates(hx))

            input_gate = self.input_gate(gate["input"])
            forget_gate = self.forget_gate(gate["forget"])
            cell_gate = self.cell_gate(gate["cell"])
            out_gate = self.output_gate(gate["output"])

        fgate_cx = self.fgate_cx.mul(forget_gate, cx)
        igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
        fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
        cy = fgate_cx_igate_cgate

        # TODO: make this tanh a member of the module so its qparams can be configured
        tanh_cy = torch.tanh(cy)
        hy = self.ogate_cy.mul(out_gate, tanh_cy)
        return hy, cy

    def initialize_hidden(
        self, batch_size: int, is_quantized: bool = False
    ) -> tuple[Tensor, Tensor]:
        h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros(
            (batch_size, self.hidden_size)
        )
        if is_quantized:
            (h_scale, h_zp) = self.initial_hidden_state_qparams
            (c_scale, c_zp) = self.initial_cell_state_qparams
            h = torch.quantize_per_tensor(
                h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype
            )
            c = torch.quantize_per_tensor(
                c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype
            )
        return h, c

    def _get_name(self):
        return "QuantizableLSTMCell"

    @classmethod
    def from_params(cls, wi, wh, bi=None, bh=None, split_gates=False):
        """Uses the weights and biases to create a new LSTM cell.

        Args:
            wi, wh: Weights for the input and hidden layers
            bi, bh: Biases for the input and hidden layers
        """
        assert (bi is None) == (bh is None)  # Either both None or both have values
        input_size = wi.shape[1]
        hidden_size = wh.shape[1]
        cell = cls(
            input_dim=input_size,
            hidden_dim=hidden_size,
            bias=(bi is not None),
            split_gates=split_gates,
        )

        if not split_gates:
            cell.igates.weight = torch.nn.Parameter(wi)
            if bi is not None:
                cell.igates.bias = torch.nn.Parameter(bi)
            cell.hgates.weight = torch.nn.Parameter(wh)
            if bh is not None:
                cell.hgates.bias = torch.nn.Parameter(bh)
        else:
            # split weight/bias
            for w, b, gates in zip([wi, wh], [bi, bh], [cell.igates, cell.hgates]):
                for w_chunk, gate in zip(w.chunk(4, dim=0), gates.values()):  # type: ignore[operator]
                    gate.weight = torch.nn.Parameter(w_chunk)

                if b is not None:
                    for b_chunk, gate in zip(b.chunk(4, dim=0), gates.values()):  # type: ignore[operator]
                        gate.bias = torch.nn.Parameter(b_chunk)

        return cell

    @classmethod
    def from_float(cls, other, use_precomputed_fake_quant=False, split_gates=False):
        assert type(other) == cls._FLOAT_MODULE
        assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
        observed = cls.from_params(
            other.weight_ih,
            other.weight_hh,
            other.bias_ih,
            other.bias_hh,
            split_gates=split_gates,
        )
        observed.qconfig = other.qconfig
        observed.igates.qconfig = other.qconfig
        observed.hgates.qconfig = other.qconfig
        if split_gates:
            # also apply qconfig directly to Linear modules
            for g in observed.igates.values():
                g.qconfig = other.qconfig
            for g in observed.hgates.values():
                g.qconfig = other.qconfig
        return observed


class _LSTMSingleLayer(torch.nn.Module):
    r"""A single one-directional LSTM layer.

    The difference between a layer and a cell is that the layer can process a
    sequence, while the cell only expects an instantaneous value.
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        bias: bool = True,
        device=None,
        dtype=None,
        *,
        split_gates=False,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.cell = LSTMCell(
            input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs
        )

    def forward(self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None):
        result = []
        seq_len = x.shape[0]
        for i in range(seq_len):
            hidden = self.cell(x[i], hidden)
            result.append(hidden[0])  # type: ignore[index]
        result_tensor = torch.stack(result, 0)
        return result_tensor, hidden

    @classmethod
    def from_params(cls, *args, **kwargs):
        cell = LSTMCell.from_params(*args, **kwargs)
        layer = cls(
            cell.input_size, cell.hidden_size, cell.bias, split_gates=cell.split_gates
        )
        layer.cell = cell
        return layer


class _LSTMLayer(torch.nn.Module):
    r"""A single bi-directional LSTM layer."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        bias: bool = True,
        batch_first: bool = False,
        bidirectional: bool = False,
        device=None,
        dtype=None,
        *,
        split_gates=False,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.batch_first = batch_first
        self.bidirectional = bidirectional
        self.layer_fw = _LSTMSingleLayer(
            input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs
        )
        if self.bidirectional:
            self.layer_bw = _LSTMSingleLayer(
                input_dim,
                hidden_dim,
                bias=bias,
                split_gates=split_gates,
                **factory_kwargs,
            )

    def forward(self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None):
        if self.batch_first:
            x = x.transpose(0, 1)
        if hidden is None:
            hx_fw, cx_fw = (None, None)
        else:
            hx_fw, cx_fw = hidden
        hidden_bw: Optional[tuple[Tensor, Tensor]] = None
        if self.bidirectional:
            if hx_fw is None:
                hx_bw = None
            else:
                hx_bw = hx_fw[1]
                hx_fw = hx_fw[0]
            if cx_fw is None:
                cx_bw = None
            else:
                cx_bw = cx_fw[1]
                cx_fw = cx_fw[0]
            if hx_bw is not None and cx_bw is not None:
                hidden_bw = hx_bw, cx_bw
        if hx_fw is None and cx_fw is None:
            hidden_fw = None
        else:
            hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(
                cx_fw
            )
        result_fw, hidden_fw = self.layer_fw(x, hidden_fw)

        if hasattr(self, "layer_bw") and self.bidirectional:
            x_reversed = x.flip(0)
            result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
            result_bw = result_bw.flip(0)

            result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
            if hidden_fw is None and hidden_bw is None:
                h = None
                c = None
            elif hidden_fw is None:
                (h, c) = torch.jit._unwrap_optional(hidden_bw)
            elif hidden_bw is None:
                (h, c) = torch.jit._unwrap_optional(hidden_fw)
            else:
                h = torch.stack([hidden_fw[0], hidden_bw[0]], 0)  # type: ignore[list-item]
                c = torch.stack([hidden_fw[1], hidden_bw[1]], 0)  # type: ignore[list-item]
        else:
            result = result_fw
            h, c = torch.jit._unwrap_optional(hidden_fw)  # type: ignore[assignment]

        if self.batch_first:
            result.transpose_(0, 1)

        return result, (h, c)

    @classmethod
    def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
        r"""
        There is no FP equivalent of this class. This function is here just to
        mimic the behavior of the `prepare` within the `torch.ao.quantization`
        flow.
        """
        assert hasattr(other, "qconfig") or (qconfig is not None)

        input_size = kwargs.get("input_size", other.input_size)
        hidden_size = kwargs.get("hidden_size", other.hidden_size)
        bias = kwargs.get("bias", other.bias)
        batch_first = kwargs.get("batch_first", other.batch_first)
        bidirectional = kwargs.get("bidirectional", other.bidirectional)
        split_gates = kwargs.get("split_gates", False)

        layer = cls(
            input_size,
            hidden_size,
            bias,
            batch_first,
            bidirectional,
            split_gates=split_gates,
        )
        layer.qconfig = getattr(other, "qconfig", qconfig)
        wi = getattr(other, f"weight_ih_l{layer_idx}")
        wh = getattr(other, f"weight_hh_l{layer_idx}")
        bi = getattr(other, f"bias_ih_l{layer_idx}", None)
        bh = getattr(other, f"bias_hh_l{layer_idx}", None)

        layer.layer_fw = _LSTMSingleLayer.from_params(
            wi, wh, bi, bh, split_gates=split_gates
        )

        if other.bidirectional:
            wi = getattr(other, f"weight_ih_l{layer_idx}_reverse")
            wh = getattr(other, f"weight_hh_l{layer_idx}_reverse")
            bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None)
            bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None)
            layer.layer_bw = _LSTMSingleLayer.from_params(
                wi, wh, bi, bh, split_gates=split_gates
            )
        return layer


class LSTM(torch.nn.Module):
    r"""A quantizable long short-term memory (LSTM).

    For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`

    Attributes:
        layers : instances of the `_LSTMLayer`

    .. note::
        To access the weights and biases, you need to access them per layer.
        See examples below.

    Examples::

        >>> import torch.ao.nn.quantizable as nnqa
        >>> rnn = nnqa.LSTM(10, 20, 2)
        >>> input = torch.randn(5, 3, 10)
        >>> h0 = torch.randn(2, 3, 20)
        >>> c0 = torch.randn(2, 3, 20)
        >>> output, (hn, cn) = rnn(input, (h0, c0))
        >>> # To get the weights:
        >>> # xdoctest: +SKIP
        >>> print(rnn.layers[0].weight_ih)
        tensor([[...]])
        >>> print(rnn.layers[0].weight_hh)
        AssertionError: There is no reverse path in the non-bidirectional layer
    """
    _FLOAT_MODULE = torch.nn.LSTM

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int = 1,
        bias: bool = True,
        batch_first: bool = False,
        dropout: float = 0.0,
        bidirectional: bool = False,
        device=None,
        dtype=None,
        *,
        split_gates: bool = False,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.batch_first = batch_first
        self.dropout = float(dropout)
        self.bidirectional = bidirectional
        self.training = False  # Default to eval mode. If we want to train, we will explicitly set to training.

        if (
            not isinstance(dropout, numbers.Number)
            or not 0 <= dropout <= 1
            or isinstance(dropout, bool)
        ):
            raise ValueError(
                "dropout should be a number in range [0, 1] "
                "representing the probability of an element being "
                "zeroed"
            )
        if dropout > 0:
            warnings.warn(
                "dropout option for quantizable LSTM is ignored. "
                "If you are training, please, use nn.LSTM version "
                "followed by `prepare` step."
            )
            if num_layers == 1:
                warnings.warn(
                    "dropout option adds dropout after all but last "
                    "recurrent layer, so non-zero dropout expects "
                    f"num_layers greater than 1, but got dropout={dropout} "
                    f"and num_layers={num_layers}"
                )

        layers = [
            _LSTMLayer(
                self.input_size,
                self.hidden_size,
                self.bias,
                batch_first=False,
                bidirectional=self.bidirectional,
                split_gates=split_gates,
                **factory_kwargs,
            )
        ]
        layers.extend(
            _LSTMLayer(
                self.hidden_size,
                self.hidden_size,
                self.bias,
                batch_first=False,
                bidirectional=self.bidirectional,
                split_gates=split_gates,
                **factory_kwargs,
            )
            for _ in range(1, num_layers)
        )
        self.layers = torch.nn.ModuleList(layers)

    def forward(self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None):
        if self.batch_first:
            x = x.transpose(0, 1)

        max_batch_size = x.size(1)
        num_directions = 2 if self.bidirectional else 1
        if hidden is None:
            zeros = torch.zeros(
                num_directions,
                max_batch_size,
                self.hidden_size,
                dtype=torch.float,
                device=x.device,
            )
            zeros.squeeze_(0)
            if x.is_quantized:
                zeros = torch.quantize_per_tensor(
                    zeros, scale=1.0, zero_point=0, dtype=x.dtype
                )
            hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
        else:
            hidden_non_opt = torch.jit._unwrap_optional(hidden)
            if isinstance(hidden_non_opt[0], Tensor):
                hx = hidden_non_opt[0].reshape(
                    self.num_layers, num_directions, max_batch_size, self.hidden_size
                )
                cx = hidden_non_opt[1].reshape(
                    self.num_layers, num_directions, max_batch_size, self.hidden_size
                )
                hxcx = [
                    (hx[idx].squeeze(0), cx[idx].squeeze(0))
                    for idx in range(self.num_layers)
                ]
            else:
                hxcx = hidden_non_opt

        hx_list = []
        cx_list = []
        for idx, layer in enumerate(self.layers):
            x, (h, c) = layer(x, hxcx[idx])
            hx_list.append(torch.jit._unwrap_optional(h))
            cx_list.append(torch.jit._unwrap_optional(c))
        hx_tensor = torch.stack(hx_list)
        cx_tensor = torch.stack(cx_list)

        # We are creating another dimension for bidirectional case
        # need to collapse it
        hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1])
        cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1])

        if self.batch_first:
            x = x.transpose(0, 1)

        return x, (hx_tensor, cx_tensor)

    def _get_name(self):
        return "QuantizableLSTM"

    @classmethod
    def from_float(cls, other, qconfig=None, split_gates=False):
        assert isinstance(other, cls._FLOAT_MODULE)
        assert hasattr(other, "qconfig") or qconfig
        observed = cls(
            other.input_size,
            other.hidden_size,
            other.num_layers,
            other.bias,
            other.batch_first,
            other.dropout,
            other.bidirectional,
            split_gates=split_gates,
        )
        observed.qconfig = getattr(other, "qconfig", qconfig)
        for idx in range(other.num_layers):
            observed.layers[idx] = _LSTMLayer.from_float(
                other, idx, qconfig, batch_first=False, split_gates=split_gates
            )

        # Prepare the model
        if other.training:
            observed.train()
            observed = torch.ao.quantization.prepare_qat(observed, inplace=True)
        else:
            observed.eval()
            observed = torch.ao.quantization.prepare(observed, inplace=True)
        return observed

    @classmethod
    def from_observed(cls, other):
        # The whole flow is float -> observed -> quantized
        # This class does float -> observed only
        raise NotImplementedError(
            "It looks like you are trying to convert a "
            "non-quantizable LSTM module. Please, see "
            "the examples on quantizable LSTMs."
        )
