# mypy: allow-untyped-defs
import functools
import itertools
import logging
from collections.abc import Iterable, Sequence
from typing import Any, Callable, cast, Optional, Union

import sympy
from sympy import Expr

from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, ShapeEnv
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, IntInfinity, ValueRanges

from .runtime.runtime_utils import is_power_of_2
from .utils import (
    has_free_symbols,
    sympy_index_symbol,
    sympy_index_symbol_with_prefix,
    sympy_subs,
    VarRanges,
)
from .virtualized import V


log = logging.getLogger(__name__)


def evaluate_expr(
    shape_env: ShapeEnv,
    expr: Union[sympy.Basic, bool],
    axioms: Optional[tuple[sympy.Expr]] = None,
    var_to_range: Optional[tuple[tuple[sympy.Symbol, ValueRanges[Any]]]] = None,
) -> bool:
    if expr in (True, False):
        return bool(expr)

    try:
        simplified = shape_env._maybe_evaluate_static(
            expr,
            axioms=axioms,
            var_to_range=var_to_range,
        )
        if simplified is not None:
            return bool(simplified)
    except Exception:
        log.debug("Could not simplify  %s", expr, exc_info=True)

    return False


# This class is a little awkward, because ShapeEnv is doing most of the heavy
# lifting and in some cases we should be directly passing through to ShapeEnv,
# but there is some extra inductor logic that needs to be handled here
class SizeVarAllocator:
    def __init__(self, shape_env=None) -> None:
        super().__init__()
        if shape_env is None:
            shape_env = ShapeEnv()
        self.shape_env = shape_env
        self.var_to_val = self.shape_env.var_to_val
        self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements
        # Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
        # The basic idea is if we have some complicated sympy expression
        # f(s0), we may choose to precompute it on the host and then replace
        # all occurrences of that sympy expression with ps0, so that when we
        # codegen we simply reference ps0 directly without repeating
        # f(s0).  Unlike regular size variables, ps variables cannot be
        # guarded upon; so if we are asked to guard on a Sympy expression
        # which potentially could have already had a precomputed replacement
        # on it, we are obligated to invert the precomputed replacements
        # (inv_precomputed_replacements).
        self.precomputed_replacements: dict[Expr, sympy.Symbol] = {}
        self.inv_precomputed_replacements: dict[sympy.Symbol, Expr] = {}
        self.stride_vars = self.make_stride_vars_cache()
        self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
        self._simplify_loops = self.make_simplify_loops_cache()

    def simplify(self, expr: Expr):
        return sympy.expand(expr).xreplace(self.replacements)

    def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]:
        """
        self._simplify_with_ranges() can be expensive, cache its results
        """
        cache: dict[tuple[Any, ...], Expr] = {}
        replacement_count = len(self.replacements)

        def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr:
            nonlocal replacement_count
            if replacement_count != len(self.replacements):
                # new replacements invalidates cached results
                cache.clear()
                replacement_count = len(self.replacements)
            key = (expr, *var_ranges.items())
            result = cache.get(key, None)
            if result is None:
                result = self._simplify_with_ranges(expr, var_ranges)
                cache[key] = result
                if result != expr:
                    cache[(result, *var_ranges.items())] = result
            return result

        return simplify_with_ranges

    def make_simplify_loops_cache(self):
        """
        self._simplify_with_ranges() can be expensive, cache its results
        """
        cache: dict[tuple[Any, ...], Any] = {}
        replacement_count = len(self.replacements)

        def simplify_loops(index_vars, sizes, index_formulas):
            nonlocal replacement_count
            if replacement_count != len(self.replacements):
                # new replacements invalidates cached results
                cache.clear()
                replacement_count = len(self.replacements)
            key = (*index_vars, *sizes, *index_formulas)
            result = cache.get(key, None)
            if result is None:
                result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
                cache[key] = result
            return result

        return simplify_loops

    def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr:
        """
        Simplify indexing expression with knowledge of the ranges of
        iteration variables.
        """

        expr = join_dimensions(self.simplify(expr))
        original_expr = expr

        var_to_range = dict(self.shape_env.var_to_range)
        var_to_range.update(
            {
                k: ValueRanges(
                    0, max(0, v - 1) if not has_free_symbols([v]) else IntInfinity()
                )
                for k, v in var_ranges.items()
            }
        )
        for var in expr.free_symbols:
            if var not in var_to_range:
                var_to_range[var] = ValueRanges(0, IntInfinity())

        var_to_range_tuple = cast(
            tuple[tuple[sympy.Symbol, ValueRanges[sympy.Expr]]],
            tuple(var_to_range.items()),
        )

        axioms = []
        for var, upper_bound in var_ranges.items():
            axioms.append(0 <= var)
            axioms.append(var < upper_bound)
        axioms = tuple(axioms) + self.shape_env.get_axioms()

        def statically_known(expr):
            evaluated = self.shape_env._maybe_evaluate_static(
                expr,
                axioms=axioms,
                var_to_range=var_to_range_tuple,
            )
            return bool(evaluated)

        def remove_zero_terms(base, divisor):
            """Symbols smaller than the divisor are zero"""
            if not statically_known(base >= 0):
                return base

            for v in base.free_symbols:
                if v in var_ranges:
                    # var smaller than divisor can be removed
                    # if the rest is guaranteed to be multiple of divisor
                    rest = sympy.Wild("_rest", exclude=[v])
                    m = base.match(v + rest)
                    if m and v not in m[rest].free_symbols:
                        gcd = sympy.gcd(m[rest], divisor)
                        if gcd == divisor:
                            if statically_known(v < divisor):
                                base = m[rest]
            return base

        def visit_indexing_div(base, divisor):
            return FloorDiv(remove_zero_terms(base, divisor), divisor)

        def visit_modular_indexing(base, divisor, modulus):
            base = remove_zero_terms(base, divisor)

            can_remove_mod = statically_known(base >= 0) and statically_known(
                base < modulus * divisor
            )

            if can_remove_mod:
                return FloorDiv(base, divisor)
            return ModularIndexing(base, divisor, modulus)

        if expr.has(ModularIndexing):
            expr = expr.replace(
                ModularIndexing(
                    sympy.Wild("base", integer=True),
                    sympy.Wild("divisor", integer=True),
                    sympy.Wild("modulus", integer=True),
                ),
                visit_modular_indexing,
            )

        if expr.has(FloorDiv):
            expr = expr.replace(
                FloorDiv(
                    sympy.Wild("base", integer=True),
                    sympy.Wild("divisor", integer=True),
                ),
                visit_indexing_div,
            )

        if expr != original_expr:
            return self._simplify_with_ranges(expr, var_ranges)
        return expr

    def _simplify_loops_impl(
        self, index_vars: list[sympy.Symbol], sizes, index_formulas
    ):
        """
        Try to remove as many axis from loop iterations as possible, by:
            1) removing size==1 dimensions
            2) fuse contiguous dimensions into a single loop
            If channel_last = True, we will prevent the last dim fused with other dims
        """
        sizes = list(map(self.simplify, sizes))

        strides = [
            # index_formulas may contain boolean expressions (e.g. s0 < 10),
            # for which "strides" don't make sense so we ignore them here.
            # NOTE: These expressions may still block merging dims in the sound
            # substitution test performed in can_merge_dims.
            (
                self.stride_vars(x, index_vars)
                if isinstance(x, sympy.Expr)
                else [0] * len(index_vars)
            )
            for x in index_formulas
        ]
        assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))

        for i in range(len(sizes)):
            if sizes[i] == 1:
                # remove dim
                sizes[i] = None

        def can_merge_dims(a, b):
            for k in range(len(strides)):
                if self.simplify(strides[k][a] * sizes[a]) == self.simplify(
                    strides[k][b]
                ):
                    # approximate test passed, try sound version
                    va = index_vars[a]
                    vb = index_vars[b]
                    m1 = sympy_index_symbol("_merge_tester1")
                    m2 = sympy_index_symbol("_merge_tester2")
                    # NOTE: can't sub vb=0 here in case va * vb appears in the expression,
                    # in which case both expr1 and expr2 would be zero!
                    expr1 = sympy_subs(index_formulas[k], {va: m1 * sizes[a], vb: m2})
                    expr2 = sympy_subs(index_formulas[k], {va: 0, vb: (m1 + m2)})
                    if self.simplify(expr1) == self.simplify(expr2):
                        continue
                return False
            return True

        changed = True
        while changed:
            changed = False
            for i, j in itertools.product(
                reversed(range(len(sizes))), reversed(range(len(sizes)))
            ):
                if i == j or sizes[i] is None or sizes[j] is None:
                    continue
                if can_merge_dims(i, j):
                    changed = True
                    sizes[i] = sizes[i] * sizes[j]
                    sizes[j] = None

        def reindex(index):
            it = list(reversed(index))
            new_index = []
            for size in sizes:
                if size is None:
                    new_index.append(sympy.S.Zero)
                else:
                    new_index.append(it.pop())
            assert not it
            return new_index

        def prune(index):
            assert len(index) == len(sizes)
            return [i for i, s in zip(index, sizes) if s is not None]

        return [x for x in sizes if x is not None], reindex, prune

    # Note - [On Statically Known]
    #
    # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system
    # operated by providing essentially a question, where the size hinted values were evaluated. If the condition was
    # true, we add a guard and return True, otherwise, False.
    #
    # def maybe_guard_foo(args):
    #   if size_hinted_check(args):
    #       return False # No guard, no optim
    #   guard(args) # Make a guard
    #   return True # Safe to apply optimization
    #
    # The prior system incurred a guard, and green lit an optimization.
    #
    # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the
    # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we
    # return False.
    #
    # def maybe_guard_foo(args):
    #   if all_static(args):
    #       return True # Safe to apply optimization
    #   else:
    #       return False # No guard, no optim

    # See Note - [On Statically Known]

    def is_expr_static_and_true(self, expr: Union[sympy.Basic, bool]) -> bool:
        return evaluate_expr(self.shape_env, expr)

    def statically_known_equals(
        self, left: Union[Expr, int], right: Union[Expr, int]
    ) -> bool:
        """
        Returns a bool indicating if it is sound to optimize as if left and right are equal.
        """
        return self.is_expr_static_and_true(sympy.Eq(left, right))  # type: ignore[arg-type]

    # See Note - [On Statically Known]
    def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> bool:
        """
        Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
        """
        return len(left) == len(right) and all(
            self.statically_known_equals(l, r) for l, r in zip(left, right)
        )

    # See Note - [On Statically Known]
    def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool:
        """
        Returns a bool indicating if it is sound to optimize as if left is less than or equal to right.
        """
        expr = left <= right
        return self.is_expr_static_and_true(expr)

    # See Note - [On Statically Known]
    def statically_known_geq(self, left: Expr, right: Union[Expr, int]) -> bool:
        """
        Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right.
        """
        expr = left >= right
        return self.is_expr_static_and_true(expr)

    # See Note - [On Statically Known]
    def statically_known_lt(self, left: Expr, right: Union[Expr, int]) -> bool:
        """
        Returns a bool indicating if it is sound to optimize as if left is less than right.
        """
        expr = left < right
        return self.is_expr_static_and_true(expr)

    # See Note - [On Statically Known]
    def statically_known_gt(self, left: Expr, right: Union[Expr, int]) -> bool:
        """
        Returns a bool indicating if it is sound to optimize as if left is greater than right.
        """
        expr = left > right
        return self.is_expr_static_and_true(expr)

    # See Note - [On Statically Known]
    def statically_known_multiple_of(
        self, numerator: Expr, denominator: Union[Expr, int]
    ) -> bool:
        """
        Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
        """
        if free_unbacked_symbols(numerator) or free_unbacked_symbols(denominator):
            return False
        expr = sympy.Eq(numerator % denominator, 0)
        return self.is_expr_static_and_true(expr)  # type: ignore[arg-type]

    # See Note - [On Statically Known]
    def statically_known_power_of_2(self, expr: Expr) -> bool:
        """
        Returns a bool indicating if x is known to be a power of 2.
        """
        return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr))

    # The guard functions require you to ALREADY KNOW that a particular
    # condition holds.  If you don't know (you want to guard on an expression
    # being a particular value, and then get access to that value), use
    # the evaluate functions.

    def guard_equals(self, left: Expr, right: Expr) -> Expr:
        if isinstance(left, Expr):
            left = sympy_subs(left, self.inv_precomputed_replacements)  # type: ignore[arg-type]
        if isinstance(right, Expr):
            right = sympy_subs(right, self.inv_precomputed_replacements)  # type: ignore[arg-type]

        expr = sympy.Eq(left, right)
        static_expr = self.shape_env._maybe_evaluate_static(expr)

        if static_expr is not None:
            assert bool(static_expr)
            return left

        assert self.shape_env.defer_runtime_assert(expr, "guard_equals")
        return left

    def guard_leq(self, left: Expr, right: Expr) -> None:
        return self.guard_lt(left, right + 1)

    def guard_lt(self, left: Expr, right: Expr) -> None:
        expr = sympy.Lt(left, right)
        static_expr = self.shape_env._maybe_evaluate_static(expr)

        if static_expr is not None:
            assert bool(static_expr)
            return

        assert self.shape_env.defer_runtime_assert(expr, "guard_lt")

    def guarded_order(self, seq):
        """
        Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing.
        """
        seq = [*map(self.remove_precomputed_replacements, seq)]
        seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)]
        seq.sort()
        order = [-1] * len(seq)
        last_var = None
        for new_index, (_, orig_index, var) in enumerate(seq):
            order[orig_index] = new_index
            if last_var is not None:
                self.guard_leq(last_var, var)
            last_var = var
        return order

    # The evaluate functions evaluate some symbolic sympy expression
    # (NB: not necessarily an Expr) and return what the concrete result
    # is, guarding on the expression being that result

    # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b)
    # as this will ensure that you actually have a sympy'ified expression,
    # and will prevent you from incorrectly writing evaluate_expr(a == b)
    # which does the wrong thing if a or b is a sympy expression
    def evaluate_expr(
        self,
        left: Union[Expr, sympy.logic.boolalg.Boolean],
        size_oblivious: bool = False,
    ) -> bool:
        assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left)
        return self.shape_env.evaluate_expr(
            sympy.sympify(left), size_oblivious=size_oblivious
        )

    def evaluate_min(self, left: Expr, right: Expr) -> Expr:
        """return the smaller of left and right, and guard on that choice"""
        if isinstance(left, Expr):
            left = sympy_subs(left, self.inv_precomputed_replacements)  # type: ignore[arg-type]
        if isinstance(right, Expr):
            right = sympy_subs(right, self.inv_precomputed_replacements)  # type: ignore[arg-type]
        try:
            lv = self.size_hint(left)
            rv = self.size_hint(right)
        except TypeError:  # unbacked symints
            if left == right or self.statically_known_leq(left, right):
                return left
            if self.statically_known_leq(right, left):
                return right
            gcd = sympy.gcd(left, right)
            if left == gcd:  # handle `min(10*u0, u0)` etc
                return left
            if right == gcd:
                return right
            raise TypeError(
                f"evaluate_min({left}, {right}) with unbacked symints"
            ) from None
        if lv <= rv:
            self.guard_leq(left, right)
            return left
        else:
            self.guard_leq(right, left)
            return right

    def evaluate_max(self, left: Expr, right: Expr) -> Expr:
        """return the larger of left and right, and guard on that choice"""
        # Always choose the opposite of eval min for consistency
        # This means min(a, b) and max(a, b) produce the same guards
        min_val = self.evaluate_min(left, right)
        return right if min_val is left else left

    def evaluate_static_shape(self, left: Union[Expr, int]) -> int:
        if isinstance(left, int):
            return left
        right = self.size_hint(left)
        self.guard_equals(left, sympy.Integer(right))
        return int(right)

    def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]:
        return [self.evaluate_static_shape(x) for x in left]

    def remove_precomputed_replacements(self, expr: Expr) -> Expr:
        if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols):  # type: ignore[attr-defined]
            return sympy_subs(expr, self.inv_precomputed_replacements)  # type: ignore[arg-type]
        return expr

    def symbolic_hint(self, expr: Union[Expr, int]) -> Union[Expr, int]:
        if isinstance(expr, int):
            return expr
        # Substitute all hints into expr, but leave unbacked symints alone
        expr = self.simplify(expr)
        if not isinstance(expr, Expr):
            assert isinstance(expr, int)
            return expr
        free_symbols = expr.free_symbols
        if not free_symbols:
            try:
                return int(expr)  # type: ignore[return-value]
            except TypeError:
                return expr  # inf/nan/I
        expr = self.remove_precomputed_replacements(expr)
        return sympy_subs(expr, self.var_to_val)

    def size_hint(
        self, expr: Union[Expr, int], *, fallback: Optional[int] = None
    ) -> int:
        out = self.symbolic_hint(expr)
        if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
            # Use the provided heuristic fallback hint
            unbacked_sym_vrs = {
                s: self.shape_env.var_to_range.get(s, None) for s in out.free_symbols
            }
            if all(vr is not None for vr in unbacked_sym_vrs.values()):
                hint_vr = bound_sympy(out, unbacked_sym_vrs)  # type: ignore[arg-type]
                if isinstance(hint_vr.lower, (int, sympy.Integer)):
                    fallback = max(fallback, int(hint_vr.lower))
                if isinstance(hint_vr.upper, (int, sympy.Integer)):
                    fallback = min(fallback, int(hint_vr.upper))
            return fallback

        try:
            return int(out)
        except Exception:
            log.debug("failed on: %s", out)
            raise

    def size_hints(
        self,
        exprs: Iterable[Expr],
        *,
        fallback: Optional[int] = None,
    ) -> tuple[int, ...]:
        return tuple(self.size_hint(x, fallback=fallback) for x in exprs)

    def _lru_cache(self, fn, maxsize=None):
        """
        Wrapper around functools.lru_cache that clears when replacements
        has been invalidated.
        """
        fn_cache = functools.lru_cache(maxsize)(fn)
        prior_len = len(self.replacements)

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            nonlocal prior_len
            if prior_len != len(self.replacements):
                prior_len = len(self.replacements)
                fn_cache.cache_clear()
            return fn_cache(*args, **kwargs)

        return wrapper

    def make_stride_vars_cache(self):
        cache = self._lru_cache(self._stride_vars)

        def stride_vars(
            index: Expr,
            vars: Sequence[sympy.Symbol],
            support_vars: Optional[Sequence[sympy.Symbol]] = None,
        ) -> list[Expr]:
            if not support_vars:
                support_vars = vars
            return cache(index, tuple(vars), tuple(support_vars))

        return stride_vars

    def _stride_vars(
        self,
        index: Expr,
        vars: Sequence[sympy.Symbol],
        support_vars: Sequence[sympy.Symbol],
    ) -> list[Expr]:
        """Convert an indexing expression back into strides

        NOTE: This is only valid if the index is a standard strided offset
        calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a
        stride of -10 because the index wraps around after the first element

        """
        strides = []
        index = self.simplify(index)
        # remove any offset
        index = index - sympy_subs(
            index, {v: sympy.S.Zero for v in support_vars if v != 0}
        )
        for i in range(len(vars)):
            # drop all the other dims
            index_dim = sympy_subs(
                index,
                {
                    support_vars[j]: sympy.S.Zero
                    for j in range(len(support_vars))
                    if vars[i] != support_vars[j] and support_vars[j] != 0
                },
            )
            v = vars[i]
            if v == 0:
                strides.append(sympy.S.Zero)
            else:
                # TODO(jansel): should we use sympy.diff here?
                strides.append(
                    sympy_subs(index_dim, {v: sympy.S.One})
                    - sympy_subs(index_dim, {v: sympy.S.Zero})
                )
        return strides

    def atomically_apply_size_hint(
        self, expr: Union[Expr, int], *, fallback: Optional[int] = None
    ) -> Union[Expr, int]:
        if isinstance(expr, int):
            return int(expr)

        # For multiple expressions that depend on an unbacked symint,
        # we want to compute them consistently for a size hint we have chosen.
        # So, recursively compute expressions via size hints of contained symbols.
        # For example: u1 * u2 - 10 ==> fallback * fallback - 10
        assert isinstance(expr, Expr), type(expr)
        free_symbols = expr.free_symbols
        size_dict = {
            symbol: V.graph.sizevars.size_hint(symbol, fallback=fallback)
            for symbol in free_symbols
        }
        return expr.subs(size_dict)

    def offset_var(self, index: Expr, vars: list[sympy.Symbol]) -> Expr:
        """Extract offset part of an indexing expression"""
        index = self.simplify(index)
        return sympy_subs(index, {v: sympy.S.Zero for v in vars if v != 0})

    def stride_hints(
        self,
        index: Expr,
        vars: Sequence[sympy.Symbol],
        support_vars: Optional[Sequence[sympy.Symbol]] = None,
    ) -> list[int]:
        for v in index.free_symbols:
            if symbol_is_type(v, SymT.INDIRECT):  # type: ignore[attr-defined]
                index = sympy_subs(index, {v: 0})  # type: ignore[dict-item]
        result = []
        for s in self.stride_vars(index, vars, support_vars):
            try:
                result.append(self.size_hint(s))
            except TypeError:
                result.append(0)
        return result

    def stride_order(self, index: Expr, vars: list[sympy.Symbol]) -> list[int]:
        strides = tuple(map(abs, self.stride_hints(index, vars)))
        order = list(range(len(strides)))
        order.sort(key=lambda x: (strides[x] == 0, strides[x]))
        return order

    def lookup_precomputed_size(self, expr: Expr) -> Expr:
        if (
            isinstance(expr, (int, sympy.Symbol, sympy.Number))
            or expr.is_number
            or expr.is_symbol
        ):
            return expr
        expr = self.remove_precomputed_replacements(expr)
        if expr not in self.precomputed_replacements:
            sym = sympy_index_symbol_with_prefix(
                SymT.PRECOMPUTED_SIZE, len(self.precomputed_replacements)
            )
            self.precomputed_replacements[expr] = sym
            self.inv_precomputed_replacements[sym] = expr
        return self.precomputed_replacements[expr]

    def free_symbols(self) -> OrderedSet[sympy.Symbol]:
        return OrderedSet(self.var_to_val.keys()) - OrderedSet(self.replacements.keys())

    def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr:
        """
        A pair of special ModularIndexing can be combined.

        E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b)
        We can simplify this to ModuleIndexing(x, 1, b), if
        1. x is non negative integer
        2. a and b are positive integers
        3. a is a multiple of b.
        """

        def _check_args(x, div, mod, is_first):
            if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer):
                return False
            if div != 1:
                return False
            if mod <= 0:
                return False

            if is_first:
                # first ModularIndexing should conatins a nested ModularIndex
                if not isinstance(x, ModularIndexing):
                    return False
            else:
                # second ModularIndexing should constains a non-negative
                # symbol
                if not isinstance(x, sympy.Symbol) or not self.statically_known_geq(
                    x, 0
                ):
                    return False
            return True

        if isinstance(index, ModularIndexing):
            x, div, mod = index.args

            if not _check_args(x, div, mod, True):
                return index

            x2, div2, mod2 = x.args

            if not _check_args(x2, div2, mod2, False):
                return index

            if mod2 % mod != 0:
                return index

            return ModularIndexing(x2, 1, mod)

        return index

    def expand_floor_div(
        self, index: sympy.Expr
    ) -> Union[bool, tuple[sympy.Expr, sympy.Expr]]:
        """
        Expand the FloorDiv to the entire expression so that the expression may
        be simplfied.

        E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables
        x1, x2, index expression 'x1 * 2b + x2' can be easily combined.
        But index expression 'x1 * b + x2 // 2' can not.
        By expanding the FloorDiv to the entire expression, we get
        '(x1 * 2b + x2) // 2'. This transformation allows us to merge loops
        for the numerator!

        Return false if this optimization can be applied;
        Return the new expression and the denominator otherwise.
        The original expression will be equivalent to 'new_expression // denominator'
        """
        if not isinstance(index, sympy.Add):
            return False
        terms = index.args

        if len(terms) < 2:
            return False
        floor_div_index = -1
        varlist = []
        factorlist = []
        for idx, term in enumerate(terms):
            if isinstance(term, sympy.Mul):
                # For dynamic shape, term like '2*s1*x1' has 3 child nodes.
                # - A integer for 2
                # - A symbol for s1
                # - A symbol for x1
                # Skip for now.
                if len(term.args) != 2:
                    return False
                factor, var = term.args
                varlist.append(var)
                factorlist.append(factor)
                if not isinstance(factor, sympy.Integer) or not isinstance(
                    var, sympy.Symbol
                ):
                    return False
                # It's easier to reason about the correceness of the transformation
                # for non-negative integers.
                if not self.statically_known_geq(var, 0):
                    return False
            elif isinstance(term, FloorDiv):
                var, factor = term.args
                if not isinstance(factor, sympy.Integer) or not isinstance(
                    var, sympy.Symbol
                ):
                    return False
                if not self.statically_known_geq(var, 0):
                    return False
                if floor_div_index >= 0:
                    # can not handle multi FloorDiv yet
                    return False

                floor_div_index = idx
                varlist.append(var)
                # this factor is denominator
                factorlist.append(factor)
            else:
                return False

        if floor_div_index < 0:
            return False

        # Construct the new expression and remember the denominator
        denominator = factorlist[floor_div_index]
        new_index = sympy.S.Zero

        for var, factor, idx in zip(varlist, factorlist, itertools.count()):
            if idx == floor_div_index:
                new_index += var
            else:
                new_index += (factor * denominator) * var

        return new_index, denominator


