# mypy: allow-untyped-defs

from __future__ import annotations


"""
This file does three things:
- Contains the definition of SymNode
- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
- Does not depend on sympy at import time

As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
to avoid having to load SymPy at import time, as doing so is *very* slow.
"""


import builtins
import functools
import inspect
import itertools
import logging
import math
import operator
import sys
from functools import lru_cache, update_wrapper
from typing import Optional, TYPE_CHECKING, Union

import torch
import torch._logging.structured as structured

# NB: The sym_* functions are used via getattr() and must be imported here.
from torch import (  # noqa: F401
    sym_float,
    sym_ite,
    sym_max,
    sym_min,
    sym_not,
    SymBool,
    SymFloat,
    SymInt,
)
from torch._logging import dtrace_structured


if TYPE_CHECKING:
    from torch.fx.experimental.symbolic_shapes import ShapeEnv

log = logging.getLogger(__name__)
sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")


__all__ = ["SymNode", "method_to_operator", "magic_methods"]


from torch.types import py_sym_types as SymTypes


def _to_symtype(t):
    if t is bool:
        return SymBool
    if t is int:
        return SymInt
    if t is float:
        return SymFloat
    return t


# TODO: An incomplete list
# 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction
class SymNode:
    """
    This is a type erased SymInt/SymFloat which we use to do actual operations.
    End users don't touch this.  Magic methods are NOT defined on this object.
    """

    # Note [optimized_summation]: indicates that SymNode is an Add expression of the form
    # a + b + c + d... etc where all terms are unique symbols. This allows us to do some optimizations
    # for common patterns see _optimized_add.

    # The unfortunate reason we have this here is because sympy sets  __slots__ = () for add expression,
    # so we cannot add the attribute directly to the sympy expression. Furthermore, we cannot use it as
    # a weak dictionary key either! So instead, we attach the attribute here to the SymNode.
    _optimized_summation: bool = False

    def __init__(
        self,
        expr,
        shape_env,
        pytype,
        hint: Optional[Union[int, float, bool]],
        constant=None,
        fx_node=None,
        optimized_summation=False,
    ):
        self._expr = expr
        self.shape_env = shape_env
        self.pytype = pytype
        self._optimized_summation = optimized_summation

        # What's the difference between hint and constant?
        #
        # - A constant is known to be invariant across invocations of the model;
        #   it will always be this value.  We only really know this when we
        #   encounter an honest-to-goodness literal (when wrapping it into
        #   a SymNode, we set constant.)  Most of the time, constant is None
        #
        # - A hint is a *particular* value from the particular run we are
        #   tracing, but it may vary the next time around.  It's useful to
        #   keep this around, as if we need a concrete value from a SymNode,
        #   we will return the hint and guard on the expression that produced
        #   it giving the same hint next time around.  The hint is not
        #   guaranteed to be set either: if you have an unbacked SymNode,
        #   there won't be any hint; it was the result of some tensor-dependent
        #   computation, but we don't know what it actually is because we
        #   haven't actually run the tensor computation.
        #
        # If _hint is None, we will query maybe_evaluate_static(compute_hint=True)
        # in hopes that we've learned enough about the unbacked symints to
        # discharge the hint; otherwise, you're likely to just error out.
        #
        # (A previous version of this system had some optimizations to only
        # recompute when it was possible we had learned enough about the
        # unbacked symint that a hint was now possible, but as we added more
        # potential refinements to unbacked symints this got harder to keep
        # in sync, so we've deleted it for now.)

        def compute_hint():
            from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols

            # This occasionally gets exercised by, e.g.,
            # convert_shape_to_symint.  It's just a nicety so you don't HAVE
            # to have a correct hint on hand when making a SymNode.
            # Don't attempt to compute for unbacked, this can be quite
            # expensive.
            if has_free_unbacked_symbols(self.expr):
                return None
            hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
            if hint is not None:
                hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint
            return hint

        if hint is not None:
            assert type(hint) is pytype or type(hint) is _to_symtype(pytype), (
                "Cannot create SymNode of type "
                f"{pytype} with incompatible hint of type {type(hint)}"
            )
            if self.shape_env and self.shape_env._translation_validation_enabled:
                # This is technically not TV, but this assert is expensive so
                # let's only do it when we're already doing expensive things
                computed_hint = compute_hint()
                assert (
                    hint == computed_hint
                ), f"{hint} != {computed_hint} (for {self.expr})"
        else:
            hint = compute_hint()
        self._hint = hint
        self.constant: Optional[Union[int, float, bool]] = constant

        # Record the FX node of the current node if we are doing translation
        # validation. They will be used for building the input assertions for
        # the translation validation problem.
        tx_validation_en = (
            self.shape_env and self.shape_env._translation_validation_enabled
        )
        self.fx_node = tx_validation_en and fx_node

    def with_shape_env(self, shape_env: ShapeEnv) -> SymNode:
        return SymNode(
            self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
        )

    def _value_eq(self, other: SymNode) -> bool:
        # Purposely don't include the shape_env in the eq.
        return (
            self._expr == other._expr
            and self.pytype == other.pytype
            and self._hint == other._hint
            and self.constant == other.constant
            and self.fx_node == other.fx_node
        )

    def _value_hash(self) -> int:
        # Purposely don't include the shape_env in the hash.
        return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))

    @property
    def expr(self):
        return self.shape_env.replace(self._expr)

    @property
    def hint(self):
        return self._hint

    def has_hint(self):
        return self._hint is not None

    def require_hint(self, fallback=None):
        from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols

        if self._hint is None:
            if fallback is not None:
                # Say we have some expr like 2*u0 + s0
                # The hint will be None, since the expr contains at least 1 unbacked.
                # We will:
                # - replace every backed free symbol with its corresponding hint
                # - replace every unbacked free symbol with the fallback
                # - regenerate the expression with those symbol replacements
                # Note: this is not really complete either, since right now
                # this logic does not take into account any value ranges
                # for the unbacked symints, we may need to beef it up at some point.
                unbacked_symbols = free_unbacked_symbols(self.expr)
                replacements = {
                    s: 4096 if s in unbacked_symbols else self.shape_env.var_to_val[s]
                    for s in self.expr.free_symbols
                }
                return self.expr.xreplace(replacements)
            # NB: we expect this to raise
            return self.shape_env.size_hint(self.expr)
        return self._hint

    def maybe_as_int(self):
        if self.expr.is_number:
            return int(self.expr)
        else:
            return None

    # NB: This does conversions, not sure if this is good or not
    def maybe_as_float(self):
        import sympy

        if isinstance(self.expr, sympy.Float):
            return float(self.expr)
        else:
            return None

    def maybe_as_bool(self):
        import sympy

        if self.expr is sympy.true:
            return True
        elif self.expr is sympy.false:
            return False
        else:
            return None

    def is_int(self):
        return self.pytype is int

    def is_float(self):
        return self.pytype is float

    def is_bool(self):
        return self.pytype is bool

    def is_nested_int(self):
        # Unbacked SymInts cannot be nested int today
        return (
            self._hint is not None
            and isinstance(self._hint, SymInt)
            and self._hint.node.is_nested_int()
        )

    def wrap_int(self, num):
        assert type(num) is int
        import sympy

        return SymNode(
            sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
        )

    def wrap_float(self, num):
        assert type(num) is float
        import sympy

        return SymNode(
            sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
        )

    def wrap_bool(self, num):
        assert type(num) is bool
        import sympy

        return SymNode(
            sympy.true if num else sympy.false,
            self.shape_env,
            bool,
            num,
            constant=num,
            fx_node=num,
        )

    def clone(self):
        return self

    def str(self):
        return f"{self.expr}"

    def __str__(self):
        return self.str()

    def __repr__(self):
        rep = [
            f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}",
        ]
        if self._hint is not None:
            rep.append(f"hint={self._hint}")
        if self.constant is not None:
            rep.append(f"constant={self.constant}")
        if self.fx_node is not None:
            rep.append(f"fx_node={self.fx_node}")
        return ", ".join(rep) + ")"

    def _graph_repr(self) -> builtins.str:
        # Representation used by GraphModule to create a pythonic version of a graph
        return self.str()

    # These methods call the metaprogrammed methods, they're hand written
    # here so we get good stack traces
    def abs(self) -> SymNode:
        return self._abs()  # type: ignore[attr-defined]

    def pos(self) -> SymNode:
        return self._pos()  # type: ignore[attr-defined]

    def round(self, ndigits=None) -> SymNode:
        return self._round(ndigits)  # type: ignore[attr-defined]

    def trunc(self) -> SymNode:
        return self._trunc()  # type: ignore[attr-defined]

    def add(self, other) -> SymNode:
        return self._add(other)  # type: ignore[attr-defined]

    def sub(self, other) -> SymNode:
        return self._sub(other)  # type: ignore[attr-defined]

    def mul(self, other) -> SymNode:
        return self._mul(other)  # type: ignore[attr-defined]

    def mod(self, other) -> SymNode:
        return self._mod(other)  # type: ignore[attr-defined]

    def float_pow(self, other) -> SymNode:
        return self._float_pow(other)  # type: ignore[attr-defined]

    def pow_by_natural(self, other) -> SymNode:
        return self._pow_by_natural(other)  # type: ignore[attr-defined]

    def and_(self, other) -> SymNode:
        return self._and_(other)  # type: ignore[attr-defined]

    def or_(self, other) -> SymNode:
        return self._or_(other)  # type: ignore[attr-defined]

    def float_truediv(self, other) -> SymNode:
        return self._float_truediv(other)  # type: ignore[attr-defined]

    def int_truediv(self, other) -> SymNode:
        return self._int_truediv(other)  # type: ignore[attr-defined]

    def int_floordiv(self, other) -> SymNode:
        return self._int_floordiv(other)  # type: ignore[attr-defined]

    def lshift(self, other) -> SymNode:
        return self._lshift(other)  # type: ignore[attr-defined]

    def rshift(self, other) -> SymNode:
        return self._rshift(other)  # type: ignore[attr-defined]

    def sym_not(self) -> SymNode:  # noqa: F811
        return self._sym_not()  # type: ignore[attr-defined]

    def eq(self, other) -> SymNode:
        return self._eq(other)  # type: ignore[attr-defined]

    def ne(self, other) -> SymNode:
        return self._ne(other)  # type: ignore[attr-defined]

    def gt(self, other) -> SymNode:
        return self._gt(other)  # type: ignore[attr-defined]

    def lt(self, other) -> SymNode:
        return self._lt(other)  # type: ignore[attr-defined]

    def le(self, other) -> SymNode:
        return self._le(other)  # type: ignore[attr-defined]

    def ge(self, other) -> SymNode:
        return self._ge(other)  # type: ignore[attr-defined]

    def floor(self) -> SymNode:
        return self._floor()  # type: ignore[attr-defined]

    def is_integer(self) -> SymNode:
        return self._is_integer()  # type: ignore[attr-defined]

    def sym_float(self) -> SymNode:  # noqa: F811
        return self._sym_float()  # type: ignore[attr-defined]

    def sym_int(self) -> SymNode:
        return self._sym_int()  # type: ignore[attr-defined]

    def ceil(self) -> SymNode:
        return self._ceil()  # type: ignore[attr-defined]

    def neg(self) -> SymNode:
        return self._neg()  # type: ignore[attr-defined]

    def sym_min(self, other) -> SymNode:  # noqa: F811
        return self._sym_min(other)  # type: ignore[attr-defined]

    def sym_max(self, other) -> SymNode:  # noqa: F811
        return self._sym_max(other)  # type: ignore[attr-defined]

    def sym_ite(self, then_val, else_val) -> SymNode:
        return self._sym_ite(then_val, else_val)  # type: ignore[attr-defined]

    def is_contiguous(self, sizes, strides) -> SymNode:
        return self._is_contiguous(sizes, strides)  # type: ignore[attr-defined]

    def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode:
        return self._is_channels_last_contiguous_2d(sizes, strides)  # type: ignore[attr-defined]

    def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode:
        return self._is_channels_last_contiguous_3d(sizes, strides)  # type: ignore[attr-defined]

    def is_channels_last_strides_2d(self, sizes, strides) -> SymNode:
        return self._is_channels_last_strides_2d(sizes, strides)  # type: ignore[attr-defined]

    def is_channels_last_strides_3d(self, sizes, strides) -> SymNode:
        return self._is_channels_last_strides_3d(sizes, strides)  # type: ignore[attr-defined]

    def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode:
        return self._is_non_overlapping_and_dense_indicator(sizes, strides)  # type: ignore[attr-defined]

    # Make C++ happy
    def sym_or(self, other):
        return self.or_(other)

    def sym_and(self, other):
        return self.and_(other)

    # Integer bitwise ops
    def bitwise_and(self, other):
        return self._bitwise_and(other)  # type: ignore[attr-defined]

    def bitwise_or(self, other):
        return self._bitwise_or(other)  # type: ignore[attr-defined]

    # There is no int_truediv available from C++
    def truediv(self, other):
        return self.float_truediv(other)

    def floordiv(self, other) -> SymNode:
        return self.int_floordiv(other)

    # We didn't bind integer pow in C++
    def pow(self, other):
        return self.float_pow(other)

    def is_non_overlapping_and_dense(self, sizes, strides):
        return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1))  # type: ignore[attr-defined]

    def int_(self):
        return self.guard_int("", 0)  # NB: uses Python backtrace

    # This one is currently done by hand, but if we add other variadic
    # functions consider factoring it out to be metaprogrammed too.  Note that
    # some load bearing logic is directly in torch.sym_sum

    def sym_sum(self, args) -> SymNode:
        import sympy

        # Inner impl
        from torch.fx.experimental.proxy_tensor import (
            get_proxy_mode,
            handle_sym_dispatch,
        )

        if get_proxy_mode():
            return to_node(
                self,
                handle_sym_dispatch(
                    torch.sym_sum,
                    (tuple(wrap_node(a) for a in args),),
                    {},
                ),
            )
        exprs = [a.expr for a in args]
        out = sympy.Add(*exprs)

        size_hints = []
        out_hint = None
        for a in args:
            if a.hint is None:
                break
            size_hints.append(a.hint)
        else:
            out_hint = sum(size_hints)

        fx_node, _ = self.shape_env._create_fx_call_function(
            torch.sym_sum, (tuple(a.fx_node for a in args),)
        )

        # NB: Only for integers!
        return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node)

    def evaluate(self, size_oblivious=False):
        return self.shape_env.evaluate_sym_node(self, size_oblivious)

    # You can manually trigger a guard with this function
    def guard_int(self, file, line):
        # TODO: use the file/line for some useful diagnostic on why a
        # guard occurred
        r = self.evaluate()
        try:
            return int(r)
        except Exception:
            log.warning("Failed to convert to int: %s", r)
            raise

    def guard_float(self, file, line):
        # TODO: use the file/line for some useful diagnostic on why a
        # guard occurred
        r = self.evaluate()
        try:
            return float(r)
        except Exception:
            log.warning("Failed to convert to float: %s", r)
            raise

    def guard_bool(self, file, line):
        # TODO: use the file/line for some useful diagnostic on why a
        # guard occurred
        r = self.evaluate()
        try:
            return bool(r)
        except Exception:
            log.warning("Failed to convert to bool: %s", r)
            raise

    def expect_true(self, file, line):
        from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols

        if (
            self.has_hint()
            and not free_unbacked_symbols(self.expr)
            and not self.shape_env.prefer_deferred_runtime_asserts_over_guards
        ):
            # OK to generate guards
            return self.guard_bool(file, line)
        # Generate a deferred runtime assert (this might actually end up doing
        # a regular guard if we can!)
        # TODO: file/line here is very important, because the assert has been
        # deferred so you can't backtrace easily
        return self.shape_env.defer_runtime_assert(
            self.expr, f"{file}:{line}", fx_node=self.fx_node
        )

    def expect_size(self, file, line):
        from torch.fx.experimental.symbolic_shapes import _advise_is_size

        b = self.ge(self.wrap_int(0))
        # Generate a deferred runtime assert
        r = b.expect_true(file, line)
        # Refine compile time range, but only if it's unbacked.
        # If you refine range for hinted variables, you can end up making
        # improper deductions since compile time reasoning may be
        # incompatible with runtime reasoning.
        if r and not self.has_hint():
            _advise_is_size(SymInt(self))
        return r

    def guard_size_oblivious(self, file, line):
        """
        Like guard_bool, but if we encounter unbacked symbols, if those symbols
        are size-like, we will treat them as >= 2 for the purposes of the analysis.

        This CHANGES the runtime semantics, but all size-oblivious sites have been
        audited to ensure that the runtime semantics don't change in a material way.
        Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
        an unbacked one size, or a tensor reporting as non-contiguous even if it's
        contiguous if it would have been reported contiguous due to being empty.
        """
        # TODO: use the file/line for some useful diagnostic on why a
        # guard occurred
        r = self.evaluate(size_oblivious=True)
        try:
            return bool(r)
        except Exception:
            log.warning("Failed to convert to bool: %s", r)
            raise

    def bool_(self):
        return self.guard_bool("", 0)

    def is_symbolic(self):
        return True

    def nested_int(self):
        return None

    def is_constant(self):
        return False


