# mypy: ignore-errors

from __future__ import annotations

import contextlib
import dataclasses
import sys
import threading
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing_extensions import override, Self
from unittest.mock import patch

from torch._inductor import config
from torch._inductor.remote_cache import RemoteCacheBackend


if TYPE_CHECKING:
    from types import TracebackType


@dataclasses.dataclass
class Stats:
    num_put: int = 0
    num_get_hit: int = 0
    num_get_miss: int = 0

    def __iadd__(self, other: Stats) -> Self:
        self.num_put += other.num_put
        self.num_get_hit += other.num_get_hit
        self.num_get_miss += other.num_get_miss
        return self

    def reset(self) -> None:
        self.num_put = 0
        self.num_get_hit = 0
        self.num_get_miss = 0

    def __str__(self) -> str:
        return "".join(
            (
                f"puts: {self.num_put}, ",
                f"misses: {self.num_get_miss}, ",
                f"hits: {self.num_get_hit}, ",
            )
        )

    def __eq__(self, other: object) -> bool:
        # Dataclass's default __eq__ checks that the types are the same so can't
        # be used with _GlobalItemStats.
        return (
            isinstance(other, (Stats, _GlobalItemStats))
            and self.num_put == other.num_put
            and self.num_get_hit == other.num_get_hit
            and self.num_get_miss == other.num_get_miss
        )


class _GlobalItemStats(Stats):
    cache: dict[str, object]

    def __init__(self) -> None:
        super().__init__()
        self.cache = {}

    def reset(self) -> None:
        super().reset()
        self.cache = {}


# The cache states are thread-local so if we're running multiple tests at once
# they won't cross contaminate. However - it needs to be "global" because we
# allow code to create new cache clients which refer to the same cache (because
# it's a remote cache).


class _GlobalStats(threading.local):
    def __init__(self) -> None:
        self.autotune_local = _GlobalItemStats()
        self.autotune_remote = _GlobalItemStats()
        self.bundled_autotune = _GlobalItemStats()
        self.fx_graph = _GlobalItemStats()
        self.triton = _GlobalItemStats()
        self.aot_autograd = _GlobalItemStats()
        self.dynamo_pgo = _GlobalItemStats()

    def reset(self) -> None:
        self.autotune_local.reset()
        self.autotune_remote.reset()
        self.bundled_autotune.reset()
        self.fx_graph.reset()
        self.triton.reset()
        self.aot_autograd.reset()
        self.dynamo_pgo.reset()

    def get_stat(self, name: str) -> _GlobalItemStats:
        return getattr(self, name)

    def report(self):
        subs = (
            ("autotune_local", self.autotune_local),
            ("autotune_remote", self.autotune_remote),
            ("bundled_autotune", self.bundled_autotune),
            ("fx_graph", self.fx_graph),
            ("triton", self.triton),
            ("aot_autograd", self.aot_autograd),
            ("dynamo_pgo", self.dynamo_pgo),
        )

        print("Cache Stats:", file=sys.stderr)
        for name, sub in subs:
            print(f"  {name}: {sub}", file=sys.stderr)

        print("Cache Entries:", file=sys.stderr)
        for name, sub in subs:
            if sub.cache:
                print(f"  {name}:", file=sys.stderr)
                for k, v in sorted(sub.cache.items()):
                    v = repr(v)
                    if len(v) > 100:
                        v = v[:100] + "..."
                    print(f"    {k!r}: {v}", file=sys.stderr)


global_stats = _GlobalStats()


class MockBackend(RemoteCacheBackend[Any]):
    def __init__(self, name: str) -> None:
        self._name = name

    @staticmethod
    def with_name(name: str) -> Callable[[], MockBackend]:
        def wrapper() -> MockBackend:
            return MockBackend(name)

        return wrapper

    @override
    def _get(self, key: str) -> Optional[Any]:
        stat = global_stats.get_stat(self._name)
        if key in stat.cache:
            stat += Stats(num_get_hit=1)
            return stat.cache.get(key)
        else:
            stat += Stats(num_get_miss=1)
            return None

    @override
    def _put(self, key: str, data: Any) -> None:
        stat = global_stats.get_stat(self._name)
        stat += Stats(num_put=1)
        stat.cache[key] = data