def join_dimensions(expr: Expr) -> Expr:
    if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing):
        return expr  # fast exit path
    return _join_dimensions_cached(expr)


@functools.lru_cache(256)
def _join_dimensions_cached(expr: Expr) -> Expr:
    """
    ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
    becomes
    ModularIndexing(i0, 1, 128)
    ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32)
    becomes i0


    This type of pattern can come from view operations
    """
    assert isinstance(expr, sympy.Add)

    scale = sympy.Wild("scale", exclude=[0], integer=True)
    base = sympy.Wild("base", integer=True)
    divisor = sympy.Wild("divisor", integer=True)
    mod1 = sympy.Wild("modulus", integer=True)
    mod2 = sympy.Wild("modulus2", integer=True)
    for term1 in expr.args:
        m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
        if m1:
            for term2 in expr.args:
                m2 = term2.match(
                    m1[scale]
                    * m1[mod1]
                    * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2)
                )
                if m2 and term1 != term2:
                    expr = join_dimensions(
                        expr
                        - term1
                        - term2
                        + m1[scale]
                        * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2])
                    )
                    return expr
    for term1 in expr.args:
        m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
        if m1:
            for term2 in expr.args:
                m2 = term2.match(
                    m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1])
                )
                if m2 is not None:  # in case of success we get an empty dict here
                    expr = join_dimensions(
                        expr
                        - term1
                        - term2
                        + m1[scale] * FloorDiv(m1[base], m1[divisor])
                    )
                    return expr
    return expr


class SimplifyIndexing(V.WrapperHandler):  # type: ignore[name-defined]
    """
    A wrapper around .virtualize.ops that uses var range information to
    simplify ModularIndexing/FloorDiv.
    """

    def __init__(self, inner, var_ranges: VarRanges) -> None:
        super().__init__(inner)
        self.name = "SimplifyIndexing"
        self._simplify: Callable[[Expr], Expr] = (
            lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
        )

    def load(self, name: str, index: sympy.Expr):
        return self._inner.load(name, self._simplify(index))

    def store(self, name, index, value, mode=None):
        return self._inner.store(name, self._simplify(index), value, mode=mode)

    def store_reduction(self, name, index, value):
        return self._inner.store_reduction(name, self._simplify(index), value)

    def index_expr(self, index, dtype):
        return self._inner.index_expr(self._simplify(index), dtype)

    def check_bounds(self, index, size, lower, upper):
        return self._inner.check_bounds(self._simplify(index), size, lower, upper)