# TODO: this probably needs the sizes-strides eval functions
METHOD_TO_OPERATOR = {
    "pos": operator.pos,
    "abs": operator.abs,
    "add": operator.add,
    "and": operator.and_,
    "bitwise_and": operator.and_,
    "ceil": math.ceil,
    "eq": operator.eq,
    "floor": math.floor,
    "trunc": math.trunc,
    "int_floordiv": operator.floordiv,
    "ge": operator.ge,
    "gt": operator.gt,
    "is_integer": lambda x: x.is_integer(),
    "le": operator.le,
    "lshift": operator.lshift,
    "lt": operator.lt,
    "mod": operator.mod,
    "mul": operator.mul,
    "ne": operator.ne,
    "neg": operator.neg,
    "or": operator.or_,
    "bitwise_or": operator.or_,
    "float_pow": operator.pow,
    "pow_by_natural": operator.pow,
    "round": builtins.round,
    "rshift": operator.rshift,
    "sub": operator.sub,
    "sym_float": sym_float,
    "sym_ite": sym_ite,
    "sym_max": sym_max,
    "sym_min": sym_min,
    "sym_not": sym_not,
    "float_truediv": operator.truediv,
    "int_truediv": operator.truediv,
}

unary_magic_methods = {
    "abs",
    "sym_float",
    "sym_int",
    "ceil",
    "floor",
    "neg",
    "sym_not",
    "pos",
    "trunc",
}


