# mypy: allow-untyped-defs
import copy
from typing import Any, Union

import torch
from torch.fx import GraphModule
from torch.fx.graph import Graph


__all__ = [
    "FusedGraphModule",
    "ObservedGraphModule",
    "ObservedStandaloneGraphModule",
    "QuantizedGraphModule",
]


class FusedGraphModule(GraphModule):
    def __init__(
        self,
        root: Union[torch.nn.Module, dict[str, Any]],
        graph: Graph,
        preserved_attr_names: set[str],
    ):
        self.preserved_attr_names = preserved_attr_names
        preserved_attrs = {
            attr: getattr(root, attr)
            for attr in self.preserved_attr_names
            if hasattr(root, attr)
        }
        super().__init__(root, graph)
        for attr in preserved_attrs:
            setattr(self, attr, preserved_attrs[attr])

    # GraphModule does not copy attributes which are not in the __dict__
    # of vanilla nn.Module.  So, we override __deepcopy__ in order
    # to copy the quantization specific attributes correctly.
    def __deepcopy__(self, memo):
        fake_mod = torch.nn.Module()
        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
        return FusedGraphModule(
            fake_mod,
            copy.deepcopy(self.graph),
            copy.deepcopy(self.preserved_attr_names),
        )


class ObservedGraphModule(GraphModule):
    def __init__(
        self,
        root: Union[torch.nn.Module, dict[str, Any]],
        graph: Graph,
        preserved_attr_names: set[str],
    ):
        self.preserved_attr_names = {
            "_activation_post_process_map",
            "_activation_post_process_indexes",
            "_patterns",
            "_node_name_to_qconfig",
            "_prepare_custom_config",
            "_equalization_node_name_to_qconfig",
            "_node_name_to_scope",
            "_qconfig_mapping",
            "_is_qat",
            "_observed_node_names",
        }.union(preserved_attr_names)
        preserved_attrs = {
            attr: getattr(root, attr)
            for attr in self.preserved_attr_names
            if hasattr(root, attr)
        }
        super().__init__(root, graph)
        for attr in preserved_attrs:
            setattr(self, attr, preserved_attrs[attr])

    # GraphModule does not copy attributes which are not in the __dict__
    # of vanilla nn.Module.  So, we override __deepcopy__ in order
    # to copy the quantization specific attributes correctly.
    def __deepcopy__(self, memo):
        fake_mod = torch.nn.Module()
        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
        return ObservedGraphModule(
            fake_mod,
            copy.deepcopy(self.graph),
            copy.deepcopy(self.preserved_attr_names),
        )


def _is_observed_module(module: Any) -> bool:
    return hasattr(module, "meta") and "_observed_graph_module_attrs" in module.meta


def _get_observed_graph_module_attr(
    model: Union[torch.nn.Module, GraphModule], attr_name: str
) -> Any:
    if hasattr(model, "meta") and "_observed_graph_module_attrs" in model.meta:  # type: ignore[operator, index]
        return getattr(model.meta["_observed_graph_module_attrs"], attr_name)  # type: ignore[index]
    return None


class ObservedStandaloneGraphModule(ObservedGraphModule):
    def __init__(
        self,
        root: Union[torch.nn.Module, dict[str, Any]],
        graph: Graph,
        preserved_attr_names: set[str],
    ):
        preserved_attr_names = preserved_attr_names.union(
            {
                "_standalone_module_input_quantized_idxs",
                "_standalone_module_output_quantized_idxs",
            }
        )
        super().__init__(root, graph, preserved_attr_names)

    def __deepcopy__(self, memo):
        fake_mod = torch.nn.Module()
        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
        return ObservedStandaloneGraphModule(
            fake_mod,
            copy.deepcopy(self.graph),
            copy.deepcopy(self.preserved_attr_names),
        )


def _is_observed_standalone_module(module: Any) -> bool:
    return (
        _is_observed_module(module)
        and module.meta["_observed_graph_module_attrs"].is_observed_standalone_module
    )


def _save_packed_weight(self, destination, prefix, keep_vars):
    for attr_name in dir(self):
        if "_packed_weight" in attr_name and isinstance(
            getattr(self, attr_name), torch._C.ScriptObject
        ):  # type: ignore[attr-defined]
            packed_weight = getattr(self, attr_name)
            destination[prefix + attr_name] = packed_weight


class QuantizedGraphModule(GraphModule):
    """This class is created to make sure PackedParams
    (e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict
    so that we can serialize and deserialize quantized graph module with
    torch.save(m.state_dict()) and m.load_state_dict(state_dict)
    """

    def __init__(
        self,
        root: Union[torch.nn.Module, dict[str, Any]],
        graph: Graph,
        preserved_attr_names: set[str],
    ):
        self.preserved_attr_names = preserved_attr_names
        preserved_attrs = {
            attr: getattr(root, attr)
            for attr in self.preserved_attr_names
            if hasattr(root, attr)
        }
        super().__init__(root, graph)
        for attr in preserved_attrs:
            setattr(self, attr, preserved_attrs[attr])
        self._register_state_dict_hook(_save_packed_weight)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        attrs_to_pop = []
        for attr_name in state_dict:
            if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject):  # type: ignore[attr-defined] # noqa: B950
                setattr(self, attr_name, state_dict[attr_name])
                attrs_to_pop.append(attr_name)

        # pop the packed param attributesn
        for attr_name in attrs_to_pop:
            state_dict.pop(attr_name)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

    def __deepcopy__(self, memo):
        fake_mod = torch.nn.Module()
        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
        return QuantizedGraphModule(
            fake_mod,
            copy.deepcopy(self.graph),
            copy.deepcopy(self.preserved_attr_names),
        )
