# mypy: ignore-errors

r"""Importing this file includes common utility methods and base clases for
checking quantization api and properties of resulting modules.
"""

from functorch.experimental import control_flow

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
from torch.ao.nn.intrinsic import _FusedModule
import torch.distributed as dist
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM

from torch.export import export_for_training
from torch.ao.quantization import (
    QuantType,
    default_dynamic_qat_qconfig,
    default_embedding_qat_qconfig,
    default_symmetric_qnnpack_qat_qconfig,
)
from torch.ao.quantization.quantize_pt2e import (
    _convert_to_reference_decomposed_fx,
    convert_pt2e,
    prepare_pt2e,
    prepare_qat_pt2e,
)
from torch.ao.quantization.backend_config import (
    get_executorch_backend_config,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from torch.ao.quantization import QuantWrapper, QuantStub, DeQuantStub, \
    default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
    propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \
    get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, quantize, \
    QConfigMapping, get_default_qconfig_mapping, get_default_qat_qconfig_mapping
from torch.ao.quantization.quantization_mappings import (
    get_default_dynamic_quant_module_mappings,
    get_default_qconfig_propagation_list,
    get_default_qat_module_mappings,
)
from torch.testing._internal.common_quantized import (
    override_quantized_engine,
)
from torch.jit.mobile import _load_for_lite_interpreter

try:
    # graph mode quantization based on fx
    from torch.ao.quantization.quantize_fx import (
        prepare_fx,
        prepare_qat_fx,
        convert_fx,
        convert_to_reference_fx,
    )
    from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph
    from torch.fx.graph import Node
    from torch.fx import GraphModule
    HAS_FX = True
except ImportError:
    HAS_FX = False

import copy
import io
import functools
import os

import unittest
import numpy as np
from torch.testing import FileCheck
from typing import Callable, Any, Union, Optional
import torch._dynamo as torchdynamo
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer
import contextlib

class NodeSpec:
    ''' Used for checking GraphModule Node
    '''
    def __init__(self, op, target):
        '''
        op: call_function | call_module
        target:
          for call_function, target would be a function
          for call_module, target would be the type of PyTorch module
        '''
        self.op = op
        self.target = target

    @classmethod
    def call_function(cls, target):
        return NodeSpec('call_function', target)

    @classmethod
    def call_method(cls, target):
        return NodeSpec('call_method', target)

    @classmethod
    def call_module(cls, target):
        return NodeSpec('call_module', target)

    def __hash__(self):
        return hash((self.op, self.target))

    def __eq__(self, other):
        if not isinstance(other, NodeSpec):
            return NotImplemented

        return self.op == other.op and self.target == other.target

    def __repr__(self):
        return repr(self.op) + " " + repr(self.target)

def get_supported_device_types():
    return ['cpu', 'cuda'] if torch.cuda.is_available() and not TEST_WITH_ROCM else ['cpu']

def test_only_eval_fn(model, calib_data):
    r"""
    Default evaluation function takes a torch.utils.data.Dataset or a list of
    input Tensors and run the model on the dataset
    """
    for inp in calib_data:
        model(*inp)

_default_loss_fn = torch.nn.CrossEntropyLoss()
def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn):
    r"""
    Default train function takes a torch.utils.data.Dataset and train the model
    on the dataset
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_loss, correct, total = 0, 0, 0
    for _ in range(10):
        model.train()

        for data, target in train_data:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return train_loss, correct, total

class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches):
    model.train()
    for cnt, (image, target) in enumerate(data_loader, start=1):
        print('.', end='')
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        accuracy(output, target, topk=(1, 5))
        if cnt >= ntrain_batches:
            return
    return

def ddp_setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def ddp_cleanup():
    dist.destroy_process_group()

def run_ddp(rank, world_size, prepared):
    ddp_setup(rank, world_size)
    prepared.cuda()
    prepared = torch.nn.parallel.DistributedDataParallel(prepared, device_ids=[rank])
    prepared.to(rank)
    model_with_ddp = prepared
    optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001)
    train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1)  # noqa: F821
    ddp_cleanup()


def convert_dynamic(module):
    convert(module, get_default_dynamic_quant_module_mappings(), inplace=True)

def prepare_dynamic(model, qconfig_dict=None):
    propagate_qconfig_(model, qconfig_dict)

def _make_conv_test_input(
    batch_size, in_channels_per_group, input_feature_map_size,
    out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale,
    W_zero_point, use_bias, use_channelwise,
):
    in_channels = in_channels_per_group * groups
    out_channels = out_channels_per_group * groups

    (X_value_min, X_value_max) = (0, 4)
    X_init = torch.randint(
        X_value_min, X_value_max,
        (batch_size, in_channels,) + input_feature_map_size)
    X = X_scale * (X_init - X_zero_point).float()
    X_q = torch.quantize_per_tensor(
        X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8)

    W_scale = W_scale * out_channels
    W_zero_point = W_zero_point * out_channels
    # Resize W_scale and W_zero_points arrays equal to out_channels
    W_scale = W_scale[:out_channels]
    W_zero_point = W_zero_point[:out_channels]
    # For testing, we use small values for weights and for activations so that
    # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in
    # qconv implementation and if there is no overflow.
    # In reference we can't exactly match the results with reference.
    # Please see the comment in qconv implementation file
    #   aten/src/ATen/native/quantized/cpu/qconv.cpp for more details.
    (W_value_min, W_value_max) = (-5, 5)
    # The operator expects them in the format
    # (out_channels, in_channels/groups,) + kernel_size
    W_init = torch.randint(
        W_value_min, W_value_max,
        (out_channels, in_channels_per_group,) + kernel_size)
    b_init = torch.randint(0, 10, (out_channels,))

    if use_channelwise:
        W_shape = (-1, 1) + (1,) * len(kernel_size)
        W_scales_tensor = torch.tensor(W_scale, dtype=torch.float)
        W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float)
        W = W_scales_tensor.reshape(*W_shape) * (
            W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float()
        b = X_scale * W_scales_tensor * b_init.float()
        W_q = torch.quantize_per_channel(
            W, W_scales_tensor.double(), W_zero_points_tensor.long(), 0,
            dtype=torch.qint8)
    else:
        W = W_scale[0] * (W_init - W_zero_point[0]).float()
        b = X_scale * W_scale[0] * b_init.float()
        W_q = torch.quantize_per_tensor(
            W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8)

    return (X, X_q, W, W_q, b if use_bias else None)

def _make_conv_add_extra_input_tensor(scale, zero_point, sizes):
    (X_value_min, X_value_max) = (0, 4)
    X_init = torch.randint(
        X_value_min,
        X_value_max,
        sizes  # Infer the size of tensor to do the add
    )
    X = scale * (X_init - zero_point).float()
    X_q = torch.quantize_per_tensor(
        X, scale=scale, zero_point=zero_point, dtype=torch.quint8)
    return X, X_q

def skipIfNoFBGEMM(fn):
    reason = 'Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer.'
    if isinstance(fn, type):
        if 'fbgemm' not in torch.backends.quantized.supported_engines:
            fn.__unittest_skip__ = True
            fn.__unittest_skip_why__ = reason
        return fn

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if 'fbgemm' not in torch.backends.quantized.supported_engines:
            raise unittest.SkipTest(reason)
        else:
            fn(*args, **kwargs)
    return wrapper

def skipIfNoQNNPACK(fn):
    reason = 'Quantized operations require QNNPACK.'
    if isinstance(fn, type):
        if 'qnnpack' not in torch.backends.quantized.supported_engines:
            fn.__unittest_skip__ = True
            fn.__unittest_skip_why__ = reason
        return fn

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if 'qnnpack' not in torch.backends.quantized.supported_engines:
            raise unittest.SkipTest(reason)
        else:
            fn(*args, **kwargs)
    return wrapper

def withQNNPACKBackend(fn):
    # TODO(future PR): consider combining with skipIfNoQNNPACK,
    # will require testing of existing callsites
    reason = 'Quantized operations require QNNPACK.'
    if isinstance(fn, type):
        if 'qnnpack' not in torch.backends.quantized.supported_engines:
            fn.__unittest_skip__ = True
            fn.__unittest_skip_why__ = reason
        return fn

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if 'qnnpack' not in torch.backends.quantized.supported_engines:
            raise unittest.SkipTest(reason)
        with override_quantized_engine('qnnpack'):
            fn(*args, **kwargs)

    return wrapper

def skipIfNoONEDNN(fn):
    reason = 'Quantized operations require ONEDNN.'
    if isinstance(fn, type):
        if 'onednn' not in torch.backends.quantized.supported_engines:
            fn.__unittest_skip__ = True
            fn.__unittest_skip_why__ = reason
        return fn

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if 'onednn' not in torch.backends.quantized.supported_engines:
            raise unittest.SkipTest(reason)
        else:
            fn(*args, **kwargs)
    return wrapper

def skipIfNoONEDNNBF16(fn):
    reason = 'Quantized operations require BF16 support.'
    if isinstance(fn, type):
        if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
            fn.__unittest_skip__ = True
            fn.__unittest_skip_why__ = reason
        return fn

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
            raise unittest.SkipTest(reason)
        else:
            fn(*args, **kwargs)
    return wrapper

def skipIfNoX86(fn):
    reason = 'Quantized operations require X86.'
    if isinstance(fn, type):
        if 'x86' not in torch.backends.quantized.supported_engines:
            fn.__unittest_skip__ = True
            fn.__unittest_skip_why__ = reason
        return fn

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if 'x86' not in torch.backends.quantized.supported_engines:
            raise unittest.SkipTest(reason)
        else:
            fn(*args, **kwargs)
    return wrapper

def skipIfNoDynamoSupport(fn):
    reason = "dynamo doesn't support."
    if isinstance(fn, type):
        if not torchdynamo.is_dynamo_supported():
            fn.__unittest_skip__ = True
            fn.__unittest_skip_why__ = reason
        return fn

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if not torchdynamo.is_dynamo_supported():
            raise unittest.SkipTest(reason)
        else:
            fn(*args, **kwargs)
    return wrapper

def skipIfNoInductorSupport(fn):
    reason = "inductor doesn't support."
    if isinstance(fn, type):
        if not torchdynamo.is_inductor_supported():
            fn.__unittest_skip__ = True
            fn.__unittest_skip_why__ = reason
        return fn

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if not torchdynamo.is_inductor_supported():
            raise unittest.SkipTest(reason)
        else:
            fn(*args, **kwargs)
    return wrapper

try:
    import torchvision  # noqa: F401
    HAS_TORCHVISION = True
except ImportError:
    HAS_TORCHVISION = False
skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

def get_script_module(model, tracing, data):
    return torch.jit.trace(model, data) if tracing else torch.jit.script(model)

def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True):
    """
    Convert lengths to offsets for embedding_bag
    """
    tt = np.zeros((t.shape[0] + 1,), dtype=offset_type)
    tt[1:] = t
    tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type))
    if use_begin_offset:
        return tt[:-1]
    return tt[1:]


def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
    assert w.dim() == 2
    w = w.transpose(0, 1).contiguous()
    assert q_group_size > 1
    assert w.shape[-1] % q_group_size == 0

    to_quant = w.reshape(-1, q_group_size)
    assert torch.isnan(to_quant).sum() == 0

    max_val = to_quant.amax(dim=1, keepdim=True)
    min_val = to_quant.amin(dim=1, keepdim=True)
    max_int = 2 ** n_bit - 1
    min_int = 0
    scales = (max_val - min_val).clamp(min=1e-6) / max_int
    assert torch.isnan(scales).sum() == 0

    zeros = min_val + scales * (2 ** (n_bit - 1))
    assert torch.isnan(zeros).sum() == 0

    out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
    assert torch.isnan(out).sum() == 0

    out = out.to(dtype=torch.int32).reshape(w.shape)
    if out.device != torch.device('cpu'):
        out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8)

    # Scales and zeros for the same q-group should be contiguous, so we can
    # load as a 32-bit word
    scales = scales.view(w.shape[0], -1)
    zeros = zeros.view(w.shape[0], -1)
    scales_and_zeros = (
        torch.cat(
            [
                scales.reshape(scales.size(0), scales.size(1), 1),
                zeros.reshape(zeros.size(0), zeros.size(1), 1),
            ],
            2,
        ).transpose(0, 1).contiguous()
    )

    return out, scales_and_zeros


def _group_quantize_tensor_symmetric(
    w, n_bit=4, groupsize=32
):
    # W is of shape [K x N]
    # We transpose W as Quantization is applied on [N x K]
    w = w.transpose(0, 1).contiguous()
    assert w.dim() == 2
    assert groupsize > 1
    assert w.shape[-1] % groupsize == 0
    # Calculate scale and zeros
    to_quant = w.reshape(-1, groupsize)
    max_val = to_quant.abs().amax(dim=1, keepdim=True)
    eps = torch.finfo(max_val.dtype).eps
    max_int = 2 ** (n_bit - 1) - 1  # For 4-bit, this is 7
    scales = max_val.clamp(min=eps) / max_int
    zeros = torch.zeros_like(scales)

    # Quantize the weight
    scales = scales.to(torch.float32).reshape(w.shape[0], -1)
    zeros = zeros.to(torch.float32).reshape(w.shape[0], -1)
    scales = scales.reshape(-1, 1)
    zeros = zeros.reshape(-1, 1)
    max_int = 2**n_bit - 1
    w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int)
    # We pack 2 signed int4 values in unsigned uint8 container.
    # This reduces the weight size by half and improves load perf
    out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8)

    scales_and_zeros = scales.squeeze().contiguous()

    return out_uint8, scales_and_zeros


def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
    # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py
    # default setup for affine quantization of activations
    x_dtype = x.dtype
    x = x.float()
    eps = torch.finfo(torch.float32).eps

    # get min and max
    min_val, max_val = torch.aminmax(x, dim=1)

    # calculate scales and zero_points based on min and max
    # reference: https://fburl.com/code/srbiybme
    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
    device = min_val_neg.device

    # reference: https://fburl.com/code/4wll53rk
    max_val_pos = torch.max(-min_val_neg, max_val_pos)
    scales = max_val_pos / (float(quant_max - quant_min) / 2)
    # ensure scales is the same dtype as the original tensor
    scales = torch.clamp(scales, min=eps).to(x.dtype)
    zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

    # quantize based on qmin/qmax/scales/zp
    x_div = x / scales.unsqueeze(-1)
    x_round = torch.round(x_div)
    x_zp = x_round + zero_points.unsqueeze(-1)
    quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)

    return quant, scales.to(x_dtype), zero_points


# QuantizationTestCase used as a base class for testing quantization on modules
class QuantizationTestCase(TestCase):
    def setUp(self):
        super().setUp()
        self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)]
        self.train_data = [[torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)] for _ in range(2)]
        self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)]
                            for _ in range(2)]
        self.img_data_2d = [[torch.rand(1, 3, 10, 10, dtype=torch.float)]
                            for _ in range(2)]
        self.img_data_3d = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float)]
                            for _ in range(2)]
        self.img_data_1d_train = [[torch.rand(2, 3, 10, dtype=torch.float),
                                   torch.randint(0, 1, (1,), dtype=torch.long)]
                                  for _ in range(2)]
        self.img_data_2d_train = [[torch.rand(1, 3, 10, 10, dtype=torch.float),
                                   torch.randint(0, 1, (1,), dtype=torch.long)]
                                  for _ in range(2)]
        self.img_data_3d_train = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float),
                                   torch.randint(0, 1, (1,), dtype=torch.long)]
                                  for _ in range(2)]

        self.img_data_dict = {1 : self.img_data_1d,
                              2 : self.img_data_2d,
                              3 : self.img_data_3d}

        # Quant types that produce statically quantized ops
        self.static_quant_types = [QuantType.STATIC, QuantType.QAT]
        # All quant types for (fx based) graph mode quantization
        self.all_quant_types = [QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT]

    def checkNoPrepModules(self, module):
        r"""Checks the module does not contain child
            modules for quantization preparation, e.g.
            quant, dequant and observer
        """
        self.assertFalse(hasattr(module, 'quant'))
        self.assertFalse(hasattr(module, 'dequant'))

    def checkNoQconfig(self, module):
        r"""Checks the module does not contain qconfig
        """
        self.assertFalse(hasattr(module, 'qconfig'))

        for child in module.children():
            self.checkNoQconfig(child)

    def checkHasPrepModules(self, module):
        r"""Checks the module contains child
            modules for quantization preparation, e.g.
            quant, dequant and observer
        """
        self.assertTrue(hasattr(module, 'module'))
        self.assertTrue(hasattr(module, 'quant'))
        self.assertTrue(hasattr(module, 'dequant'))

    def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None):
        r"""Checks the module or module's leaf descendants
            have observers in preparation for quantization
        """
        if propagate_qconfig_list is None:
            propagate_qconfig_list = get_default_qconfig_propagation_list()
        if prepare_custom_config_dict is None:
            prepare_custom_config_dict = {}
        float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})

        # check if a module is a leaf module, ignoring activation_post_process attribute
        def is_leaf_module(module):
            submodule_name_count = 0
            for name, _ in module.named_children():
                if name != 'activation_post_process':
                    submodule_name_count += 1
            return submodule_name_count == 0

        if hasattr(module, 'qconfig') and module.qconfig is not None and \
           ((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
            and type(module) in propagate_qconfig_list) or
           type(module) in float_to_observed_module_class_mapping.keys()) and \
           not isinstance(module, torch.ao.quantization.DeQuantStub):
            self.assertTrue(hasattr(module, 'activation_post_process'),
                            'module: ' + str(type(module)) + ' do not have observer')
        # we don't need to check observers for child modules of the
        # qat modules
        if type(module) not in get_default_qat_module_mappings().values() and \
           type(module) not in float_to_observed_module_class_mapping.values() and \
           not isinstance(module, _FusedModule):
            for child in module.children():
                if type(child) in [nn.Dropout]:
                    continue
                self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict)

    def checkQuantDequant(self, mod):
        r"""Checks that mod has nn.Quantize and
            nn.DeQuantize submodules inserted
        """
        self.assertEqual(type(mod.quant), nnq.Quantize)
        self.assertEqual(type(mod.dequant), nnq.DeQuantize)

    def checkWrappedQuantizedLinear(self, mod):
        r"""Checks that mod has been swapped for an nnq.Linear
            module, the bias is qint32, and that the module
            has Quantize and DeQuantize submodules
        """
        self.assertEqual(type(mod.module), nnq.Linear)
        self.checkQuantDequant(mod)

    def checkQuantizedLinear(self, mod):
        self.assertEqual(type(mod), nnq.Linear)

    def checkDynamicQuantizedLinear(self, mod, dtype):
        r"""Checks that mod has been swapped for an nnqd.Linear
            module, the bias is float.
        """
        self.assertEqual(type(mod), nnqd.Linear)
        self.assertEqual(mod._packed_params.dtype, dtype)

    def checkDynamicQuantizedLinearRelu(self, mod, dtype):
        r"""Checks that mod has been swapped for an nnqd.Linear
            module, the bias is float.
        """
        self.assertEqual(type(mod), nniqd.LinearReLU)
        self.assertEqual(mod._packed_params.dtype, dtype)

    def check_eager_serialization(self, ref_model, loaded_model, x):
        # Check state dict serialization and torch.save APIs
        model_dict = ref_model.state_dict()
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        # weights_only=False as we sometimes get a ScriptObect here (weird)
        loaded_dict = torch.load(b, weights_only=False)
        loaded_model.load_state_dict(loaded_dict)
        ref_out = ref_model(*x)
        load_out = loaded_model(*x)

        def check_outputs(ref_out, load_out):
            self.assertEqual(ref_out[0], load_out[0])
            if isinstance(ref_out[1], tuple):
                self.assertEqual(ref_out[1][0], load_out[1][0])
                self.assertEqual(ref_out[1][1], load_out[1][1])
            else:
                self.assertEqual(ref_out[1], load_out[1])

        check_outputs(ref_out, load_out)
        b = io.BytesIO()
        torch.save(ref_model, b)
        b.seek(0)
        # weights_only=False as this is legacy code that saves the model
        loaded = torch.load(b, weights_only=False)
        load_out = loaded(*x)
        check_outputs(ref_out, load_out)

    def check_weight_bias_api(self, ref_model, weight_keys, bias_keys):
        weight = ref_model.get_weight()
        bias = ref_model.get_bias()
        self.assertEqual(weight_keys ^ weight.keys(), set())
        self.assertEqual(bias_keys ^ bias.keys(), set())

    def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype):
        r"""Checks that mod has been swapped for an nnqd.LSTM type
            module, the bias is float.
        """
        wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'}
        self.assertEqual(type(mod), reference_module_type)
        for packed_params in mod._all_weight_values:
            self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype])

    def checkLinear(self, mod):
        self.assertEqual(type(mod), torch.nn.Linear)

    def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype):
        r"""Checks that mod has been swapped for an nnqd.Linear
            module, the bias is float.
        """
        wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'}
        self.assertEqual(type(mod), reference_module_type)
        if hasattr(mod, '_all_weight_values'):
            for packed_params in mod._all_weight_values:
                self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype])

    def checkScriptable(self, orig_mod, calib_data, check_save_load=False):
        scripted = torch.jit.script(orig_mod)
        self._checkScriptable(orig_mod, scripted, calib_data, check_save_load)

        # Use first calib_data entry as trace input
        traced = torch.jit.trace(orig_mod, calib_data[0])
        self._checkScriptable(orig_mod, traced, calib_data, check_save_load)

    # Call this twice: once for a scripted module and once for a traced module
    def _checkScriptable(self, orig_mod, script_mod, calib_data, check_save_load):
        self._checkModuleCorrectnessAgainstOrig(orig_mod, script_mod, calib_data)

        # Test save/load
        buffer = io.BytesIO()
        torch.jit.save(script_mod, buffer)

        buffer.seek(0)
        loaded_mod = torch.jit.load(buffer)
        # Pending __get_state_ and __set_state__ support
        # See tracking task https://github.com/pytorch/pytorch/issues/23984
        if check_save_load:
            self._checkModuleCorrectnessAgainstOrig(orig_mod, loaded_mod, calib_data)

    def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data):
        for inp in calib_data:
            ref_output = orig_mod(*inp)
            scripted_output = test_mod(*inp)
            self.assertEqual(scripted_output, ref_output)


    def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=False,
                         check=True, eval_mode=True, dynamic=False, qconfig=None):
        if debug:
            print('Testing:', str(module))
        qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}

        if eval_mode:
            module = module.eval()
        if dynamic:
            qconfig_dict = {'': default_dynamic_qconfig if qconfig is None else qconfig}
        model = get_script_module(module, tracing, inputs[0]).eval()
        if debug:
            print('input graph:', model.graph)
        models = {}
        outputs = {}
        for debug in [True, False]:
            if dynamic:
                models[debug] = quantize_dynamic_jit(model, qconfig_dict, debug=debug)
                # make sure it runs
                outputs[debug] = models[debug](inputs)
            else:
                # module under test can contain in-place ops, and we depend on
                # input data staying constant for comparisons
                inputs_copy = copy.deepcopy(inputs)
                models[debug] = quantize_jit(
                    model, qconfig_dict, test_only_eval_fn, [inputs_copy], inplace=False,
                    debug=debug)
                # make sure it runs
                outputs[debug] = models[debug](*inputs[0])

        if debug:
            print('debug graph:', models[True].graph)
            print('non debug graph:', models[False].graph)

        if check:
            # debug and non-debug option should have the same numerics
            self.assertEqual(outputs[True], outputs[False])

            # non debug graph should produce quantized op
            FileCheck().check(quantized_op) \
                       .run(models[False].graph)

        return models[False]

    def checkGraphModuleNodes(
            self, graph_module,
            expected_node=None,
            expected_node_occurrence=None,
            expected_node_list=None):
        """ Check if GraphModule contains the target node
        Args:
            graph_module: the GraphModule instance we want to check
            expected_node, expected_node_occurrence, expected_node_list:
               see docs for checkGraphModeFxOp
        """
        nodes_in_graph = {}
        node_list = []
        modules = dict(graph_module.named_modules(remove_duplicate=False))
        for node in graph_module.graph.nodes:
            n = None
            if node.op == 'call_function' or node.op == 'call_method':
                n = NodeSpec(node.op, node.target)
            elif node.op == 'call_module':
                n = NodeSpec(node.op, type(modules[node.target]))

            if n is not None:
                node_list.append(n)
                if n in nodes_in_graph:
                    nodes_in_graph[n] += 1
                else:
                    nodes_in_graph[n] = 1

        if expected_node is not None:
            self.assertTrue(expected_node in nodes_in_graph, 'node:' + str(expected_node) +
                            ' not found in the graph module')

        if expected_node_occurrence is not None:
            for expected_node, occurrence in expected_node_occurrence.items():
                if occurrence != 0:
                    self.assertTrue(
                        expected_node in nodes_in_graph,
                        'Check failed for node:' + str(expected_node) +
                        ' not found')
                    self.assertTrue(
                        nodes_in_graph[expected_node] == occurrence,
                        'Check failed for node:' + str(expected_node) +
                        ' Expected occurrence:' + str(occurrence) +
                        ' Found occurrence:' + str(nodes_in_graph[expected_node]))
                else:
                    self.assertTrue(
                        expected_node not in nodes_in_graph,
                        'Check failed for node:' + str(expected_node) +
                        ' expected no occurrence but found')

        if expected_node_list is not None:
            cur_index = 0
            for n in node_list:
                if cur_index == len(expected_node_list):
                    return
                if n == expected_node_list[cur_index]:
                    cur_index += 1
            self.assertTrue(
                cur_index == len(expected_node_list),
                "Check failed for graph:" +
                self.printGraphModule(graph_module, print_str=False) +
                "Expected ordered list:" +
                str(expected_node_list))

    def printGraphModule(self, graph_module, print_str=True):
        modules = dict(graph_module.named_modules(remove_duplicate=False))
        node_infos = []
        for n in graph_module.graph.nodes:
            node_info = ' '.join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs]))
            if n.op == 'call_module':
                node_info += ' module type: ' + repr(type(modules[n.target]))
            node_infos.append(node_info)
        str_to_print = '\n'.join(node_infos)
        if print_str:
            print(str_to_print)
        return str_to_print

    if HAS_FX:

        def assert_types_for_matched_subgraph_pairs(
            self,
            matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]],
            expected_types: dict[str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]]],
            gm_a: GraphModule,
            gm_b: GraphModule,
        ) -> None:
            """
            Verifies that the types specified in expected_types match
            the underlying objects pointed to by the nodes in matched_subgraph_pairs.

            An example successful test case:

              matched_subgraph_pairs = {'x0': (graph_a_conv_0_node, graph_b_conv_0_node)}
              expected_types = {'x0': (nn.Conv2d, nnq.Conv2d)}

            The function tests for key equivalence, and verifies types with
            instance checks.
            """

            def _get_underlying_op_type(
                node: Node, gm: GraphModule
            ) -> Union[Callable, str]:
                if node.op == 'call_module':
                    mod = getattr(gm, node.target)
                    return type(mod)
                else:
                    assert node.op in ('call_function', 'call_method')
                    return node.target

            self.assertTrue(
                len(matched_subgraph_pairs) == len(expected_types),
                f'Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}'
            )
            for k, v in expected_types.items():
                expected_types_a, expected_types_b = v
                exp_type_start_a, exp_type_end_a = expected_types_a
                exp_type_start_b, exp_type_end_b = expected_types_b
                subgraph_a, subgraph_b = matched_subgraph_pairs[k]

                act_type_start_a = _get_underlying_op_type(subgraph_a.start_node, gm_a)
                act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b)
                act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a)
                act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b)
                types_match = (exp_type_start_a is act_type_start_a) and \
                    (exp_type_end_a is act_type_end_a) and \
                    (exp_type_start_b is act_type_start_b) and \
                    (exp_type_end_b is act_type_end_b)
                self.assertTrue(
                    types_match,
                    f'Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, '
                    f'got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}'
                )

        def assert_ns_compare_dict_valid(
            self,
            act_compare_dict: dict[str, dict[str, dict[str, Any]]],
        ) -> None:
            """
            Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid:
            1. for each layer, results are recorded for two models
            2. number of seen tensors match
            3. shapes of each pair of seen tensors match
            """
            for layer_name, result_type_to_data in act_compare_dict.items():
                for result_type, layer_data in result_type_to_data.items():
                    self.assertTrue(
                        len(layer_data) == 2,
                        f"Layer {layer_name} does not have exactly two model results.")
                    model_name_0, model_name_1 = layer_data.keys()
                    for res_idx in range(len(layer_data[model_name_0])):
                        layer_data_0 = layer_data[model_name_0][res_idx]
                        layer_data_1 = layer_data[model_name_1][res_idx]
                        self.assertTrue(
                            layer_data_0['type'] == layer_data_0['type'],
                            f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.")

                        self.assertTrue(
                            len(layer_data_0['values']) ==
                            len(layer_data_1['values']),
                            f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.")

                        # F.conv1d weight has rank 3, and toq.conv1d unpacked weight
                        # has rank 4. For now, skip the length check for conv1d only.
                        is_weight_functional_conv1d = (
                            result_type == NSSingleResultValuesType.WEIGHT.value and
                            (
                                'conv1d' in layer_data_0['prev_node_target_type'] or
                                'conv1d' in layer_data_1['prev_node_target_type']
                            )
                        )
                        if not is_weight_functional_conv1d:
                            for idx in range(len(layer_data_0['values'])):
                                values_0 = layer_data_0['values'][idx]
                                values_1 = layer_data_1['values'][idx]
                                if isinstance(values_0, torch.Tensor):
                                    self.assertTrue(
                                        values_0.shape == values_1.shape,
                                        f"Layer {layer_name}, {model_name_0} and {model_name_1} " +
                                        f"have a shape mismatch at idx {idx}.")
                                elif isinstance(values_0, list):
                                    values_0 = values_0[0]
                                    values_1 = values_1[0]
                                    self.assertTrue(
                                        values_0.shape == values_1.shape,
                                        f"Layer {layer_name}, {model_name_0} and {model_name_1} " +
                                        f"have a shape mismatch at idx {idx}.")
                                else:
                                    assert isinstance(values_0, tuple), \
                                        f"unhandled type {type(values_0)}"
                                    assert len(values_0) == 2
                                    assert len(values_0[1]) == 2
                                    assert values_0[0].shape == values_1[0].shape
                                    assert values_0[1][0].shape == values_1[1][0].shape
                                    assert values_0[1][1].shape == values_1[1][1].shape

                        # verify that ref_node_name is valid
                        ref_node_name_0 = layer_data_0['ref_node_name']
                        ref_node_name_1 = layer_data_1['ref_node_name']
                        prev_node_name_0 = layer_data_0['prev_node_name']
                        prev_node_name_1 = layer_data_1['prev_node_name']
                        if layer_data_0['type'] == NSSingleResultValuesType.NODE_OUTPUT.value:
                            self.assertTrue(ref_node_name_0 == prev_node_name_0)
                            self.assertTrue(ref_node_name_1 == prev_node_name_1)
                        elif layer_data_0['type'] == NSSingleResultValuesType.NODE_INPUT.value:
                            self.assertTrue(ref_node_name_0 != prev_node_name_0)
                            self.assertTrue(ref_node_name_1 != prev_node_name_1)

        def checkGraphModeFxOp(
                self,
                model,
                inputs,
                quant_type,
                expected_node=None,
                expected_node_occurrence=None,
                expected_node_list=None,
                is_reference=False,
                print_debug_info=False,
                custom_qconfig_dict=None,
                prepare_expected_node=None,
                prepare_expected_node_occurrence=None,
                prepare_expected_node_list=None,
                prepare_custom_config=None,
                backend_config=None):
            """ Quantizes model with graph mode quantization on fx and check if the
                quantized model contains the quantized_node

                Args:
                    model: floating point torch.nn.Module
                    inputs: one positional sample input arguments for model
                    expected_node: NodeSpec
                        e.g. NodeSpec.call_function(torch.quantize_per_tensor)
                    expected_node_occurrence: a dict from NodeSpec to
                        expected number of occurrences (int)
                        e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1,
                                NodeSpec.call_method('dequantize'): 1}
                    expected_node_list: a list of NodeSpec, used to check the order
                        of the occurrence of Node
                        e.g. [NodeSpec.call_function(torch.quantize_per_tensor),
                                NodeSpec.call_module(nnq.Conv2d),
                                NodeSpec.call_function(F.hardtanh_),
                                NodeSpec.call_method('dequantize')]
                    is_reference: if True, enables reference mode
                    print_debug_info: if True, prints debug info
                    custom_qconfig_dict: overrides default qconfig_dict
                    prepare_expected_node: same as expected_node, but for prepare
                    prepare_expected_node_occurrence: same as
                        expected_node_occurrence, but for prepare
                    prepare_expected_node_list: same as expected_node_list, but
                        for prepare

                Returns:
                    A dictionary with the following structure:
                   {
                       "prepared": ...,  # the prepared model
                       "quantized": ...,  # the quantized non-reference model
                       "quantized_reference": ...,  # the quantized reference model
                       "result": ...,  # the result for either quantized or
                                       # quantized_reference model depending on the
                                       # is_reference argument
                   }
            """
            # TODO: make img_data a single example instead of a list
            if type(inputs) == list:
                inputs = inputs[0]

            if quant_type == QuantType.QAT:
                qconfig_mapping = get_default_qat_qconfig_mapping(torch.backends.quantized.engine)
                model.train()
            elif quant_type == QuantType.STATIC:
                qconfig_mapping = get_default_qconfig_mapping(torch.backends.quantized.engine)
                model.eval()
            else:
                qconfig = default_dynamic_qconfig
                qconfig_mapping = QConfigMapping().set_global(qconfig)
                model.eval()

            if quant_type == QuantType.QAT:
                prepare = prepare_qat_fx
            else:
                prepare = prepare_fx

            # overwrite qconfig_dict with custom_qconfig_dict
            if custom_qconfig_dict is not None:
                assert type(custom_qconfig_dict) in (QConfigMapping, dict), \
                    'custom_qconfig_dict should be a QConfigMapping or a dict'
                if isinstance(custom_qconfig_dict, QConfigMapping):
                    qconfig_mapping = custom_qconfig_dict
                else:
                    qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict)
            prepared = prepare(
                model, qconfig_mapping,
                example_inputs=inputs,
                prepare_custom_config=prepare_custom_config,
                backend_config=backend_config)
            if not quant_type == QuantType.DYNAMIC:
                prepared(*inputs)

            if print_debug_info:
                print()
                print('quant type:\n', quant_type)
                print('original model:\n', model)
                print()
                print('prepared model:\n', prepared)

            self.checkGraphModuleNodes(
                prepared, prepare_expected_node,
                prepare_expected_node_occurrence, prepare_expected_node_list)

            prepared_copy = copy.deepcopy(prepared)
            qgraph = convert_fx(copy.deepcopy(prepared))
            qgraph_reference = convert_to_reference_fx(copy.deepcopy(prepared))
            result = qgraph(*inputs)
            result_reference = qgraph_reference(*inputs)
            qgraph_copy = copy.deepcopy(qgraph)
            qgraph_reference_copy = copy.deepcopy(qgraph_reference)

            qgraph_to_check = qgraph_reference if is_reference else qgraph
            if print_debug_info:
                print()
                print('quantized model:\n', qgraph_to_check)
                self.printGraphModule(qgraph_to_check)
                print()
            self.checkGraphModuleNodes(
                qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list)
            return {"prepared": prepared_copy,
                    "quantized": qgraph_copy,
                    "quantized_reference": qgraph_reference_copy,
                    "quantized_output": result,
                    "quantized_reference_output": result_reference}


    def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets,
                                    set_qconfig, is_emb_bag, dtype=torch.quint8):
        # Test serialization of dynamic EmbeddingBag module using state_dict
        if is_emb_bag:
            inputs = [indices, offsets]
        else:
            inputs = [indices]
        emb_dict = qemb.state_dict()
        b = io.BytesIO()
        torch.save(emb_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        embedding_unpack = torch.ops.quantized.embedding_bag_unpack
        # Check unpacked weight values explicitly
        for key in emb_dict:
            if isinstance(emb_dict[key], torch._C.ScriptObject):
                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
                emb_weight = embedding_unpack(emb_dict[key])
                loaded_weight = embedding_unpack(loaded_dict[key])
                self.assertEqual(emb_weight, loaded_weight)

        # Check state dict serialization and torch.save APIs
        if is_emb_bag:
            loaded_qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
                                           include_last_offset=True, mode='sum', dtype=dtype)
        else:
            loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype)
        self.check_eager_serialization(qemb, loaded_qemb, inputs)

        loaded_qemb.load_state_dict(loaded_dict)
        self.assertEqual(embedding_unpack(qemb._packed_params._packed_weight),
                         embedding_unpack(loaded_qemb._packed_params._packed_weight))


        # Test JIT serialization
        self.checkScriptable(qemb, [inputs], check_save_load=True)

        # Test from_float call
        if is_emb_bag:
            float_embedding = torch.nn.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
                                                    include_last_offset=True, scale_grad_by_freq=False, mode='sum')
        else:
            float_embedding = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)

        if set_qconfig:
            float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
                                                                        qscheme=torch.per_channel_affine_float_qparams,
                                                                        ch_axis=0)
            float_embedding.qconfig = QConfig(activation=default_dynamic_quant_observer,
                                              weight=float_qparams_observer)

        prepare_dynamic(float_embedding)

        float_embedding(*inputs)
        if is_emb_bag:
            q_embeddingbag = nnq.EmbeddingBag.from_float(float_embedding)
            expected_name = "QuantizedEmbeddingBag"
        else:
            q_embeddingbag = nnq.Embedding.from_float(float_embedding)
            expected_name = "QuantizedEmbedding"

        q_embeddingbag(*inputs)

        self.assertTrue(expected_name in str(q_embeddingbag))

class QuantizationLiteTestCase(QuantizationTestCase):
    def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs):
        # Creates quantized model for testing mobile script modules
        qengine = "qnnpack"
        with override_quantized_engine(qengine):
            # FIXME(rec): shouldn't qconfig be passed to quantize?
            qconfig = torch.ao.quantization.get_default_qconfig(qengine)  # noqa: F841
            model = model_class(**kwargs)
            model = quantize(model, test_only_eval_fn, [self.calib_data])

        return model

    def _compare_script_and_mobile(self,
                                   model: torch.nn.Module,
                                   input: torch.Tensor):
        # Compares the numerical outputs for script and lite modules
        qengine = "qnnpack"
        with override_quantized_engine(qengine):
            script_module = torch.jit.script(model)
            script_module_result = script_module(input)

            max_retry = 5
            for retry in range(1, max_retry + 1):
                # retries `max_retry` times; breaks iff succeeds else throws exception
                try:
                    buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
                    buffer.seek(0)
                    mobile_module = _load_for_lite_interpreter(buffer)

                    mobile_module_result = mobile_module(input)

                    torch.testing.assert_close(script_module_result, mobile_module_result)
                    mobile_module_forward_result = mobile_module.forward(input)
                    torch.testing.assert_close(script_module_result, mobile_module_forward_result)

                    mobile_module_run_method_result = mobile_module.run_method("forward", input)
                    torch.testing.assert_close(script_module_result, mobile_module_run_method_result)
                except AssertionError as e:
                    if retry == max_retry:
                        raise e
                    else:
                        continue
                break


class PT2EQuantizationTestCase(QuantizationTestCase):
    """
    Base QuantizationTestCase for PT2 with some helper methods.
    """
    _MAP_TO_FX_TRACED_OPS = {
        torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default,
        torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default,
        torch.ops.quantized_decomposed.quantize_per_channel: torch.ops.quantized_decomposed.quantize_per_channel.default,
        torch.ops.quantized_decomposed.dequantize_per_channel: torch.ops.quantized_decomposed.dequantize_per_channel.default,
        torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
    }

    def _test_quantizer(
        self,
        model,
        example_inputs,
        quantizer,
        expected_node_occurrence,
        expected_node_list=None,
        check_against_fx_quant=False,
        fx_qconfig_mapping=None,
        export_with_dynamic_shape=False,
        is_qat=False,
        is_debug_mode=False,
        training_ir_node_occurrence=None,
    ):
        # resetting dynamo cache
        torch._dynamo.reset()
        m_eager = model.eval()

        # program capture
        m = copy.deepcopy(m_eager)
        dynamic_shapes = tuple(
            {0: torch.export.Dim("dim")} if i == 0 else None
            for i in range(len(example_inputs))
        )
        m = export_for_training(
            m,
            example_inputs,
            dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
        ).module()

        if is_qat:
            m = prepare_qat_pt2e(m, quantizer)
        else:
            m = prepare_pt2e(m, quantizer)
        if is_debug_mode:
            print("prepared model:", m)
        # Calibrate
        m(*example_inputs)
        m = convert_pt2e(m)
        if is_debug_mode:
            print("quantized model", m)

        pt2_quant_output = m(*example_inputs)
        ns = NodeSpec
        node_occurrence = {
            ns.call_function(k): v for k, v in expected_node_occurrence.items()
        }
        if expected_node_list is None:
            expected_node_list = []
        node_list = [ns.call_function(n) for n in expected_node_list]
        self.checkGraphModuleNodes(
            m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
        )
        if check_against_fx_quant:
            qconfig_mapping = fx_qconfig_mapping
            backend_config = get_executorch_backend_config()
            m_copy = copy.deepcopy(m_eager)
            m_fx = prepare_fx(
                m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
            )
            m_fx(*example_inputs)
            m_fx = _convert_to_reference_decomposed_fx(
                m_fx, backend_config=backend_config
            )
            m_fx = export_for_training(
                m_fx,
                example_inputs,
                dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
            ).module()
            node_occurrence = {}
            for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
                if k in expected_node_occurrence:
                    node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
            if training_ir_node_occurrence is not None:
                node_occurrence = {
                    ns.call_function(k): v for k, v in training_ir_node_occurrence.items()
                }
            self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
            fx_quant_output = m_fx(*example_inputs)
            self.assertEqual(fx_quant_output, pt2_quant_output)
        return m

    def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False):
        # resetting dynamo cache
        torch._dynamo.reset()

        m = export_for_training(
            m,
            example_inputs,
        ).module()
        if is_qat:
            m = prepare_qat_pt2e(m, quantizer)
        else:
            m = prepare_pt2e(m, quantizer)
        m(*example_inputs)
        m = convert_pt2e(m)
        return m

    def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule:
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(2, 2)

            def forward(self, x):
                return self.linear(x)

        quantizer = XNNPACKQuantizer()
        operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel)
        quantizer.set_global(operator_config)
        example_inputs = (torch.randn(2, 2),)
        m = M().eval()
        return self._quantize(m, quantizer, example_inputs)

# Below are a series of toy models to use in testing quantization

class SingleLayerLinearModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)

    def forward(self, x):
        x = self.fc1(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)

class AnnotatedSingleLayerLinearModel(torch.nn.Module):
    def __init__(self, qengine='fbgemm'):
        super().__init__()
        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
        self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))

    def forward(self, x):
        x = self.fc1(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)

class SingleLayerLinearDynamicModel(torch.nn.Module):
    def __init__(self, qengine='fbgemm'):
        super().__init__()
        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)

    def forward(self, x):
        x = self.fc1(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)

class LinearAddModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.add(x, 5)
        x = self.fc2(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)

class RNNDynamicModel(torch.nn.Module):
    def __init__(self, mod_type):
        super().__init__()
        self.qconfig = default_dynamic_qconfig
        if mod_type == 'GRU':
            self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float)
        if mod_type == 'LSTM':
            self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float)

    def forward(self, x):
        x = self.mod(x)
        return x

class RNNCellDynamicModel(torch.nn.Module):
    def __init__(self, mod_type):
        super().__init__()
        self.qconfig = default_dynamic_qconfig
        if mod_type == 'GRUCell':
            self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float)
        if mod_type == 'LSTMCell':
            self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float)
        if mod_type == 'RNNReLU':
            self.mod = torch.nn.RNNCell(2, 2, nonlinearity='relu').to(dtype=torch.float)
        if mod_type == 'RNNTanh':
            self.mod = torch.nn.RNNCell(2, 2, nonlinearity='tanh').to(dtype=torch.float)

    def forward(self, x):
        x = self.mod(x)
        return x

class LSTMwithHiddenDynamicModel(torch.nn.Module):
    def __init__(self, qengine='fbgemm'):
        super().__init__()
        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
        self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float)

    def forward(self, x, hid):
        x, hid = self.lstm(x, hid)
        return x, hid

class ConvModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)

    def forward(self, x):
        x = self.conv(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

class ConvTransposeModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)

    def forward(self, x):
        x = self.conv(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

class AnnotatedConvModel(torch.nn.Module):
    def __init__(self, qengine):
        super().__init__()
        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

class AnnotatedConvTransposeModel(torch.nn.Module):
    def __init__(self, qengine):
        super().__init__()
        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
        self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

class ConvBnModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

class AnnotatedConvBnModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.qconfig = default_qconfig
        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.dequant(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

class ConvBnReLUModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

class AnnotatedConvBnReLUModel(torch.nn.Module):
    def __init__(self, qengine='fbgemm'):
        super().__init__()
        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
        self.relu = nn.ReLU(inplace=True)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        # TODO: remove this check and define two fuse_modules function on this module
        if self.training:
            torch.ao.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True)
        else:
            torch.ao.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True)

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

class TwoLayerConvModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
        self.conv2 = torch.nn.Conv2d(5, 5, 1, bias=False).to(dtype=torch.float)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

class TwoLayerLinearModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)

class LinearModelWithSubmodule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.subm = TwoLayerLinearModel()
        self.fc = nn.Linear(5, 5)

    def forward(self, x):
        x = self.subm(x)
        x = self.fc(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return self.subm.get_example_inputs()

class AnnotatedTwoLayerLinearModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
        self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float))
        self.fc2.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)

class ActivationsTestModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
        self.quant = torch.ao.quantization.QuantStub()
        self.hardswish = torch.nn.Hardswish().to(dtype=torch.float)
        self.elu = torch.nn.ELU().to(dtype=torch.float)
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.hardswish(x)
        x = self.elu(x)
        x = self.dequant(x)
        return x

class LinearReluModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc(x))
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)


class LinearReluLinearModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)

class LinearReluAddModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(5, 5).to(dtype=torch.float)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = torch.add(x, 5)
        x = self.fc2(x)
        self.relu = torch.nn.ReLU()
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)

class LinearBnLeakyReluModel(torch.nn.Module):
    def __init__(self, with_bn=True):
        super().__init__()
        self.linear = nn.Linear(5, 5)
        self.bn1d = nn.BatchNorm1d(5)
        self.leaky_relu = nn.LeakyReLU(0.01)
        self.with_bn = with_bn

    def forward(self, x):
        x = self.linear(x)
        if self.with_bn:
            x = self.bn1d(x)
        x = self.leaky_relu(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)

class LinearTanhModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(5, 5)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.linear(x)
        x = self.tanh(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)

class ConvBnAddReluModel(torch.nn.Module):
    def __init__(self,
                 with_bn=True,
                 with_relu=True,
                 left_conv=True,
                 two_conv=True,
                 use_torch_add=True):
        super().__init__()
        self.conv = nn.Conv2d(5, 5, (2, 2))
        self.conv2 = nn.Conv2d(5, 5, (2, 2))
        self.bn = nn.BatchNorm2d(5)
        self.relu = nn.ReLU()
        self.with_bn = with_bn
        self.with_relu = with_relu
        self.two_conv = two_conv
        self.left_conv = left_conv
        self.use_torch_add = use_torch_add

    def forward(self, x1, x2):
        if self.two_conv:
            if self.use_torch_add:
                if self.with_bn:
                    x = torch.add(self.bn(self.conv(x1)), self.conv2(x1))
                else:
                    x = torch.add(self.conv(x1), self.conv2(x1))
            else:
                if self.with_bn:
                    x = self.bn(self.conv(x1)) + self.conv2(x1)
                else:
                    x = self.conv(x1) + self.conv2(x1)
        else:
            if self.use_torch_add:
                if self.left_conv:
                    if self.with_bn:
                        x = torch.add(self.bn(self.conv(x1)), x2)
                    else:
                        x = torch.add(self.conv(x1), x2)
                else:
                    if self.with_bn:
                        x = torch.add(x2, self.bn(self.conv(x1)))
                    else:
                        x = torch.add(x2, self.conv(x1))
            else:
                if self.left_conv:
                    if self.with_bn:
                        x = self.bn(self.conv(x1)) + x2
                    else:
                        x = self.conv(x1) + x2
                else:
                    if self.with_bn:
                        x = x2 + self.bn(self.conv(x1))
                    else:
                        x = x2 + self.conv(x1)
        if self.with_relu:
            x = self.relu(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2))

# TODO: self.fc should be self.conv
class ConvReluModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc(x))
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

# TODO: self.fc should be self.conv
class ConvReluConvModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

# TODO: self.fc should be self.conv
class ConvReluAddModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = torch.add(x, 5)
        x = self.fc2(x)
        self.relu = torch.nn.ReLU()
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

class NormalizationTestModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
        self.layer_norm = torch.nn.LayerNorm(8)
        self.group_norm = torch.nn.GroupNorm(2, 8)
        self.instance_norm1d = torch.nn.InstanceNorm1d(8)
        self.instance_norm2d = torch.nn.InstanceNorm2d(8)
        self.instance_norm3d = torch.nn.InstanceNorm3d(8)

    def forward(self, x):
        x = self.quant(x)
        x = self.fc1(x)
        x = self.layer_norm(x)
        x = self.group_norm(x.unsqueeze(-1).repeat(1, 1, 3))
        x = self.instance_norm1d(x)
        x = self.instance_norm2d(x.unsqueeze(-1))
        x = self.instance_norm3d(x.unsqueeze(-1))
        return x

class NestedModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sub1 = LinearReluModel()
        self.sub2 = TwoLayerLinearModel()
        self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)

    def forward(self, x):
        x = self.sub1(x)
        x = self.sub2(x)
        x = self.fc3(x)
        return x

class AnnotatedNestedModel(torch.nn.Module):
    def __init__(self, qengine):
        super().__init__()
        self.sub1 = LinearReluModel()
        self.sub2 = TwoLayerLinearModel()
        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
        self.fc3.qconfig = default_qconfig
        self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
        if qengine == 'fbgemm':
            self.sub2.fc1.qconfig = default_per_channel_qconfig
        else:
            self.sub2.fc1.qconfig = default_qconfig

    def forward(self, x):
        x = self.sub1(x)
        x = self.sub2(x)
        x = self.fc3(x)
        return x

class AnnotatedSubNestedModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sub1 = LinearReluModel()
        self.sub2 = QuantWrapper(TwoLayerLinearModel())
        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
        self.fc3.qconfig = default_qconfig
        self.sub2.qconfig = default_qconfig

    def forward(self, x):
        x = self.sub1(x)
        x = self.sub2(x)
        x = self.fc3(x)
        return x

class AnnotatedCustomConfigNestedModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sub1 = LinearReluModel()
        self.sub2 = TwoLayerLinearModel()
        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
        self.fc3.qconfig = default_qconfig
        self.sub2.qconfig = default_qconfig

        custom_options = {
            'dtype': torch.quint8,
            'qscheme': torch.per_tensor_affine
        }
        custom_qconfig = QConfig(activation=default_observer.with_args(**custom_options),
                                 weight=default_weight_observer)
        self.sub2.fc1.qconfig = custom_qconfig

        self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
        self.sub2.fc2 = QuantWrapper(self.sub2.fc2)

    def forward(self, x):
        x = self.sub1(x)
        x = self.sub2(x)
        x = self.fc3(x)
        return x

class QuantSubModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sub1 = LinearReluModel()
        self.sub2 = QuantWrapper(TwoLayerLinearModel())
        self.sub2.qconfig = default_qconfig
        self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
        self.fc3.qconfig = default_qconfig

    def forward(self, x):
        x = self.sub1(x)
        x = self.sub2(x)
        x = self.fc3(x)
        return x

class InnerModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
        self.relu1 = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
        self.relu2 = torch.nn.ReLU()

    def forward(self, x):
        return self.relu2(self.fc2(self.relu1(self.fc1(x))))

    def fuse_modules(self):
        fusable_layers = []
        named_children = list(self.named_children())
        for idx, (current_name, layer) in enumerate(named_children):
            if isinstance(layer, torch.nn.Linear):
                if idx >= len(named_children) - 1:
                    break
                if isinstance(named_children[idx + 1][1], torch.nn.ReLU):
                    fusable_layers.append([current_name,
                                           named_children[idx + 1][0]])
        # TODO: remove this check and define two fuse_modules function on this module
        if self.training:
            torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True)
        else:
            torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True)

class FunctionalLinear(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.weight = torch.rand((5, 5))
        self.bias = torch.zeros(5)

    def forward(self, x):
        return F.linear(x, self.weight, self.bias)

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 5),)

class SingleLayerFunctionalLinearModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = FunctionalLinear()

    def forward(self, x):
        x = self.linear1(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return self.linear1.get_example_inputs()

class TwoLayerFunctionalLinearModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = FunctionalLinear()
        self.linear2 = FunctionalLinear()

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return self.linear1.get_example_inputs()

class FunctionalLinearAddModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = FunctionalLinear()
        self.linear2 = FunctionalLinear()

    def forward(self, x):
        x = self.linear1(x)
        x = torch.add(x, 5)
        x = self.linear2(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return self.linear1.get_example_inputs()

class FunctionalLinearReluModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = FunctionalLinear()

    def forward(self, x):
        x = self.linear(x)
        x = F.relu(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return self.linear.get_example_inputs()

class FunctionalLinearReluLinearModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = FunctionalLinear()
        self.relu = nn.ReLU()
        self.linear2 = FunctionalLinear()

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return self.linear1.get_example_inputs()

class FunctionalConv2d(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.weight = torch.rand(3, 3, 3, 3)
        self.bias = torch.rand(3)
        self.stride = (1, 1)
        self.padding = (0, 0)
        self.dilation = (1, 1)
        self.groups = 1

    def forward(self, x):
        return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

    def get_example_inputs(self) -> tuple[Any, ...]:
        return (torch.rand(1, 3, 5, 5),)

class SingleLayerFunctionalConvModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = FunctionalConv2d()

    def forward(self, x):
        x = self.conv1(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return self.conv1.get_example_inputs()

class TwoLayerFunctionalConvModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = FunctionalConv2d()
        self.conv2 = FunctionalConv2d()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return self.conv1.get_example_inputs()

class FunctionalConvReluModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = FunctionalConv2d()

    def forward(self, x):
        x = self.conv(x)
        x = F.relu(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return self.conv.get_example_inputs()

class FunctionalConvReluConvModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = FunctionalConv2d()
        self.relu = nn.ReLU()
        self.conv2 = FunctionalConv2d()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x

    def get_example_inputs(self) -> tuple[Any, ...]:
        return self.conv1.get_example_inputs()

class SkipQuantModel(torch.nn.Module):
    r"""We can skip quantization by explicitly
    setting qconfig of a submodule to None
    """
    def __init__(self) -> None:
        super().__init__()
        self.sub = InnerModule()
        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)

    def forward(self, x):
        return self.fc(self.sub(x))

    def fuse_modules(self):
        self.sub.fuse_modules()

class AnnotatedSkipQuantModel(torch.nn.Module):
    r"""We can skip quantization by explicitly
    setting qconfig of a submodule to None
    """
    def __init__(self, qengine):
        super().__init__()
        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
        self.sub = QuantWrapper(InnerModule())
        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
        # don't quantize this fc
        self.fc.qconfig = None

    def forward(self, x):
        return self.fc(self.sub(x))

    def fuse_modules(self):
        self.sub.module.fuse_modules()

class QuantStubModel(torch.nn.Module):
    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
    """
    def __init__(self) -> None:
        super().__init__()
        self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        return self.dequant(x)