# Adding math ops: sqrt, cos, sin, ...
def _get_sym_node_fn(name):
    def fn(self):
        return getattr(self, f"_sym_{name}")()

    return fn


math_op_names = (
    "sqrt",
    "cos",
    "cosh",
    "sin",
    "sinh",
    "tan",
    "tanh",
    "asin",
    "acos",
    "atan",
    "log2",
)
for name in math_op_names:
    sym_name = f"sym_{name}"
    priv_sym_name = f"_{sym_name}"
    setattr(SymNode, sym_name, _get_sym_node_fn(name))
    METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
    unary_magic_methods.add(sym_name)
    __all__.append(sym_name)


# Unary methods that are not magic methods
unary_nonmagic_methods = {
    "is_integer",
}

unary_methods = unary_magic_methods | unary_nonmagic_methods

# Most methods are only registered on SymInt and SymFloat
# Some methods are only be registered on SymBool
only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
# Methods that implicitly convert SymBool into SymInt
bool_becomes_int_magic_methods = {"add", "sub", "mul"}
# Methods that are also on SymBool, in addition to on SymInt and SymFloat
also_bool_magic_methods = {"eq"}
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods

# Methods that are only for float
only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"}


magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
# remap necessary because an op name can have a bitwise and boolean implementation
bitwise_ops = {
    "bitwise_and": "and",
    "bitwise_or": "or",
}


