# mypy: allow-untyped-defs
import os
from typing import Callable, Optional, TypeVar

from torch.fx import Graph, Node
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
from torch.fx.traceback import NodeSource, NodeSourceAction


T = TypeVar("T")


from .graph_drawer import FxGraphDrawer


__all__ = ["GraphTransformObserver"]


@compatibility(is_backward_compatible=False)
class GraphTransformObserver:
    __pass_count = 0

    def __init__(
        self,
        gm: GraphModule,
        passname: str,
        subsystem: Optional[str] = None,
        log_url: Optional[str] = None,
    ):
        """
        log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified
        """
        from torch._inductor.config import trace

        self.gm = gm
        self.passname = passname
        self.subsystem = subsystem

        if log_url is None:
            log_url = trace.log_url_for_graph_xform

        self.log_url = log_url

        self.active = trace.enabled or self.log_url is not None

        if self.active:
            self.erased_nodes: set[str] = set()
            self.created_nodes: set[str] = set()
            self.name_to_node: dict[str, Node] = {}
            # record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context
            self.copied_gms: list[GraphModule] = []

            self._node_creation_hook = self.get_node_creation_hook()
            self._node_erase_hook = self.get_node_erase_hook()
            self._node_replace_hook = self.get_node_replace_hook()
            self._deepcopy_hook = self.get_deepcopy_hook()

        # If log_url is None, we don't log anything
        if self.log_url is None:
            return
        GraphTransformObserver.__pass_count += 1

        self.input_dot_graph = FxGraphDrawer(
            self.gm,
            self.passname,
            ignore_getattr=True,
            ignore_parameters_and_buffers=True,
        ).get_dot_graph()

    @classmethod
    def get_current_pass_count(cls):
        return cls.__pass_count

    def apply_gm_pass(self, pass_fn: Callable[[GraphModule], T]) -> Optional[T]:
        with self:
            if not self._check_disable_pass():
                return pass_fn(self.gm)

        return None

    def apply_graph_pass(self, pass_fn: Callable[[Graph], T]) -> Optional[T]:
        with self:
            if not self._check_disable_pass():
                return pass_fn(self.gm.graph)

        return None

    def _check_disable_pass(self):
        if self.subsystem is None:
            return False

        debug_info = lambda: self.passname  # noqa: E731
        from torch._inductor.compiler_bisector import CompilerBisector

        return CompilerBisector.disable_subsystem(
            "inductor", self.subsystem, debug_info
        )

    def __enter__(self):
        if not self.active:
            return self
        self.gm._register_create_node_hook(self._node_creation_hook)
        self.gm._register_erase_node_hook(self._node_erase_hook)
        self.gm._register_replace_node_hook(self._node_replace_hook)
        self.gm._register_deepcopy_hook(self._deepcopy_hook)

        self.erased_nodes.clear()
        self.created_nodes.clear()
        self.name_to_node.clear()
        self.copied_gms.clear()

        for node in self.gm.graph.nodes:
            self.name_to_node[node.name] = node

        return self

    def __exit__(self, type, value, tb):
        if not self.active:
            return
        for gm in self.copied_gms + [self.gm]:
            gm._unregister_create_node_hook(self._node_creation_hook)
            gm._unregister_erase_node_hook(self._node_erase_hook)
            gm._unregister_replace_node_hook(self._node_replace_hook)
            gm._unregister_deepcopy_hook(self._deepcopy_hook)

        if self.log_url is None:
            return

        if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0:
            for e in self.input_dot_graph.get_node_list():
                if e.get_name() in self.erased_nodes:
                    e.obj_dict["attributes"]["fillcolor"] = "yellow"
                else:
                    e.obj_dict["attributes"]["fillcolor"] = "grey"
            assert self.log_url is not None
            self.input_dot_graph.write(
                os.path.join(
                    self.log_url,
                    f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.dot",
                )
            )

            output_dot_graph = FxGraphDrawer(
                self.gm,
                self.passname,
                ignore_getattr=True,
                ignore_parameters_and_buffers=True,
            ).get_dot_graph()
            for e in output_dot_graph.get_node_list():
                if e.get_name() in self.created_nodes:
                    e.obj_dict["attributes"]["fillcolor"] = "yellow"
                else:
                    e.obj_dict["attributes"]["fillcolor"] = "grey"
            output_dot_graph.write(
                os.path.join(
                    self.log_url,
                    f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.dot",
                )
            )

    def get_node_creation_hook(self):
        # We have to return a function instead of using a class method directly
        # to avoid max recursion issue when deepcopy a graph module within the context manager.
        def on_node_creation(node):
            self.created_nodes.add(node.name)
            self.name_to_node[node.name] = node
            source = NodeSource(None, self.passname, NodeSourceAction.CREATE)
            if "from_node" not in node.meta:
                node.meta["from_node"] = [source]
            else:
                node.meta["from_node"].append(source)

        return on_node_creation

    def get_node_erase_hook(self):
        def on_node_erase(node):
            self.erased_nodes.add(node.name)
            self.name_to_node.pop(node.name, None)

        return on_node_erase

    def get_node_replace_hook(self):
        def on_node_replace(old: Node, new: str, user: Node):
            # Update node meta when replacing old node with new node
            new_node = self.name_to_node.get(new, None)

            if not new_node:
                return

            assert isinstance(new_node, Node)

            action = [NodeSourceAction.REPLACE]
            if new_node.name in self.created_nodes:
                action.append(NodeSourceAction.CREATE)

            def created_this_pass(source):
                return source.pass_name == self.passname and source.action == [
                    NodeSourceAction.CREATE
                ]

            # remove redundant source added on node creation
            new_from_node = new_node.meta.get("from_node", [])
            new_from_node = [
                source for source in new_from_node if not created_this_pass(source)
            ]

            # add new source
            new_node_source = NodeSource(old, self.passname, action)
            new_from_node.append(new_node_source)
            new_node.meta["from_node"] = new_from_node

        return on_node_replace

    def get_deepcopy_hook(self):
        def on_deepcopy(gm):
            self.copied_gms.append(gm)

        return on_deepcopy