# List of configs for each cache
_CACHE_CONFIG_EN = (
    "fx_graph_cache",
    "fx_graph_remote_cache",
    "autotune_local_cache",
    "autotune_remote_cache",
    "bundled_autotune_remote_cache",
)


class PatchCaches(contextlib.AbstractContextManager):
    @classmethod
    def setUp(cls):
        # If this test is using PatchCaches then disable all the caches by
        # default, letting the tests turn them on explicitly. This is because
        # tests using PatchCaches will often want to check stats explicitly.
        cls._savedCacheState = {}
        for name in _CACHE_CONFIG_EN:
            if hasattr(config, name):
                cls._savedCacheState[name] = getattr(config, name)
            setattr(config, name, False)

    @classmethod
    def tearDown(cls):
        # Restore cache defaults
        for name in _CACHE_CONFIG_EN:
            delattr(config, name)
            if name in cls._savedCacheState:
                setattr(config, name, cls._savedCacheState[name])

    def __init__(self) -> None:
        self._stack = contextlib.ExitStack()

    def __enter__(self) -> Self:
        global_stats.reset()
        self._stack.__enter__()

        ctx = patch(
            "torch._inductor.runtime.autotune_cache.LocalAutotuneCache.backend_override_cls",
            MockBackend.with_name("autotune_local"),
        )
        self._stack.enter_context(ctx)

        ctx = patch(
            "torch._inductor.remote_cache.RemoteAutotuneCache.backend_override_cls",
            MockBackend.with_name("autotune_remote"),
        )
        self._stack.enter_context(ctx)

        ctx = patch(
            "torch._inductor.remote_cache.RemoteBundledAutotuneCache.backend_override_cls",
            MockBackend.with_name("bundled_autotune"),
        )
        self._stack.enter_context(ctx)

        ctx = patch(
            "torch._inductor.remote_cache.RemoteFxGraphCache.backend_override_cls",
            MockBackend.with_name("fx_graph"),
        )
        self._stack.enter_context(ctx)

        ctx = patch(
            "torch._inductor.remote_cache.RemoteAOTAutogradCache.backend_override_cls",
            MockBackend.with_name("aot_autograd"),
        )
        self._stack.enter_context(ctx)

        ctx = patch(
            "torch._inductor.remote_cache.RemoteDynamoPGOCache.backend_override_cls",
            MockBackend.with_name("dynamo_pgo"),
        )
        self._stack.enter_context(ctx)

        if config.is_fbcode():
            ctx = patch(
                "torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls",
                MockBackend.with_name("autotune_remote"),
            )
            self._stack.enter_context(ctx)

            ctx = patch(
                "torch._inductor.fb.remote_cache.FbRemoteBundledAutotuneCache.backend_override_cls",
                MockBackend.with_name("bundled_autotune"),
            )
            self._stack.enter_context(ctx)

            ctx = patch(
                "torch._inductor.fb.remote_cache.FbRemoteFxGraphCache.backend_override_cls",
                MockBackend.with_name("fx_graph"),
            )
            self._stack.enter_context(ctx)

            ctx = patch(
                "triton.fb.fb_memcache.FbMemcacheRemoteKernelCache.backend_override_cls",
                MockBackend.with_name("triton"),
            )
            self._stack.enter_context(ctx)

            ctx = patch(
                "torch._inductor.fb.remote_cache.FbRemoteAOTAutogradCache.backend_override_cls",
                MockBackend.with_name("aot_autograd"),
            )
            self._stack.enter_context(ctx)

            ctx = patch(
                "torch._inductor.fb.remote_cache.FbRemoteDynamoPGOCache.backend_override_cls",
                MockBackend.with_name("dynamo_pgo"),
            )
            self._stack.enter_context(ctx)

        return self

    def __exit__(
        self,
        exc_type: Optional[type[BaseException]],
        exc_value: Optional[BaseException],
        traceback: Optional[TracebackType],
    ) -> None:
        self._stack.__exit__(exc_type, exc_value, traceback)