always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}

for name in math_op_names:
    sym_name = f"sym_{name}"
    always_float_magic_methods.add(sym_name)


always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"}
always_bool_magic_methods = {
    "eq",
    "ne",
    "gt",
    "lt",
    "le",
    "ge",
    "and",
    "or",
    "sym_not",
    "is_non_overlapping_and_dense",
    "is_integer",
}

# Methods that have a `__foo__` as well as `__rfoo__`


def _sympy_float_truediv(a, b):
    from torch.utils._sympy.functions import FloatTrueDiv

    return FloatTrueDiv(a, b)


def _sympy_int_truediv(a, b):
    from torch.utils._sympy.functions import IntTrueDiv

    return IntTrueDiv(a, b)


def _sympy_floordiv(a, b):
    from torch.utils._sympy.functions import FloorDiv

    return FloorDiv(a, b)


def _sympy_mod(a, b):
    from torch.utils._sympy.functions import Mod, PythonMod

    if a.is_nonnegative and b.is_nonnegative:
        return Mod(a, b)
    else:
        return PythonMod(a, b)


def _sympy_pow_by_natural(a, b):
    from torch.utils._sympy.functions import PowByNatural

    return PowByNatural(a, b)


def _sympy_float_pow(a, b):
    from torch.utils._sympy.functions import FloatPow

    return FloatPow(a, b)


def _sympy_and(a, b):
    import sympy

    return sympy.And(a, b)


def _sympy_or(a, b):
    import sympy

    return sympy.Or(a, b)


def _sympy_lshift(a, b):
    from torch.utils._sympy.functions import LShift

    return LShift(a, b)


def _sympy_rshift(a, b):
    from torch.utils._sympy.functions import RShift

    return RShift(a, b)


def _binary_search_insert_arg(ordered_args, new_arg):
    """
    If new_arg is found in ordered_args None is returned, else the new
    ordered_args with new_arg inserted
    """
    if len(ordered_args) == 0:
        return [new_arg]

    from sympy.core.basic import _args_sortkey as sort_key, Basic

    # Fast path when new_arg > ordered_args[-1].
    if sort_key(ordered_args[-1]) < sort_key(new_arg):
        return ordered_args + [new_arg]

    # Fast path when new_arg < ordered_args[0].
    if sort_key(ordered_args[0]) > sort_key(new_arg):
        return [new_arg] + ordered_args

    low, high = 0, len(ordered_args) - 1

    while low <= high:
        mid = (low + high) // 2
        compare_result = Basic.compare(ordered_args[mid], new_arg)
        if compare_result == 0:
            return None
        elif compare_result < 0:
            low = mid + 1
        else:
            high = mid - 1

    ordered_args.insert(low, new_arg)
    return ordered_args


def _optimized_add(
    lhs, rhs, lhs_is_optimized_summation=False, rhs_is_optimized_summation=False
):
    """
    Custom optimization for Add used to optimize incremental binary summations of certain properties. The idea
    is when we know the expression is a summation of unique symbols all we need to know is the correct order of symbols,
    and no other optimizations are needed. We pass evaluate=false, with the correct order of args and save the following.
    1. Avoid running other optimizations when the Add is constructed.
    2. Manually figure out the order of the args for the new expression in log(n) comparisons instead of nLog(n)
    (comparing terms is expensive and shows in the profiles).
    The function returns a tuple of (1) a boolean that indicates whether the output is a summation of unique symbols,
    (2) the result sympy expression.
    """
    import sympy
    from sympy.core.basic import _args_sortkey as sortkey

    def make_optimized(ordered_args):
        result = sympy.Add(*ordered_args, evaluate=False)
        return (True, result)

    from torch.utils._sympy.functions import _is_symbols_binary_summation

    lhs_is_optimized_summation |= _is_symbols_binary_summation(lhs)
    rhs_is_optimized_summation |= _is_symbols_binary_summation(rhs)

    if lhs_is_optimized_summation and rhs_is_optimized_summation:
        # (a0+a1..) + (a2+a3..) => (a0+a1+a2+a3)
        if sortkey(lhs._args[-1]) < sortkey(rhs._args[0]):
            return make_optimized(lhs._args + rhs._args)
        #  (a2+a3..) + (a0+a1..) => (a0+a1+a2+a3)
        if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]):
            return make_optimized(rhs._args + lhs._args)

    # (a0+a2) + a1 => (a0+a1+a2)
    if lhs_is_optimized_summation and rhs.is_symbol:
        new_args = _binary_search_insert_arg(list(lhs._args), rhs)
        if new_args is not None:
            return make_optimized(new_args)

    # a1 + (a0+a2)=> (a0+a1+a2)
    if rhs_is_optimized_summation and lhs.is_symbol:
        new_args = _binary_search_insert_arg(list(rhs._args), lhs)
        if new_args is not None:
            return make_optimized(new_args)

    result = sympy.Add(lhs, rhs)
    return (_is_symbols_binary_summation(result), result)


