import contextlib
import os
from typing import Union

from torch._dynamo.test_case import (
    run_tests as dynamo_run_tests,
    TestCase as DynamoTestCase,
)
from torch._functorch import config as functorch_config
from torch._inductor import config
from torch._inductor.utils import fresh_inductor_cache


def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
    dynamo_run_tests(needs)


class TestCase(DynamoTestCase):
    """
    A base TestCase for inductor tests. Enables FX graph caching and isolates
    the cache directory for each test.
    """

    def setUp(self) -> None:
        super().setUp()
        self._inductor_test_stack = contextlib.ExitStack()
        self._inductor_test_stack.enter_context(
            functorch_config.patch(
                {
                    "enable_autograd_cache": True,
                }
            )
        )

        if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
            self._inductor_test_stack.enter_context(
                config.patch({"fx_graph_cache": True})
            )

        if (
            os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1"
            and os.environ.get("TORCH_COMPILE_DEBUG") != "1"
        ):
            self._inductor_test_stack.enter_context(fresh_inductor_cache())

    def tearDown(self) -> None:
        super().tearDown()
        self._inductor_test_stack.close()