class ManualLinearQATModel(torch.nn.Module):
    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
    """
    def __init__(self, qengine):
        super().__init__()
        self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
        self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)

    def forward(self, x):
        x = self.quant(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return self.dequant(x)

class ManualDropoutQATModel(torch.nn.Module):
    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
    """
    def __init__(self, qengine):
        super().__init__()
        self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
        self.dropout = torch.nn.Dropout(0.5)

    def forward(self, x):
        x = self.quant(x)
        x = self.fc1(x)
        x = self.dropout(x)
        return self.dequant(x)

class ManualLinearDynamicQATModel(torch.nn.Module):
    r"""A Module that uses a dynamic QAT by default.
    """
    def __init__(self, qconfig=None):
        super().__init__()
        self.qconfig = qconfig or default_dynamic_qat_qconfig
        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
        self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

class ManualConvLinearQATModel(torch.nn.Module):
    r"""A module with manually inserted `QuantStub` and `DeQuantStub`
    and contains both linear and conv modules
    """
    def __init__(self, qconfig=None):
        super().__init__()
        self.qconfig = qconfig if qconfig else torch.ao.quantization.get_default_qat_qconfig("qnnpack")
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float)
        self.fc1 = torch.nn.Linear(64, 10).to(dtype=torch.float)
        self.fc2 = torch.nn.Linear(10, 10).to(dtype=torch.float)

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = x.view(-1, 64).contiguous()
        x = self.fc1(x)
        x = self.fc2(x)
        return self.dequant(x)