def _bitwise_and(a, b):
    from torch.utils._sympy.functions import BitwiseFn_bitwise_and

    return BitwiseFn_bitwise_and(a, b)


def _bitwise_or(a, b):
    from torch.utils._sympy.functions import BitwiseFn_bitwise_or

    return BitwiseFn_bitwise_or(a, b)


reflectable_magic_methods = {
    "add": _optimized_add,
    "sub": operator.sub,
    "mul": operator.mul,
    "mod": _sympy_mod,
    "pow_by_natural": _sympy_pow_by_natural,
    "float_pow": _sympy_float_pow,
    "and": _sympy_and,
    "bitwise_and": _bitwise_and,
    "or": _sympy_or,
    "bitwise_or": _bitwise_or,
    "float_truediv": _sympy_float_truediv,
    "int_truediv": _sympy_int_truediv,
    "int_floordiv": _sympy_floordiv,
    "lshift": _sympy_lshift,
    "rshift": _sympy_rshift,
}


def _floor_ceil_helper(a, fn):
    import sympy

    if isinstance(a, sympy.Mul):
        aa = a.args
        if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
            coef = sympy.Integer(aa[0])
            if aa[0] == coef:  # structural equality test
                return coef * aa[1]
    if (
        isinstance(a, sympy.Float)
        and a == sympy.Integer(a)
        or isinstance(a, sympy.Integer)
    ):
        return sympy.Integer(a)
    return fn(a)


def _sympy_floor(a):
    from torch.utils._sympy.functions import FloorToInt

    return FloorToInt(a)


# NB: this is Python trunc semantics which returns an int.  Do NOT use this to
# represent torch.trunc (which is float to float)
def _sympy_trunc(a):
    from torch.utils._sympy.functions import TruncToInt

    return TruncToInt(a)


def _sympy_ceil(a):
    from torch.utils._sympy.functions import CeilToInt

    return CeilToInt(a)


def _sympy_eq(a, b):
    import sympy

    return sympy.Eq(a, b)


def _sympy_ne(a, b):
    import sympy

    return sympy.Ne(a, b)


def _sympy_gt(a, b):
    import sympy

    return sympy.Gt(a, b)


def _sympy_lt(a, b):
    import sympy

    return sympy.Lt(a, b)


def _sympy_le(a, b):
    import sympy

    return sympy.Le(a, b)


def _sympy_ge(a, b):
    import sympy

    return sympy.Ge(a, b)


def _sympy_min(a, b):
    from torch.utils._sympy.functions import Min

    return Min(a, b)


def _sympy_max(a, b):
    from torch.utils._sympy.functions import Max

    return Max(a, b)


def _sympy_ite(a, t, f):
    import sympy

    return sympy.Piecewise((t, a), (f, True))


current_module = sys.modules[__name__]


def _get_sym_math_fn(name):
    def fn(a):
        import torch.utils._sympy.functions

        return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a)

    return fn


for name in math_op_names:
    priv_sympy_name = f"_sympy_{name}"
    fn = _get_sym_math_fn(name)
    fn.__qualname__ = fn.__name__ = priv_sympy_name
    setattr(current_module, priv_sympy_name, fn)

del fn, name, priv_sympy_name  # type: ignore[possibly-undefined]


def _sympy_abs(a):
    import sympy

    return sympy.Abs(a)


def _sympy_round(number, ndigits=None):
    from torch.utils._sympy.functions import RoundDecimal, RoundToInt

    if ndigits is None:
        return RoundToInt(number)
    else:
        return RoundDecimal(number, ndigits)


def _sympy_sym_float(a):
    from torch.utils._sympy.functions import ToFloat

    # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly
    # reports that it is an integer
    return ToFloat(a)


def _sympy_is_integer(a):
    import sympy

    from torch.utils._sympy.functions import ToFloat

    return sympy.Eq(ToFloat(sympy.floor(a)), a)


magic_methods = {
    **reflectable_magic_methods,
    "sym_not": operator.invert,
    "pos": operator.pos,
    "eq": _sympy_eq,
    "ne": _sympy_ne,
    "gt": _sympy_gt,
    "lt": _sympy_lt,
    "le": _sympy_le,
    "ge": _sympy_ge,
    "floor": _sympy_floor,
    "trunc": _sympy_trunc,
    "sym_float": _sympy_sym_float,
    "ceil": _sympy_ceil,
    "neg": operator.neg,
    "sym_min": _sympy_min,
    "sym_max": _sympy_max,
    "sym_ite": _sympy_ite,
    "abs": _sympy_abs,
    "round": _sympy_round,
    "is_integer": _sympy_is_integer,
}


for name in math_op_names:
    sym_name = f"sym_{name}"
    magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")

del name, sym_name, math_op_names, current_module  # type: ignore[possibly-undefined]


def sympy_is_contiguous(sizes, strides):
    dim = len(sizes)
    return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))


def sympy_is_contiguous_generic(sizes, strides, dim_order):
    import sympy

    dim = len(sizes)

    if len(dim_order) != dim:
        return sympy.false

    is_contiguous = sympy.true
    z = sympy.S.One
    # Contiguous if the strides make sense (or the dim is size 1)
    for d in dim_order:
        is_contiguous &= sympy.Eq(sizes[d], sympy.S.One) | sympy.Eq(strides[d], z)
        z *= sizes[d]
    # OR if any size is zero
    for d in range(dim):
        is_contiguous |= sympy.Eq(sizes[d], sympy.S.Zero)
    return is_contiguous


# NB: There is a TODO in C++ to allow omitting the batch dim.  If that
# happens you will need to refactor this


def sympy_is_channels_last_contiguous_2d(sizes, strides):
    return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])


def sympy_is_channels_last_contiguous_3d(sizes, strides):
    return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])


