# mypy: allow-untyped-defs
import torch


__all__ = [
    "LayerNorm",
    "GroupNorm",
    "InstanceNorm1d",
    "InstanceNorm2d",
    "InstanceNorm3d",
]


class LayerNorm(torch.nn.LayerNorm):
    r"""This is the quantized version of :class:`~torch.nn.LayerNorm`.

    Additional args:
        * **scale** - quantization scale of the output, type: double.
        * **zero_point** - quantization zero point of the output, type: long.

    """

    def __init__(
        self,
        normalized_shape,
        weight,
        bias,
        scale,
        zero_point,
        eps=1e-5,
        elementwise_affine=True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(
            normalized_shape,
            eps=eps,
            elementwise_affine=elementwise_affine,
            **factory_kwargs,
        )
        self.weight = weight
        self.bias = bias
        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))

    def forward(self, input):
        return torch.ops.quantized.layer_norm(
            input,
            self.normalized_shape,
            weight=self.weight,
            bias=self.bias,
            eps=self.eps,
            output_scale=self.scale,
            output_zero_point=self.zero_point,
        )

    def _get_name(self):
        return "QuantizedLayerNorm"

    @classmethod
    def from_float(cls, mod, use_precomputed_fake_quant=False):
        scale, zero_point = mod.activation_post_process.calculate_qparams()
        new_mod = cls(
            mod.normalized_shape,
            mod.weight,
            mod.bias,
            float(scale),
            int(zero_point),
            mod.eps,
            mod.elementwise_affine,
        )
        return new_mod

    @classmethod
    def from_reference(cls, mod, scale, zero_point):
        return cls(
            mod.normalized_shape,
            mod.weight,
            mod.bias,
            float(scale),
            int(zero_point),
            mod.eps,
            mod.elementwise_affine,
        )


class GroupNorm(torch.nn.GroupNorm):
    r"""This is the quantized version of :class:`~torch.nn.GroupNorm`.

    Additional args:
        * **scale** - quantization scale of the output, type: double.
        * **zero_point** - quantization zero point of the output, type: long.

    """
    __constants__ = ["num_groups", "num_channels", "eps", "affine"]

    def __init__(
        self,
        num_groups,
        num_channels,
        weight,
        bias,
        scale,
        zero_point,
        eps=1e-5,
        affine=True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
        self.weight = weight
        self.bias = bias
        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))

    def forward(self, input):
        return torch.ops.quantized.group_norm(
            input,
            self.num_groups,
            self.weight,
            self.bias,
            self.eps,
            self.scale,
            self.zero_point,
        )

    def _get_name(self):
        return "QuantizedGroupNorm"

    @classmethod
    def from_float(cls, mod, use_precomputed_fake_quant=False):
        scale, zero_point = mod.activation_post_process.calculate_qparams()
        new_mod = cls(
            mod.num_groups,
            mod.num_channels,
            mod.weight,
            mod.bias,
            float(scale),
            int(zero_point),
            mod.eps,
            mod.affine,
        )
        return new_mod


class InstanceNorm1d(torch.nn.InstanceNorm1d):
    r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`.

    Additional args:
        * **scale** - quantization scale of the output, type: double.
        * **zero_point** - quantization zero point of the output, type: long.

    """

    def __init__(
        self,
        num_features,
        weight,
        bias,
        scale,
        zero_point,
        eps=1e-5,
        momentum=0.1,
        affine=False,
        track_running_stats=False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(
            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
        )
        self.weight = weight
        self.bias = bias
        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))

    def forward(self, input):
        return torch.ops.quantized.instance_norm(
            input, self.weight, self.bias, self.eps, self.scale, self.zero_point
        )

    def _get_name(self):
        return "QuantizedInstanceNorm1d"

    @classmethod
    def from_float(cls, mod, use_precomputed_fake_quant=False):
        scale, zero_point = mod.activation_post_process.calculate_qparams()
        new_mod = cls(
            mod.num_features,
            mod.weight,
            mod.bias,
            float(scale),
            int(zero_point),
            mod.eps,
            mod.affine,
        )
        return new_mod

    @classmethod
    def from_reference(cls, mod, scale, zero_point):
        return cls(
            mod.num_features,
            mod.weight,
            mod.bias,
            float(scale),
            int(zero_point),
            mod.eps,
            mod.affine,
        )


class InstanceNorm2d(torch.nn.InstanceNorm2d):
    r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`.

    Additional args:
        * **scale** - quantization scale of the output, type: double.
        * **zero_point** - quantization zero point of the output, type: long.

    """

    def __init__(
        self,
        num_features,
        weight,
        bias,
        scale,
        zero_point,
        eps=1e-5,
        momentum=0.1,
        affine=False,
        track_running_stats=False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(
            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
        )
        self.weight = weight
        self.bias = bias
        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))

    def forward(self, input):
        return torch.ops.quantized.instance_norm(
            input, self.weight, self.bias, self.eps, self.scale, self.zero_point
        )

    def _get_name(self):
        return "QuantizedInstanceNorm2d"

    @classmethod
    def from_float(cls, mod, use_precomputed_fake_quant=False):
        scale, zero_point = mod.activation_post_process.calculate_qparams()
        new_mod = cls(
            mod.num_features,
            mod.weight,
            mod.bias,
            float(scale),
            int(zero_point),
            mod.eps,
            mod.affine,
        )
        return new_mod

    @classmethod
    def from_reference(cls, mod, scale, zero_point):
        return cls(
            mod.num_features,
            mod.weight,
            mod.bias,
            float(scale),
            int(zero_point),
            mod.eps,
            mod.affine,
        )


class InstanceNorm3d(torch.nn.InstanceNorm3d):
    r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`.

    Additional args:
        * **scale** - quantization scale of the output, type: double.
        * **zero_point** - quantization zero point of the output, type: long.

    """

    def __init__(
        self,
        num_features,
        weight,
        bias,
        scale,
        zero_point,
        eps=1e-5,
        momentum=0.1,
        affine=False,
        track_running_stats=False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(
            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
        )
        self.weight = weight
        self.bias = bias
        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))

    def forward(self, input):
        return torch.ops.quantized.instance_norm(
            input, self.weight, self.bias, self.eps, self.scale, self.zero_point
        )

    def _get_name(self):
        return "QuantizedInstanceNorm3d"

    @classmethod
    def from_float(cls, mod, use_precomputed_fake_quant=False):
        scale, zero_point = mod.activation_post_process.calculate_qparams()
        new_mod = cls(
            mod.num_features,
            mod.weight,
            mod.bias,
            float(scale),
            int(zero_point),
            mod.eps,
            mod.affine,
        )
        return new_mod

    @classmethod
    def from_reference(cls, mod, scale, zero_point):
        return cls(
            mod.num_features,
            mod.weight,
            mod.bias,
            float(scale),
            int(zero_point),
            mod.eps,
            mod.affine,
        )