class ManualConvLinearSymmQATModel(ManualConvLinearQATModel):
    r"""Same as ManualConvLinearQATModule but with Symmetric Quantization.
    Supported only with qnnpack.
    """
    def __init__(self) -> None:
        super().__init__(default_symmetric_qnnpack_qat_qconfig)

class ManualEmbeddingBagLinear(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum')
        self.emb.qconfig = default_embedding_qat_qconfig
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.linear = nn.Linear(12, 1).to(dtype=torch.float)
        self.qconfig = get_default_qat_qconfig("qnnpack")

    def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None,
                per_sample_weights: Optional[torch.Tensor] = None):
        x = self.emb(input, offsets, per_sample_weights)
        x = self.quant(x)
        x = self.linear(x)
        return self.dequant(x)

class DeFusedEmbeddingBagLinear(nn.Module):
    r"""A module to simulate QAT embedding bag with a linear layer,
    this module uses a separate embedding and bagging op, similar
    to that which is described in the EmbeddingBag documentation.

    https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html
    """
    def __init__(self) -> None:
        super().__init__()
        self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12)
        self.emb.qconfig = default_embedding_qat_qconfig
        self.bagging_op = torch.sum
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.linear = nn.Linear(12, 1).to(dtype=torch.float)
        self.qconfig = get_default_qat_qconfig("qnnpack")

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        x = self.bagging_op(self.emb(input), dim=1)
        x = self.quant(x)
        x = self.linear(x)
        return self.dequant(x)