def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
    import sympy

    from torch.utils._sympy.functions import Max

    dim = len(sizes)

    if dim != len(dim_order):
        return sympy.false

    m = sympy.S.Zero
    r = sympy.true

    # special case for trivial C dimension. default to NCHW
    r &= sympy.Ne(strides[1], 0)

    for d in dim_order:
        r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
        # Fallback to NCHW as default layout for ambiguous cases
        # This is the flaw of implicit memory_format from strides.
        # N111 tensor with identical strides for size 1 dimension;
        # Two cases could lead us here:
        # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
        # b. N11W contiguous Tensor sliced on the W-dimension.
        # ([N,1,1,1]@[W,W,W,W])
        if d == 0:
            r &= sympy.Ne(m, strides[1])
        # This is necessary to:
        # 1. distinguish the memory_format of N1H1;
        #     [H, 1, 1, 1] channels_last stride
        #     [H, H, 1, 1] contiguous stride
        # 2. permutation of 1C1W:
        #     [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
        #     [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as
        #     channels_last
        m = strides[d] * Max(sizes[d], 1)

    return r


def sympy_is_channels_last_strides_2d(sizes, strides):
    return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])


def sympy_is_channels_last_strides_3d(sizes, strides):
    return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])


def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
    from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator

    return IsNonOverlappingAndDenseIndicator(*sizes, *strides)


sizes_strides_methods = {
    # TODO: These could also be done with indicators, maybe it is better
    # for reasoning to do it that way
    "is_contiguous": sympy_is_contiguous,
    "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
    "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
    "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
    "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
    "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
}

alternate_impl_if_hinted_methods = {
    "sym_min": builtins.min,
    "sym_max": builtins.max,
}


def to_node(self, num):
    if isinstance(num, SymTypes):
        return num.node
    elif type(num) is bool:
        return self.wrap_bool(num)
    elif type(num) is int:
        return self.wrap_int(num)
    elif type(num) is float:
        return self.wrap_float(num)
    else:
        # NotImplemented is important so that Python tries the
        # other magic method
        return NotImplemented


def wrap_node(x):
    # TODO: let C++ also take advantage of this
    if isinstance(x, SymNode) and x.constant is not None:
        return x.constant
    if x.is_int():
        return SymInt(x)
    elif x.is_float():
        return SymFloat(x)
    elif x.is_bool():
        return SymBool(x)
    else:
        raise AssertionError(f"unrecognized return type {x}")


def method_to_operator(method):
    return METHOD_TO_OPERATOR[method]


def _make_node_magic(method, func):
    func = lru_cache(256)(func)

    if method in magic_methods_on_operator_with_trailing_underscore:
        method_attr = f"{method}_"
    else:
        method_attr = method

    def uninteresting_files() -> set[str]:
        import torch

        mods = [
            torch._dynamo.eval_frame,
            torch._dynamo.utils,
            torch.fx.experimental.sym_node,
            torch,
        ]
        import torch._dynamo.guards

        return (
            {inspect.getfile(m) for m in mods}
            | torch._dynamo.guards.uninteresting_files()
            | {"<string>"}
        )

    def capture_provenance(fn):
        @functools.wraps(fn)
        def wrapper(self, other=None):
            if other is None:
                result = fn(self)
            else:
                result = fn(self, other)
            if torch._logging._internal.GET_DTRACE_STRUCTURED:
                if other is not None:
                    arguments = [self, other]
                else:
                    arguments = [self]

                def get_id(sym_node) -> Optional[int]:
                    # We don't want to return an ID if the input is a constant
                    import sympy

                    if sym_node.constant is not None:
                        return None
                    elif id(sym_node) == id(result):
                        return None
                    elif isinstance(sym_node.expr, (sympy.Integer, sympy.Float)):
                        return None
                    elif sym_node.expr in (sympy.true, sympy.false):
                        return None
                    return id(sym_node)

                dtrace_structured(
                    "expression_created",
                    metadata_fn=lambda: {
                        "method": method,
                        "result": str(result),
                        "result_id": id(result),
                        "arguments": [str(a) for a in arguments],
                        "argument_ids": [
                            get_id(i) for i in arguments if get_id(i) is not None
                        ],
                        "user_stack": structured.get_user_stack(3),
                        "stack": structured.get_framework_stack(3),
                    },
                )

            return result

        return wrapper

    @capture_provenance
    def binary_magic_impl(self, other):
        from torch.fx.experimental.proxy_tensor import (
            get_proxy_mode,
            handle_sym_dispatch,
        )

        op = method_to_operator(method)

        out_hint = None
        if self.hint is not None and other.hint is not None:
            out_hint = op(self.hint, other.hint)

        alternate_impl = alternate_impl_if_hinted_methods.get(method)
        if alternate_impl and out_hint is not None:
            return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))

        if get_proxy_mode():
            return to_node(
                self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
            )
        assert isinstance(other, SymNode)
        optimized_summation = False
        try:
            if method == "mod":
                from torch.utils._sympy.functions import Mod, PythonMod

                # Special handling for mod that requires access to the value
                # ranges
                shape_env = self.shape_env
                if (
                    self.expr.is_nonnegative
                    or shape_env.bound_sympy(self.expr).lower >= 0
                ) and (
                    other.expr.is_nonnegative
                    or shape_env.bound_sympy(other.expr).lower >= 0
                ):
                    out = Mod(self.expr, other.expr)
                else:
                    out = PythonMod(self.expr, other.expr)
            elif method == "add":
                # see Note [optimized_summation]
                (optimized_summation, out) = func(
                    self.expr,
                    other.expr,
                    self._optimized_summation,
                    other._optimized_summation,
                )
            else:
                # TODO: consider constant prop here
                out = func(self.expr, other.expr)
        except Exception:
            log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
            raise
        sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
        pytype: type
        # This is not strictly correct. In Python, a**b may return complex when
        # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
        # returns a float while both arguments are ints: 2**(-1). Also, max and
        # min do not type promote. To avoid having data-dependent control flow
        # here, we just set the type to float if one of the args is a float. In
        # case of a type mismatch, we assume that it will be detected during
        # evaluation.
        if method in always_float_magic_methods:
            pytype = float
        elif method in always_bool_magic_methods:
            pytype = bool
        elif self.pytype is float or other.pytype is float:
            pytype = float
        else:
            pytype = self.pytype

        if (
            pytype is not None
            and out_hint is not None
            and not isinstance(out_hint, SymTypes)
        ):
            out_hint = pytype(out_hint)

        # Create a FX node that corresponds to the operation being applied to
        # this node.
        fx_node, _ = self.shape_env._create_fx_call_function(
            op, (self.fx_node, other.fx_node)
        )

        result = SymNode(
            out,
            self.shape_env,
            pytype,
            out_hint,
            fx_node=fx_node,
            optimized_summation=optimized_summation,  # see Note [optimized_summation]
        )
        return result

    @capture_provenance
    def unary_magic_impl(self):
        from torch.fx.experimental.proxy_tensor import (
            get_proxy_mode,
            handle_sym_dispatch,
        )

        op = method_to_operator(method)
        if get_proxy_mode():
            return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
        # TODO: consider constant prop here
        expr = self.expr
        if method == "floor" or method == "ceiling":
            expr = self.shape_env._simplify_floor_div(expr)

        try:
            out = func(expr)
        except Exception:
            log.warning("failed to eval %s(%s)", method, expr)
            raise
        sym_node_log.debug("%s %s -> %s", func, expr, out)
        out_hint = None
        if self.hint is not None:
            out_hint = op(self.hint)
        pytype: type
        if method in always_int_magic_methods:
            pytype = int
        elif method in always_bool_magic_methods:
            pytype = bool
        elif method in always_float_magic_methods:
            pytype = float
        else:
            pytype = self.pytype

        fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,))
        return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)

    if method in unary_methods:
        setattr(SymNode, f"_{method_attr}", unary_magic_impl)
    elif method == "sym_ite":

        def sym_ite_impl(pred_node, then_node, else_node):
            from torch.fx.experimental.proxy_tensor import (
                get_proxy_mode,
                handle_sym_dispatch,
            )

            out_hint = then_node.hint if pred_node.hint else else_node.hint
            if get_proxy_mode():
                return to_node(
                    pred_node,
                    handle_sym_dispatch(
                        sym_ite,
                        (
                            wrap_node(pred_node),
                            wrap_node(then_node),
                            wrap_node(else_node),
                        ),
                        {},
                    ),
                )

            try:
                out = func(pred_node.expr, then_node.expr, else_node.expr)
            except Exception:
                log.warning(
                    "failed to eval %s(%s, %s, %s)",
                    method,
                    pred_node.expr,
                    then_node.expr,
                    else_node.expr,
                )
                raise

            fx_node, _ = pred_node.shape_env._create_fx_call_function(
                sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node)
            )
            return SymNode(
                out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node
            )

        setattr(SymNode, f"_{method_attr}", sym_ite_impl)
    elif method == "round":

        def round_impl(self, ndigits=None):
            from torch.fx.experimental.proxy_tensor import (
                get_proxy_mode,
                handle_sym_dispatch,
            )

            op = builtins.round
            if get_proxy_mode():
                return to_node(
                    self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
                )

            expr = self.expr
            try:
                out = func(expr, ndigits)
            except Exception:
                log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
                raise

            if ndigits is None:
                pytype = int
            else:
                pytype = self.pytype

            out_hint = None
            if self.hint is not None:
                out_hint = op(self.hint, ndigits)

            # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the
            # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here
            # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The
            # hack down below works, because all round function down the line all take ndigits=None as default in their
            # signature.
            # TODO: Remove the args construction below if a different sentinel is used by FX.
            # ezyang(May 2024): LOL
            args = [self.fx_node]
            if ndigits is not None:
                args.append(ndigits)
            fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args))
            return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)

        setattr(SymNode, f"_{method_attr}", round_impl)
    else:
        setattr(SymNode, f"_{method_attr}", binary_magic_impl)


