# mypy: allow-untyped-defs
import abc
import copy
from collections import defaultdict
from typing import Any, Optional

import torch
from torch import nn
from torch.nn.utils import parametrize
from torch.nn.utils.parametrize import type_before_parametrizations

from .utils import (
    FakeSparsity,
    get_arg_info_from_tensor_fqn,
    module_contains_param,
    module_to_fqn,
    swap_module,
)


__all__ = ["BaseSparsifier"]

SUPPORTED_MODULES = {nn.Linear}

KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"]


# TODO update desc with new config args
class BaseSparsifier(abc.ABC):
    r"""Base class for all sparsifiers.

    Abstract methods that need to be implemented:

    - update_mask: Function to compute a new mask for all keys in the
        `groups`.

    Args:
        - model [nn.Module]: model to configure. The model itself is not saved
            but used for the state_dict saving / loading.
        - config [list]: configuration elements should be a dict map that includes
            `tensor_fqn` of tensors to sparsify
        - defaults [dict]: default configurations will be attached to the
            configuration. Only the keys that don't exist in the `config` will
            be updated.

    Example::

        >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask")
        >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}]
        >>> defaults = {'sparsity_level': 0.7}
        >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default)
        >>> sparsifier = BaseSparsifier(config, defaults)
    """

    def __init__(self, defaults: Optional[dict[str, Any]] = None):
        super().__init__()
        self.defaults: dict[str, Any] = defaults or {}

        self.state: dict[str, dict] = defaultdict(dict)
        self.groups: list[dict[str, Any]] = []
        self.enable_mask_update = True

    def __getstate__(self) -> dict[str, Any]:
        return {
            "defaults": self.defaults,
            "state": self.state,
            "groups": self.groups,
        }

    def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
        self.__dict__.update(state)

    def __repr__(self):
        format_string = self.__class__.__name__ + " ("
        for i, sparse_args in enumerate(self.groups):
            module = sparse_args["module"]
            format_string += "\n"
            format_string += f"\tGroup {i}\n"
            format_string += f"\t    module: {module}\n"
            for key in sorted(sparse_args.keys()):
                if key == "module":
                    continue
                format_string += f"\t    {key}: {sparse_args[key]}\n"
        format_string += ")"
        return format_string

    def state_dict(self) -> dict[str, Any]:
        r"""Returns the state of the optimizer as a :class:`dict`.

        It contains:
        * state - current state of the sparsification.
        * groups - a list containing all sparsity configuration groups
            with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model

        TODO: Need a clean way of loading the state of the "prepared" module
        """

        groups: list[dict[str, Any]] = [
            dict(
                filter(
                    lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT,
                    mg.items(),
                )
            )
            for mg in self.groups
        ]

        return {
            "state": self.state,
            "groups": groups,
        }

    def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True):
        groups = copy.deepcopy(state_dict["groups"])
        states = state_dict["state"]
        for tensor_fqn, s in states.items():
            arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn)
            module = arg_info["module"]
            tensor_name = arg_info["tensor_name"]
            if strict and module is None:
                raise RuntimeError(f"Error loading {tensor_fqn} into the model")

            found = False
            for p in module.parametrizations[tensor_name]:
                if isinstance(p, FakeSparsity):
                    found = True
                    break
            if not found:
                p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape))
                parametrize.register_parametrization(module, tensor_name, p)
            if s.get("mask", None) is not None:
                mask = s.pop("mask")
                p.mask = mask

            for mg in groups:
                if mg["tensor_fqn"] == tensor_fqn:
                    mg.update(arg_info)
        self.__setstate__({"state": states, "groups": groups})

    def make_config_from_model(
        self,
        model: nn.Module,
        SUPPORTED_MODULES: set[type[nn.Linear]] = SUPPORTED_MODULES,
    ) -> None:
        self.config = []
        stack = [model]
        while stack:
            module = stack.pop()
            for _name, child in module.named_children():
                if type(child) in SUPPORTED_MODULES:
                    module_fqn = module_to_fqn(model, child)
                    assert isinstance(module_fqn, str)  # for mypy
                    self.config.append({"tensor_fqn": module_fqn + ".weight"})
                else:
                    stack.append(child)

    def prepare(self, model, config):
        r"""Prepares a model, by adding the parametrizations.

        Note::

            The model is modified inplace. If you need to preserve the original
            model, use copy.deepcopy.
        """
        self.model = model  # TODO: Need to figure out how to load without this.
        self.config = config

        # If no config -- try getting all the supported layers
        if self.config is None:
            self.make_config_from_model(model)

        # TODO: Remove the configuration by reference ('module')
        for module_config in self.config:
            assert isinstance(module_config, dict), (
                "config elements should be dicts not modules i.e.:"
                "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]"
            )

            assert isinstance(self.defaults, dict)  # for mypy
            local_args = copy.deepcopy(self.defaults)
            local_args.update(module_config)

            tensor_fqn = local_args.get("tensor_fqn", None)
            assert tensor_fqn is not None, (
                "tensor_fqn is a required argument in the sparsity config which"
                "replaces previous `module` and [module]`fqn` arguments"
            )

            # populate all information from tensor_fqn
            info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)

            # check that whatever was put into local_args agrees with what was obtained
            # from tensor_fqn
            for key in info_from_tensor_fqn.keys():
                if key in local_args:
                    assert (
                        info_from_tensor_fqn[key] == local_args[key]
                        or (
                            key == "tensor_fqn"
                            and "." + info_from_tensor_fqn[key] == local_args[key]
                        )
                        # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that
                    ), f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!"
            local_args.update(info_from_tensor_fqn)
            self.groups.append(local_args)
        self._prepare()

    def _prepare(self, *args, **kwargs):
        r"""Adds mask parametrization to the layer weight"""
        for config in self.groups:
            module = config["module"]
            tensor_name = config["tensor_name"]
            parametrization = config.get("parametrization", FakeSparsity)
            mask = config.get("mask", torch.ones_like(getattr(module, tensor_name)))
            self.state[config["tensor_fqn"]]["mask"] = mask
            parametrize.register_parametrization(
                module, tensor_name, parametrization(mask)
            )

    def squash_mask(
        self,
        params_to_keep: Optional[tuple[str, ...]] = None,
        params_to_keep_per_layer: Optional[dict[str, tuple[str, ...]]] = None,
        *args,
        **kwargs,
    ):
        r"""Squashes the sparse masks into the appropriate tensors.

        If either the `params_to_keep` or `params_to_keep_per_layer` is set,
        the module will have a `sparse_params` dict attached to it.

        Args:
            params_to_keep: List of keys to save in the module or a dict
                            representing the modules and keys that will have
                            sparsity parameters saved
            params_to_keep_per_layer: Dict to specify the params that should be
                            saved for specific layers. The keys in the dict
                            should be the module fqn, while the values should
                            be a list of strings with the names of the variables
                            to save in the `sparse_params`

        Examples:
            >>> # xdoctest: +SKIP("locals are undefined")
            >>> # Don't save any sparse params
            >>> sparsifier.squash_mask()
            >>> hasattr(model.submodule1, 'sparse_params')
            False

            >>> # Keep sparse params per layer
            >>> sparsifier.squash_mask(
            ...     params_to_keep_per_layer={
            ...         'submodule1.linear1': ('foo', 'bar'),
            ...         'submodule2.linear42': ('baz',)
            ...     })
            >>> print(model.submodule1.linear1.sparse_params)
            {'foo': 42, 'bar': 24}
            >>> print(model.submodule2.linear42.sparse_params)
            {'baz': 0.1}

            >>> # Keep sparse params for all layers
            >>> sparsifier.squash_mask(params_to_keep=('foo', 'bar'))
            >>> print(model.submodule1.linear1.sparse_params)
            {'foo': 42, 'bar': 24}
            >>> print(model.submodule2.linear42.sparse_params)
            {'foo': 42, 'bar': 24}

            >>> # Keep some sparse params for all layers, and specific ones for
            >>> # some other layers
            >>> sparsifier.squash_mask(
            ...     params_to_keep=('foo', 'bar'),
            ...     params_to_keep_per_layer={
            ...         'submodule2.linear42': ('baz',)
            ...     })
            >>> print(model.submodule1.linear1.sparse_params)
            {'foo': 42, 'bar': 24}
            >>> print(model.submodule2.linear42.sparse_params)
            {'foo': 42, 'bar': 24, 'baz': 0.1}
        """
        for config in self.groups:
            module = config["module"]
            tensor_name = config["tensor_name"]
            parametrize.remove_parametrizations(
                module, tensor_name, leave_parametrized=True
            )
            sparse_params = {}
            if params_to_keep is not None:
                global_params = {k: config[k] for k in params_to_keep}
                sparse_params.update(global_params)
            if params_to_keep_per_layer is not None:
                params = params_to_keep_per_layer.get(config["module_fqn"], None)
                if params is not None:
                    per_layer_params = {k: config[k] for k in params}
                    sparse_params.update(per_layer_params)
            if sparse_params:
                # TODO handle multiple tensor being quantized on a single module, where to store sparse_params?
                module.sparse_params = sparse_params

    def convert(
        self,
        module: nn.Module,
        mapping: Optional[dict[type[nn.Module], type[nn.Module]]] = None,
        inplace: bool = False,
        parameterization: type[nn.Module] = FakeSparsity,
    ):
        r"""Converts submodules in input module to a different module according to `mapping`
        by calling `from_dense` method on the target module class
        Args:
            module: input module
            mapping: a dictionary that maps from source module type to target
                module type, can be overwritten to allow swapping user defined
                Modules
            inplace: carry out model transformations in-place, the original module
                is mutated
        """
        if mapping is None:
            raise NotImplementedError("Need to auto generate mapping ")
        if not inplace:
            module = copy.deepcopy(module)

        reassign = {}
        for name, mod in module.named_children():
            # leaf node
            if (
                module_contains_param(mod, parameterization)
                and type_before_parametrizations(mod) in mapping
            ):
                reassign[name] = swap_module(mod, mapping)
            else:
                # recurse
                reassign[name] = self.convert(
                    mod,
                    mapping=mapping,
                    inplace=True,
                    parameterization=parameterization,
                )

        for key, value in reassign.items():
            module._modules[key] = value

        return module

    def step(self, use_path: bool = True) -> None:
        if not self.enable_mask_update:
            return
        with torch.no_grad():
            for config in self.groups:
                self.update_mask(**config)

    @abc.abstractmethod
    def update_mask(self, module: nn.Module, tensor_name: str, **kwargs):
        pass