class SubModelForFusion(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
        self.bn = nn.BatchNorm2d(2).to(dtype=torch.float)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x


class SubModelWithoutFusion(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
        self.relu = nn.ReLU(inplace=False).to(dtype=torch.float)

    def forward(self, x):
        return self.relu(self.conv(x))

class ModelForFusion(nn.Module):
    def __init__(self, qconfig):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 2, 1, bias=None).to(dtype=torch.float)
        self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
        self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
        self.sub1 = SubModelForFusion()
        self.sub2 = SubModelWithoutFusion()
        self.fc = nn.Linear(36, 10).to(dtype=torch.float)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.qconfig = qconfig
        self.conv2 = nn.Conv3d(3, 2, (1, 1, 1), bias=None).to(dtype=torch.float)
        self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float)
        self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float)
        self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float)
        self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float)
        self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float)
        self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float)
        # don't quantize sub2
        self.sub2.qconfig = None
        self.fc.qconfig = None

    def forward(self, x):
        x = x.squeeze(2)
        x = self.quant(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu4(x)
        x = x.unsqueeze(2)
        y = x.unsqueeze(2)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.sub1(x)
        x = self.dequant(x)
        x = self.sub2(x)
        x = x.reshape(-1, 36).contiguous()
        x = self.fc(x)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.bn2(y)
        y = self.relu3(y)
        y = self.dequant(y)
        return x

class ConvBNReLU(nn.Sequential):
    def __init__(self) -> None:
        super().__init__(
            nn.Conv2d(3, 3, 1, 1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU(inplace=False)
        )

class ModelWithSequentialFusion(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 3, 1)
        self.relu1 = nn.ReLU(inplace=False)
        layers = [ConvBNReLU() for _ in range(3)]
        self.features = nn.Sequential(*layers)
        head = [nn.Linear(300, 10), nn.ReLU(inplace=False)]
        self.classifier = nn.Sequential(*head)
        self.seq = nn.Sequential()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.features(x)
        x = torch.reshape(x, (-1, 3 * 10 * 10))
        x = self.classifier(x)
        x = self.seq(x)
        x = self.dequant(x)
        return x

class ModelForFusionWithBias(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 2, 5, bias=True).to(dtype=torch.float)
        self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
        self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
        self.conv2 = nn.Conv2d(2, 2, 1, bias=True).to(dtype=torch.float)
        self.bn2 = nn.BatchNorm2d(2).to(dtype=torch.float)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.dequant(x)
        return x

class ModelForLinearBNFusion(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc = nn.Linear(20, 10)
        self.bn = nn.BatchNorm1d(10)
        nn.init.uniform_(self.bn.weight)
        nn.init.uniform_(self.bn.bias)

    def forward(self, x):
        return self.bn(self.fc(x))

class DummyObserver(torch.nn.Module):
    def calculate_qparams(self):
        return 1.0, 0

    def forward(self, x):
        return x


class ModelForConvTransposeBNFusion(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.ConvTranspose1d(3, 3, 1)
        self.bn1 = nn.BatchNorm1d(3)
        self.conv2 = nn.ConvTranspose2d(3, 3, 1)
        self.bn2 = nn.BatchNorm2d(3)
        self.conv3 = nn.ConvTranspose3d(3, 3, 1)
        self.bn3 = nn.BatchNorm3d(3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = x.unsqueeze(2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = x.unsqueeze(2)
        x = self.conv3(x)
        x = self.bn3(x)
        return x


class ModelWithFunctionals(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.mycat = nnq.FloatFunctional()
        self.myadd = nnq.FloatFunctional()
        self.myadd_relu = nnq.FloatFunctional()
        self.mymatmul = nnq.FloatFunctional()
        # Tracing doesnt work yet for c10 ops with scalar inputs
        # https://github.com/pytorch/pytorch/issues/27097
        # self.my_scalar_add = nnq.FloatFunctional()
        # self.my_scalar_mul = nnq.FloatFunctional()

    def forward(self, x):
        y = self.mycat.cat([x, x, x])
        z = self.myadd.add(y, y)
        w = self.myadd_relu.add_relu(z, z)
        u = self.mymatmul.matmul(w, w.T)
        # Tracing doesnt work yet for c10 ops with scalar inputs
        # https://github.com/pytorch/pytorch/issues/27097
        # w = self.my_scalar_add.add_scalar(w, -0.5)
        # w = self.my_scalar_mul.mul_scalar(w, 0.5)
        return u


class ResNetBase(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        norm_layer = nn.BatchNorm2d
        inplanes = 3
        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
        self.bn1 = norm_layer(inplanes)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.downsample = torch.nn.Identity()
        self.myop = nn.quantized.FloatFunctional()
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(inplanes, 1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        identity = self.downsample(x)
        out = self.myop.add(out, identity)
        out = self.relu2(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

    def fuse_model(self):
        # TODO: remove this check and define two fuse_model function on this module
        if self.training:
            torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True)
        else:
            torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)

class ModelMultipleOps(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        norm_layer = nn.BatchNorm2d
        inplanes = 3
        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
        self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
        self.bn1 = norm_layer(inplanes)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.downsample = torch.nn.Identity()
        self.skip_add = nn.quantized.FloatFunctional()
        self.cat = nn.quantized.FloatFunctional()
        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))
        self.fc = nn.Linear(12, 6)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        identity = self.downsample(x)
        out = self.skip_add.add(out, identity)
        out = self.relu2(out)
        out = self.avgpool(out)
        out = self.conv2(out)
        out = torch.nn.functional.max_pool2d(out, 2, 2)
        out = self.cat.cat([out, out])
        out = out.reshape(-1, 3 * 2 * 2)
        out = self.fc(out)
        return out

# Model to ensure consistency of fake quant with true quant
# Average pooling and mean operations are not modelled
# accurately with fake-quant so this model does not
# contain those operations
class ModelMultipleOpsNoAvgPool(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        norm_layer = nn.BatchNorm2d
        inplanes = 3
        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
        self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
        self.bn1 = norm_layer(inplanes)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.skip_add = nn.quantized.FloatFunctional()
        self.cat = nn.quantized.FloatFunctional()
        self.maxpool = nn.MaxPool2d((4, 4))
        self.fc = nn.Linear(12, 6)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        skip = self.conv2(x)
        out = self.skip_add.add(out, skip)
        out = self.relu2(out)
        out = self.maxpool(out)
        out = self.conv2(out)
        out = torch.nn.functional.max_pool2d(out, 2, 2)
        out = self.cat.cat([out, out])
        out = out.reshape(-1, 3 * 2 * 2)
        out = self.fc(out)
        return out

class EmbeddingBagModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
                                         include_last_offset=True, scale_grad_by_freq=False, mode='sum')

    def forward(self, indices, offsets, per_sample_weights):
        return self.emb(indices, offsets, per_sample_weights)

class EmbeddingModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)

    def forward(self, indices):
        return self.emb(indices)

class EmbeddingWithStaticLinear(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12)
        self.fc = torch.nn.Linear(4, 2)
        self.emb.qconfig = float_qparams_weight_only_qconfig
        self.qconfig = default_qconfig
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, indices, offsets, linear_in):
        emb = self.emb(indices, offsets)
        q_x = self.quant(linear_in)
        fc = self.fc(q_x)
        fc = self.dequant(fc)
        features = torch.cat([fc] + [emb], dim=1)
        return features

class DenseTopMLP(nn.Module):

    def __init__(self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out) -> None:
        super().__init__()

        self.dense_mlp = nn.Sequential(
            nn.Linear(dense_dim, dense_out),
        )
        self.top_mlp = nn.Sequential(
            nn.Linear(dense_out + embedding_dim, top_out_in),
            nn.Linear(top_out_in, top_out_out),
        )

    def forward(
        self,
        sparse_feature: torch.Tensor,
        dense: torch.Tensor,
    ) -> torch.Tensor:
        dense_feature = self.dense_mlp(dense)
        features = torch.cat([dense_feature] + [sparse_feature], dim=1)

        out = self.top_mlp(features)
        return out

# thin wrapper around embedding bag, because tracing inside nn.Embedding
# bag is not supported at the moment and this is top level
class EmbBagWrapper(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode='sum')

    def forward(self, indices, offsets):
        return self.emb_bag(indices, offsets)

class SparseNNModel(nn.Module):
    _NUM_EMBEDDINGS = 10
    _EMBEDDING_DIM = 5
    _DENSE_DIM = 4
    _DENSE_OUTPUT = 2
    _TOP_OUT_IN = 2
    _TOP_OUT_OUT = 2
    _TOP_MLP_DIM = 1

    def __init__(self) -> None:
        super().__init__()

        self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM)
        self.dense_top = DenseTopMLP(
            self._DENSE_DIM, self._DENSE_OUTPUT, self._EMBEDDING_DIM, self._TOP_OUT_IN,
            self._TOP_OUT_OUT)

    def forward(
        self,
        sparse_indices: torch.Tensor,
        sparse_offsets: torch.Tensor,
        dense: torch.Tensor,
    ) -> torch.Tensor:

        sparse_feature = self.model_sparse(sparse_indices, sparse_offsets)
        out = self.dense_top(sparse_feature, dense)

        return out

class TestHelperModules:
    class ControlFlow(torch.nn.Module):
        def forward(
            self,
            xs: torch.Tensor,
            pred1: torch.Tensor,
            pred2: torch.Tensor,
            y: torch.Tensor,
        ) -> torch.Tensor:

            def true_nested(y: torch.Tensor) -> torch.Tensor:
                y = y + y
                y = torch.mm(y, y)
                return y

            def false_nested(y: torch.Tensor) -> torch.Tensor:
                return torch.mm(y, y)

            def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
                z = control_flow.cond(pred2, true_nested, false_nested, [x])
                return x + z

            def false_fn(x: torch.Tensor, _) -> torch.Tensor:
                return x.cos()

            def map_fn(
                x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
            ) -> torch.Tensor:
                x = x.cos()
                y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
                x = x + y
                return x.sin()

            y = torch.mm(y, y)
            return control_flow.map(map_fn, xs, pred1, pred2, y)

        def example_inputs(self):
            return (torch.ones(2, 2), torch.tensor([False]), torch.tensor([False]), torch.ones(2, 2),)

    class Conv2dPropAnnotaton(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 3, 3)
            self.linear = torch.nn.Linear(3, 3)

        def forward(self, x):
            x = self.conv(x)
            x = x.view(-1, 3)
            x = torch.nn.functional.hardtanh(x, -0.5, 0.5)
            x = self.linear(x)
            return x

    class Conv2dWithObsSharingOps(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 3, 3)
            self.hardtanh = torch.nn.Hardtanh()
            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))

        def forward(self, x):
            x = self.conv(x)
            x = self.adaptive_avg_pool2d(x)
            x = self.hardtanh(x)
            x = torch.mean(x)
            return x

    class Conv2dWithTwoLinearPermute(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 16, 3)
            self.linear1 = torch.nn.Linear(16, 8, bias=False)
            self.linear2 = torch.nn.Linear(8, 8)

        def forward(self, x):
            conv_out = self.conv(x)
            permute_out = torch.permute(conv_out, (0, 2, 3, 1))
            return self.linear2(self.linear1(permute_out))

    class Conv2dWithTwoLinear(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 16, 3)
            self.linear1 = torch.nn.Linear(64, 8, bias=False)
            self.linear2 = torch.nn.Linear(8, 8)

        def forward(self, x):
            conv_out = self.conv(x)
            reshape_out = torch.reshape(conv_out, (2, 64))
            return self.linear2(self.linear1(reshape_out))

    class ConvLinearWPermute(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 8, 3)
            self.linear1 = torch.nn.Linear(8, 8)

        def forward(self, x):
            conv_out = self.conv(x)
            permute_out = torch.permute(conv_out, (0, 2, 3, 1))
            return self.linear1(permute_out)

    class TwoLinearModule(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.linear1 = torch.nn.Linear(8, 16, bias=False)
            self.linear2 = torch.nn.Linear(16, 8)

        def forward(self, x):
            return self.linear2(self.linear1(x))

        def example_inputs(self):
            return (torch.randn(2, 8),)

    class ConvMaxPool2d(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(2, 2, 1)
            self.pool = torch.nn.MaxPool2d(1, 1)

        def forward(self, x):
            x = self.conv(x)
            x = self.pool(x)
            return x

    class ConvWithAdaptiveAvgPool2d(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 3, 3)
            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))

        def forward(self, x):
            x = self.conv(x)
            x = self.adaptive_avg_pool2d(x)
            return x

    class ConvWithBNRelu(torch.nn.Module):
        def __init__(self, relu, dim=2, bn=True, bias=True):
            super().__init__()
            convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d}
            bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
            self.conv = convs[dim](3, 3, 3, bias=bias)

            if bn:
                self.bn = bns[dim](3)
            else:
                self.bn = torch.nn.Identity()
            if relu:
                self.relu = torch.nn.ReLU()
            else:
                self.relu = torch.nn.Identity()

        def forward(self, x):
            x = self.conv(x)
            x = self.bn(x)
            return self.relu(x)

    class ConvTWithBNRelu(torch.nn.Module):
        def __init__(self, relu, dim=2, bn=True, bias=True):
            super().__init__()
            convts = {1: torch.nn.ConvTranspose1d, 2: torch.nn.ConvTranspose2d}
            bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
            self.convt = convts[dim](3, 3, 3, bias=bias)

            if bn:
                self.bn = bns[dim](3)
            else:
                self.bn = torch.nn.Identity()
            if relu:
                self.relu = torch.nn.ReLU()
            else:
                self.relu = torch.nn.Identity()

        def forward(self, x):
            x = self.convt(x)
            x = self.bn(x)
            return self.relu(x)

    class Conv2dThenConv1d(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv1d = torch.nn.Conv1d(3, 3, 3)
            self.conv2d = torch.nn.Conv2d(3, 3, 3)

        def forward(self, x):
            x = self.conv2d(x)
            x = x.squeeze(0)
            x = self.conv1d(x)
            return x

        def example_inputs(self):
            return (torch.randn(1, 3, 5, 5),)

    class Conv2dWithCat(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv1 = torch.nn.Conv2d(3, 3, 3)
            self.conv2 = torch.nn.Conv2d(3, 3, 3)

        def forward(self, x, y):
            x = self.conv1(x)
            y = self.conv2(y)
            z = torch.cat([x, y], dim=1)
            return z

    class Conv2dWithTwoCat(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv1 = torch.nn.Conv2d(3, 3, 3)
            self.conv2 = torch.nn.Conv2d(3, 3, 3)

        def forward(self, x1, x2, x3, x4):
            x1 = self.conv1(x1)
            x2 = self.conv2(x2)
            y = torch.cat([x1, x2], dim=1)
            z = x3 + x4
            w = torch.cat([z, y])
            return w

    class Conv2dWithSplit(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv1 = torch.nn.Conv2d(3, 3, 3)
            self.conv2 = torch.nn.Conv2d(3, 3, 3)

        def forward(self, x):
            x = self.conv1(x)
            # use split so we get a list of Tensors
            x1, x2 = torch.split(x, 2, dim=1)
            y = torch.cat([x1, x2], dim=1)
            return y

        def example_inputs(self):
            return (torch.randn(1, 3, 16, 16),)

    class ThreeAdd(torch.nn.Module):
        def forward(self, x1, x2, x3, x4):
            y = x1 + x2
            z = x3 + x4
            w = y + z
            return w

    class EmbeddingModule(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)

        def forward(self, indices):
            return self.emb(indices)

    class EmbeddingConvLinearModule(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
            self.conv = torch.nn.Conv2d(8, 16, (1, 3))
            self.linear = torch.nn.Linear(16, 8)

        def forward(self, indices):
            embeddings = self.emb(indices)
            embeddings = torch.unsqueeze(embeddings, dim=0)
            embeddings = torch.permute(embeddings, (0, 3, 1, 2))
            conv_out = self.conv(embeddings)
            conv_out = torch.permute(conv_out, (0, 2, 3, 1))
            conv_out = torch.squeeze(conv_out, dim=0)
            return self.linear(conv_out)

    class AddInplaceAdd(torch.nn.Module):
        def forward(self, x, y):
            x = x + y
            x += y
            return x

    class MulInplaceMul(torch.nn.Module):
        def forward(self, x, y):
            x = x * y
            x *= y
            return x

    class AddMulScalar(torch.nn.Module):
        def forward(self, x):
            x = x + 3
            x = x * 3
            x += 3
            x *= 3
            return x

    class ConvBnReLU2dAndLinearReLU(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True)
            self.linear = torch.nn.Linear(3, 8, bias=False)
            self.relu = torch.nn.ReLU()

        def forward(self, x):
            x = self.conv_bn_relu(x)
            permute_out = torch.permute(x, (0, 2, 3, 1))
            linear_out = self.linear(permute_out)
            return linear_out

    class GroupwiseConv2d(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(4, 4, 3, groups=2)

        def forward(self, x):
            return self.conv(x)

        def example_inputs(self):
            return (torch.randn(2, 4, 10, 10),)

    class LinearReluModel(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
            self.relu = torch.nn.ReLU()

        def forward(self, x):
            x = self.relu(self.fc(x))
            return x

def _generate_qdq_quantized_model(
    mod, inputs, is_qat=False, is_dynamic=False, quantizer=None
):

    def get_default_quantizer(is_qat, is_dynamic, inputs):
        has_xpu = any(isinstance(input, torch.Tensor) and input.device.type == "xpu"
                      for input in inputs)
        if has_xpu:
            quantizer = XPUInductorQuantizer()
            assert (not is_qat) and (not is_dynamic), "QAT and dynamic quantization is not supported at XPU backend currently"
            quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config())
        else:
            quantizer = X86InductorQuantizer()
            quantizer.set_global(
                xiq.get_default_x86_inductor_quantization_config(
                    is_qat=is_qat, is_dynamic=is_dynamic
                )
            )
        return quantizer

    maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
    with maybe_no_grad:
        export_model = export_for_training(
            mod,
            inputs,
        ).module()
        quantizer = (
            quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic, inputs)
        )
        prepare_model = (
            prepare_qat_pt2e(export_model, quantizer)
            if is_qat
            else prepare_pt2e(export_model, quantizer)
        )
        prepare_model(*inputs)
        torch.ao.quantization.move_exported_model_to_eval(prepare_model)
        convert_model = convert_pt2e(prepare_model)
        return convert_model