def _make_node_sizes_strides(method, func):
    # NB: don't LRU cache, lots of arguments

    def sizes_strides_impl(self, sizes, strides):
        from torch.fx.experimental.proxy_tensor import (
            get_proxy_mode,
            handle_sym_dispatch,
        )

        op = getattr(sys.modules[__name__], method)
        if get_proxy_mode():
            return to_node(
                self,
                handle_sym_dispatch(
                    op,
                    ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
                    {},
                ),
            )
        size_exprs = [s.expr for s in sizes]
        stride_exprs = [s.expr for s in strides]
        try:
            out = func(size_exprs, stride_exprs)
        except Exception:
            log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs)
            raise
        # bool is never expandable

        size_hints = []
        out_hint = None
        for s in sizes:
            if s.hint is None:
                break
            size_hints.append(s.hint)
        else:
            stride_hints = []
            for s in strides:
                if s.hint is None:
                    break
                stride_hints.append(s.hint)
            else:
                out_hint = op(size_hints, stride_hints)

        # NB: This is the indicator function, not the actual bool!
        pytype: type
        if method.endswith("_indicator"):
            pytype = int
        else:
            pytype = bool
        return SymNode(out, self.shape_env, pytype, out_hint)

    setattr(SymNode, f"_{method}", sizes_strides_impl)

    # TODO: This is technically hotpath, but in the ideal end state
    # guards on this will resolve at a higher level so you never
    # spend time in this code
    def sizes_strides_user(sizes, strides):
        import sympy

        from torch.fx.experimental.symbolic_shapes import (
            eval_is_non_overlapping_and_dense,
        )

        for a in itertools.chain(sizes, strides):
            if isinstance(a, SymInt):
                return wrap_node(
                    getattr(a.node, method)(
                        [to_node(a.node, b) for b in sizes],
                        [to_node(a.node, b) for b in strides],
                    )
                )
        if method == "is_non_overlapping_and_dense_indicator":
            return eval_is_non_overlapping_and_dense(sizes, strides)
        else:
            # TODO: this is an awful implementation
            return bool(
                func(
                    [sympy.sympify(a) for a in sizes],
                    [sympy.sympify(a) for a in strides],
                )
            )

    # Skip for is_non_overlapping_and_dense_indicator
    if not hasattr(sys.modules[__name__], method):
        setattr(sys.modules[__name__], method, sizes_strides_user)


for method, func in magic_methods.items():
    _make_node_magic(method, func)

for method, func in sizes_strides_methods.items():
    _make_node_sizes_strides(method, func)


def _make_user_magic(method, user_type):
    # User magic takes care of wrapping the other operand into a node,
    # so that our internal logic can assume everything is nodes

    if method in magic_methods_on_operator_with_trailing_underscore:
        method_attr = f"sym_{method}"
    else:
        method_attr = method

    def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
        if isinstance(x, (int, float, bool)):
            return x
        if isinstance(x, SymBool):
            return x.node.guard_bool("", 0)
        raise AssertionError("expect to be called with constant SymBools")

    def is_constant(x):
        if isinstance(x, (int, float, bool)):
            return True
        if isinstance(x, (SymInt, SymFloat, SymBool)):
            return x.node.is_constant()
        return False

    # Promotion rules for binary operations.  NB: we preserve PYTHON semantics
    #   - if args are same type, do nothing
    #   - if one arg is float, promote other arg to float
    #       - nb: this applies to floordiv, even though output is integral
    #       (it's still float)
    #   - pow is funny business
    #       - if both ints
    #       - trigger a guard on exponent >= 0
    #           - if non-negative, output is int
    #           - otherwise, output is float
    #   - otherwise, promote other arg to float
    #       - nb: complex is impossible to handle correctly lol, with
    #       negative base and integral float need to diverge semantics and
    #       just always return complex.  Neener neener pretend this problem
    #       doesn't exist
    #   - equality is pain: Python does the fancy thing where it unpacks the
    #     mantissa from the float and then compares that against the int.
    #     Which means it is able to tell that
    #     9007199254740993 != 9007199254740992. (rather than if the LHS was
    #     promoted to float, in which case it would have truncated to the RHS
    #     and subsequently been equal).  We'll model this exactly by having
    #     special mixed type equality operations.  Unfortunately, we need to
    #     do this for all comparison operations (maybe I'll only implement
    #     compare)
    #   - sym_ite mumble mumble really shouldn't allow mixed but whatever

    if method in bool_becomes_int_magic_methods:

        def promote(x):
            """Implements True+True=2, which works in python but not sympy"""
            if isinstance(x, SymBool):
                return SymInt(x.node.wrap_int(int(x)))
            return x

    else:

        def promote(x):
            return x

    def promote2(self, other):
        # TODO: Remove eq and other relations from this list.
        # CPython has fancy implementations for these to get as much precision
        # as possible instead of just promoting to float64 and praying, so we
        # need to handle them specially too.
        # Also, note that int_truediv doesn't go through this path: both
        # arguments are "int" so there isn't any promotion
        if method not in [
            "add",
            "sub",
            "mul",
            "mod",
            "float_pow",
            "float_truediv",
            "int_floordiv",
            "sym_min",
            "sym_max",
            # TODO: remove these
            "eq",
            "ne",
            "gt",
            "lt",
            "le",
            "ge",
        ]:
            return self, other
        f_self = isinstance(self, (float, torch.SymFloat))
        f_other = isinstance(other, (float, torch.SymFloat))
        if f_self or f_other:
            if not f_self:
                self = torch.sym_float(self)
            if not f_other:
                other = torch.sym_float(other)
        return self, other

    # Before and after performing the operation, check if any operands are constant.
    # If so, extract out the constant values first. If `self` itself is a
    # constant, then "redispatch" by calling back into the operator. Sometimes
    # this means that operations involving SymBool return plain bools.
    # Alternatively, we could also rewrap into constant Symbool (i.e. by
    # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
    # today for no particular reason.
    def unary_magic_impl(self):
        self = promote(self)
        if is_constant(self):
            return (method_to_operator(method))(get_constant(self))
        return wrap_node(getattr(self.node, method_attr)())

    def binary_magic_impl(self, other):
        if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
            return NotImplemented
        sym_node_log.debug("MAGIC %s %s %s", method, self, other)
        self = promote(self)
        other = promote(other)
        self, other = promote2(self, other)
        if is_constant(self):
            return (method_to_operator(method))(get_constant(self), other)
        if is_constant(other):
            other = get_constant(other)
        other_node = to_node(self.node, other)
        if other_node is NotImplemented:
            return NotImplemented
        ret = wrap_node(getattr(self.node, method_attr)(other_node))
        return get_constant(ret) if is_constant(ret) else ret

    def rbinary_magic_impl(self, other):
        if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
            return NotImplemented
        self = promote(self)
        other = promote(other)
        self, other = promote2(self, other)
        if is_constant(self):
            return (method_to_operator(method))(get_constant(self), other)
        if is_constant(other):
            other = get_constant(other)
        other_node = to_node(self.node, other)
        if other_node is NotImplemented:
            return NotImplemented
        ret = wrap_node(getattr(other_node, method_attr)(self.node))
        return get_constant(ret) if is_constant(ret) else ret

    if method in unary_magic_methods:
        setattr(user_type, f"__{method}__", unary_magic_impl)
    elif method in unary_nonmagic_methods:
        orig = getattr(user_type, method)
        setattr(user_type, method, update_wrapper(unary_magic_impl, orig))
    elif method == "sym_ite":

        def sym_ite_magic_impl(pred, then_val, else_val):
            pred_node = pred.node
            then_node = to_node(pred_node, then_val)
            else_node = to_node(pred_node, else_val)
            if then_node is NotImplemented or else_node is NotImplemented:
                return NotImplemented
            assert (
                isinstance(then_node, SymNode)
                and isinstance(else_node, SymNode)
                and then_node.pytype == else_node.pytype
            )
            ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
            return get_constant(ret) if ret.node.is_constant() else ret

        setattr(user_type, f"__{method}__", sym_ite_magic_impl)
    elif method == "round":

        def round_magic_impl(self, ndigits=None):
            if is_constant(self):
                return builtins.round(get_constant(self), ndigits)

            return wrap_node(getattr(self.node, method)(ndigits))

        setattr(user_type, f"__{method}__", round_magic_impl)
    else:
        method_name = method
        if method in bitwise_ops:
            method_name = bitwise_ops[method]
        setattr(user_type, f"__{method_name}__", binary_magic_impl)
        if method in reflectable_magic_methods:
            setattr(user_type, f"__r{method_name}__", rbinary_magic_impl)


for method, func in magic_methods.items():  # type: ignore[assignment]
    if method in only_bool_magic_methods:
        _make_user_magic(method, SymBool)
        continue
    if method in only_float_magic_methods:
        _make_user_magic(method, SymFloat)
        continue
    if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
        _make_user_magic(method, SymBool)
    _make_user_magic(method, SymInt)
    if method not in bitwise_ops:
        _make_user_magic(method, SymFloat)

del method
del func
