# mypy: allow-untyped-defs
import contextlib
import dataclasses
import functools
import itertools
import logging
import math
import operator
import re
from enum import auto, Enum
from itertools import chain
from typing import (
    Any,
    Callable,
    ClassVar,
    Dict,
    List,
    NamedTuple,
    Optional,
    Tuple,
    Union,
)

import sympy
from sympy.printing.printer import Printer

import torch
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges

from .. import config, metrics
from ..utils import (
    DeferredLineBase,
    generate_assert,
    IndentedBuffer,
    sympy_dot,
    sympy_subs,
    unique,
)
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V


schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")


def data_type_logger(msg):
    if schedule_log.isEnabledFor(logging.DEBUG):
        schedule_log.debug("Data type propagation: %s", msg)


@dataclasses.dataclass
class WorkspaceArg:
    """A temporary buffer used for a single kernel, then discarded.

    Not registered as a traditional buffer since there are no users,
    so it would be dead code eliminated.
    """

    nbytes: sympy.Expr
    zero_fill: bool


@dataclasses.dataclass
class TensorArg:
    name: str
    buffer: str
    dtype: torch.dtype
    offset: sympy.Expr = sympy.Integer(0)  # c++ only
    alias_of: Optional[str] = None  # halide only


@dataclasses.dataclass
class SizeArg:
    name: str
    expr: sympy.Expr

    @property
    def alias_of(self):
        return None


@dataclasses.dataclass
class DeviceCodegen:
    scheduling: Any
    wrapper_codegen: type
    cpp_wrapper_codegen: type = type(None)


KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]

device_codegens: Dict[str, DeviceCodegen] = {}


class DeviceOpOverrides:
    def import_get_raw_stream_as(self, name):
        raise NotImplementedError

    def set_device(self, device_idx):
        raise NotImplementedError

    def synchronize(self):
        raise NotImplementedError

    def device_guard(self, device_idx):
        raise NotImplementedError


device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}


# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
# For any new backend looking to integrate with Inductor, customization of these two main
# parts are necessary to generate its specific code.
#
# Kernel code generation is determined by different Scheduling. Consequently, a new
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
#
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
# and override specific member functions to create backend-specific Python wrapper code.
#
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
# register_backend_for_device, to equip a new backend at runtime.
#
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
# This backend can be used as a reference:
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
def register_backend_for_device(
    device: str,
    device_scheduling: Any,
    device_wrapper_codegen: type,
    device_cpp_wrapper_codegen: type = type(None),
):
    device_codegens[device] = DeviceCodegen(
        device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
    )


class BackendFeature(Enum):
    FOREACH = auto()
    BUCKETIZE = auto()
    INPLACE_BUFFERS = auto()
    MASKED_SCATTER_WITH_INDEX = auto()
    SCAN = auto()
    SORT = auto()
    TUPLE_REDUCTION = auto()
    PREFER_STORE_LOOP_ORDER = auto()
    TRITON_TEMPLATES = auto()
    REDUCE_TO_SINGLE_ELEMENT = auto()


def get_backend_features(device: Union[torch.device, str]):
    init_backend_registration()
    if isinstance(device, torch.device):
        device_type = device.type
    else:
        assert isinstance(device, str)
        device_type = device
        device = torch.device(device_type)
    scheduling = get_scheduling_for_device(device_type)
    return scheduling(None).get_backend_features(device)


def has_backend_feature(device, feature):
    """See also V.graph.has_feature"""
    assert isinstance(feature, BackendFeature)
    return feature in get_backend_features(device)


def get_scheduling_for_device(device: str):
    return device_codegens[device].scheduling if device in device_codegens else None


def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
    if device in device_codegens:
        wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
        return (
            wrapper_codegen_obj.cpp_wrapper_codegen
            if cpp_wrapper
            else wrapper_codegen_obj.wrapper_codegen
        )
    else:
        return None


@functools.lru_cache(None)
def init_backend_registration():
    from .cpp import CppScheduling
    from .cpp_wrapper_cpu import CppWrapperCpu
    from .cpp_wrapper_cuda import CppWrapperCuda
    from .cuda_combined_scheduling import CUDACombinedScheduling
    from .halide import HalideScheduling
    from .triton import TritonScheduling
    from .wrapper import WrapperCodeGen

    if get_scheduling_for_device("cpu") is None:
        cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling}
        register_backend_for_device(
            "cpu",
            lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs),
            WrapperCodeGen,
            CppWrapperCpu,
        )

    if get_scheduling_for_device("cuda") is None:
        # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
        cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling}
        register_backend_for_device(
            "cuda",
            lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs),
            WrapperCodeGen,
            CppWrapperCuda,
        )

    if get_scheduling_for_device("xpu") is None:
        register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen)

    private_backend = torch._C._get_privateuse1_backend_name()
    if (
        private_backend != "privateuseone"
        and get_scheduling_for_device(private_backend) is None
    ):
        from torch.utils.backend_registration import _get_custom_mod_func

        try:
            device_scheduling = _get_custom_mod_func("Scheduling")
            wrapper_codegen = _get_custom_mod_func("WrapperCodeGen")
            cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodeGen")
            if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
                register_backend_for_device(
                    private_backend,
                    device_scheduling,
                    wrapper_codegen,
                    cpp_wrapper_codegen,
                )
        except RuntimeError:
            pass


def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
    from ..ir import FlexibleLayout

    # added contiguous index prevents reordering
    return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]


def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
    device_op_overrides_dict[device] = device_op_overrides


def get_device_op_overrides(device: str):
    assert isinstance(device, str)

    if not device_op_overrides_dict.keys():
        from .cuda import device_op_overrides  # noqa: F401
        from .xpu import device_op_overrides as xpu_op_overrides  # noqa: F401

    if device in device_op_overrides_dict.keys():
        return device_op_overrides_dict[device]


@functools.lru_cache(None)
def boolean_ops():
    return (
        "isinf",
        "isnan",
        "logical_not",
        "signbit",
        "le",
        "lt",
        "ge",
        "gt",
        "eq",
        "ne",
    )


DTYPE_TO_COMPUTATION_DTYPE = {
    torch.bfloat16: torch.float,
    torch.float16: torch.float,
    **{
        dtype: dtype
        for dtype in [
            torch.bool,
            torch.float32,
            torch.float64,
            torch.int8,
            torch.int16,
            torch.int32,
            torch.int64,
            torch.uint8,
            torch.uint16,
            torch.uint32,
            torch.uint64,
        ]
    },
}


def deduce_output_dtype_by_name(
    op_name: str,
    *args,
    **kwargs,
) -> Optional[torch.dtype]:
    """
    Given op name and a list of input dtypes, deduce the output dtype
    """
    if op_name in boolean_ops():
        return torch.bool
    elif op_name in (
        "to_dtype",
        "index_expr",
    ):
        return kwargs["dtype"] if "dtype" in kwargs else args[-1]
    elif op_name in (
        "rand",
        "randn",
    ):
        return torch.float
    elif op_name in (
        "get_index",
        "randint64",
        "load_seed",
    ):
        return torch.int64
    elif op_name == "reduction":
        return kwargs["dtype"] if "dtype" in kwargs else args[1]
    elif op_name == "constant":
        dtype = kwargs["dtype"] if "dtype" in kwargs else args[-1]
        return DTYPE_TO_COMPUTATION_DTYPE[dtype]  # type: ignore[index]
    elif op_name in (
        "load",
        "store",
        "store_reduction",
    ):
        buf_name = args[1]
        return V.graph.get_dtype(buf_name)  # type: ignore[arg-type]
    elif op_name == "to_dtype_bitcast":
        return kwargs["dtype"] if "dtype" in kwargs else args[-2]
    return None


class DataTypePropagation:
    def __init__(self, body) -> None:
        self.body = body
        self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
            "root": body.root_block.graph
        }
        for k, v in body.subblocks.items():
            self.graphs[k] = v.graph

    def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
        inputs = node.all_input_nodes
        input_nodes = [
            n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
        ]
        if len(input_nodes) == 0:
            return None

        all_input_nodes_propagated = all(
            OptimizationContext.key in n.meta
            and n.meta[OptimizationContext.key].dtype is not None
            for n in input_nodes
        )
        if not all_input_nodes_propagated:
            return None

        return functools.reduce(
            torch.promote_types,
            [n.meta[OptimizationContext.key].dtype for n in input_nodes],
        )

    def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
        sub_graph = self.graphs[node.target]
        dtype = self.propagate_graph(sub_graph)
        assert dtype
        return dtype

    def deduce_node_dtype(self, node: torch.fx.Node):
        if node.op == "placeholder":
            return None

        if node.target == "output" and len(node.args) != 1:
            # we can infer output node if it only have 1 arg
            return None

        if node.target == operator.getitem:
            return self.deduce_node_dtype(node.args[0])  # type: ignore[arg-type]

        assert isinstance(node.target, str)

        if node.target.startswith("masked_subblock"):
            return self.deduce_node_dtype_by_subgraph(node)

        if (
            output_dtype := deduce_output_dtype_by_name(
                node.target,
                *node.args,
                **node.kwargs,
            )
        ) is not None:
            return output_dtype

        return self.deduce_node_dtype_by_inputs(node)

    def propagate_graph(self, graph: torch.fx.Graph):
        assert graph.nodes
        graph_dtype = None
        # For masked_subblock, we use output's dtype to represent
        # the dtype of this subgraph. For other cases, graph_dtype
        # might be None
        for node in graph.nodes:
            if OptimizationContext.key in node.meta:
                opt_ctx = node.meta[OptimizationContext.key]
            else:
                opt_ctx = OptimizationContext()

            opt_ctx.dtype = self.deduce_node_dtype(node)
            node.meta[OptimizationContext.key] = opt_ctx
            if node.target == "output":
                graph_dtype = opt_ctx.dtype
        return graph_dtype

    def propagate(self):
        self.propagate_graph(self.graphs["root"])

    @classmethod
    def propagate_loopbody(cls, body):
        return cls(body).propagate()

    @classmethod
    def propagate_scheduler_node(cls, node):
        from ..loop_body import LoopBody
        from ..scheduler import SchedulerNode

        assert isinstance(node, SchedulerNode)
        assert isinstance(node._body, LoopBody)
        DataTypePropagation.propagate_loopbody(node._body)


# This printer contains rules that are supposed to be generic for both C/C++ and
# Python
class ExprPrinter(Printer):
    @staticmethod
    def paren(string):
        def all_in_parens(string):
            if string[0] != "(" or len(string) < 2:
                return False
            count = 1
            for i, char in enumerate(string[1:]):
                if char == "(":
                    count += 1
                elif char == ")":
                    count -= 1
                if count == 0 and i != len(string) - 2:
                    return False
            assert count == 0
            return True

        if (
            isinstance(string, CSEVariable)
            or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
            or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
            or string == ""
        ):
            return string
        # don't put extra parens for strings that are already wrapped in parens
        if all_in_parens(string):
            return string
        return f"({string})"

    def _print_Relational(self, expr):
        return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))

    def _print_Mul(self, expr):
        return "*".join(map(self.paren, map(self._print, expr.args)))

    def _print_Add(self, expr):
        return " + ".join(map(self.paren, map(self._print, expr.args)))

    # NB: this is OK to put here, because Mod is only defined for positive
    # numbers, and so across C/Python its behavior is consistent
    def _print_Mod(self, expr):
        return " % ".join(map(self.paren, map(self._print, expr.args)))

    def _print_FloatTrueDiv(self, expr):
        lhs, rhs = expr.args
        return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"

    def _print_CleanDiv(self, expr):
        return self._print_FloorDiv(expr)

    def _print_Identity(self, expr):
        return self._print(expr.args[0])

    def _print_GreaterThan(self, expr):
        # GreaterThan:          >=
        # StrictlyGreaterThan:  >
        # Go figure...
        return " >= ".join(map(self.paren, map(self._print, expr.args)))

    # NB: The C implementation is injected into codegen at
    # torch/_inductor/codegen/wrapper.py
    def _print_align(self, expr):
        assert len(expr.args) == 1
        return f"align({self._print(expr.args[0])})"

    # This must be implemented because sympy will collect x * x into Pow(x, 2), without
    # any explicit intervention.  We print it just like x * x, notably, we
    # never generate sympy.Pow with floats.
    #
    # NB: this pow by natural, you should never have used builtin sympy.pow
    # for FloatPow, and a symbolic exponent should be PowByNatural.  These
    # means exp is guaranteed to be integer.
    def _print_Pow(self, expr):
        base, exp = expr.args
        base = self._print(base)
        assert exp == int(exp), exp
        exp = int(exp)
        assert exp >= 0
        if exp > 0:
            return "*".join([self.paren(base)] * exp)
        else:  # exp == 0
            return "1"

    # Explicit NotImplemented functions are to prevent default sympy printing
    # behavior, which will just barf out ToFloat(...) to your IR.  The error
    # message is better here because it tells you which printer class it needs
    # to go in.

    def _print_ToFloat(self, expr):
        raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")

    def _print_Infinity(self, expr):
        raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")

    def _print_NegativeInfinity(self, expr):
        raise NotImplementedError(
            f"_print_NegativeInfinity not implemented for {type(self)}"
        )

    def _print_FloorDiv(self, expr):
        raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")

    def _print_PythonMod(self, expr):
        raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")

    def _print_IntTrueDiv(self, expr):
        raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")

    def _print_PowByNatural(self, expr):
        raise NotImplementedError(
            f"_print_PowByNatural not implemented for {type(self)}"
        )

    def _print_FloatPow(self, expr):
        raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")

    def _print_TruncToInt(self, expr):
        raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")

    def _print_RoundToInt(self, expr):
        raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")

    def _print_RoundDecimal(self, expr):
        raise NotImplementedError(
            f"_print_RoundDecimal not implemented for {type(self)}"
        )

    # NB: Some float operations are INTENTIONALLY not implemented for
    # printers.  You can implement them as a quick unblock, but it is better
    # to ask yourself why we haven't done this computation in the Tensor
    # universe instead

    def _print_TruncToFloat(self, expr):
        raise NotImplementedError(
            f"_print_TruncToFloat not implemented for {type(self)}"
        )

    def doprint(self, expr, *, simplify: bool = True):
        # TODO: why are people passing strings to the printer here :think:
        if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
            expr = V.graph.sizevars.simplify(expr)
        return super().doprint(expr)


class PythonPrinter(ExprPrinter):
    def _print_ToFloat(self, expr):
        assert len(expr.args) == 1
        return f"float({self._print(expr.args[0])})"

    def _print_ModularIndexing(self, expr):
        x, div, mod = expr.args
        x = self.paren(self.doprint(x))
        div = self.paren(self.doprint(div))
        mod = self.paren(self.doprint(mod))
        if div != "1":
            x = f"({x} // {div})"
        return f"{x} % {mod}"

    def _print_Infinity(self, expr):
        return "math.inf"

    def _print_NegativeInfinity(self, expr):
        return "-math.inf"

    # WARNING: this is dangerous for Triton, which has C-style modulus
    def _print_PythonMod(self, expr):
        return " % ".join(map(self.paren, map(self._print, expr.args)))

    # WARNING: this is dangerous for Triton, which has C-style modulus
    def _print_FloorDiv(self, expr):
        x, div = expr.args
        x = self.paren(self.doprint(x))
        div = self.paren(self.doprint(div))
        return f"({x} // {div})"

    # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
    # does a special algorithm
    def _print_IntTrueDiv(self, expr):
        lhs, rhs = expr.args
        return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"

    def _helper_sqrt(self, expr):
        return f"math.sqrt({self._print(expr)})"

    def _print_OpaqueUnaryFn_sqrt(self, expr):
        return self._helper_sqrt(expr.args[0])

    def _print_FloatPow(self, expr):
        base, exp = expr.args
        return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"

    # TODO: Not sure this works with Triton, even when base/exp are integral
    def _print_PowByNatural(self, expr):
        base, exp = expr.args
        return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"

    def _print_floor(self, expr):
        assert len(expr.args) == 1
        return f"math.floor({self._print(expr.args[0])})"

    def _print_FloorToInt(self, expr):
        assert len(expr.args) == 1
        return f"math.floor({self._print(expr.args[0])})"

    def _print_TruncToInt(self, expr):
        assert len(expr.args) == 1
        # This also could have been int(), they'll do the same thing for float
        return f"math.trunc({self._print(expr.args[0])})"

    def _print_ceiling(self, expr):
        assert len(expr.args) == 1
        return f"math.ceil({self._print(expr.args[0])})"

    def _print_CeilToInt(self, expr):
        assert len(expr.args) == 1
        return f"math.ceil({self._print(expr.args[0])})"

    def _print_Abs(self, expr):
        assert len(expr.args) == 1
        return f"abs({self._print(expr.args[0])})"

    # NB: It's expected that we've made explicit any promotion in the sympy
    # expression, so it doesn't matter that Python max/min doesn't perform
    # promotion
    def _print_Max(self, expr):
        assert len(expr.args) >= 2
        return f"max({', '.join(map(self._print, expr.args))})"

    def _print_Min(self, expr):
        assert len(expr.args) >= 2
        return f"min({', '.join(map(self._print, expr.args))})"

    def _print_OpaqueUnaryFn_cos(self, expr):
        assert len(expr.args) == 1
        return f"math.cos({self._print(expr.args[0])})"

    def _print_OpaqueUnaryFn_cosh(self, expr):
        assert len(expr.args) == 1
        return f"math.cosh({self._print(expr.args[0])})"

    def _print_OpaqueUnaryFn_acos(self, expr):
        assert len(expr.args) == 1
        return f"math.acos({self._print(expr.args[0])})"

    def _print_OpaqueUnaryFn_sin(self, expr):
        assert len(expr.args) == 1
        return f"math.sin({self._print(expr.args[0])})"

    def _print_OpaqueUnaryFn_sinh(self, expr):
        assert len(expr.args) == 1
        return f"math.sinh({self._print(expr.args[0])})"

    def _print_OpaqueUnaryFn_asin(self, expr):
        assert len(expr.args) == 1
        return f"math.asin({self._print(expr.args[0])})"

    def _print_OpaqueUnaryFn_tan(self, expr):
        assert len(expr.args) == 1
        return f"math.tan({self._print(expr.args[0])})"

    def _print_OpaqueUnaryFn_tanh(self, expr):
        assert len(expr.args) == 1
        return f"math.tanh({self._print(expr.args[0])})"

    def _print_OpaqueUnaryFn_atan(self, expr):
        assert len(expr.args) == 1
        return f"math.atan({self._print(expr.args[0])})"

    def _print_RoundToInt(self, expr):
        assert len(expr.args) == 1
        return f"round({self._print(expr.args[0])})"

    def _print_RoundDecimal(self, expr):
        assert len(expr.args) == 2
        number, ndigits = expr.args
        assert isinstance(ndigits, sympy.Integer)
        return f"round({self._print(number)}, {ndigits})"


class OpOverrides:
    def __init__(self, parent):
        super().__init__()
        self._parent = parent

    def __getattr__(self, item):
        return getattr(self._parent, item)

    @staticmethod
    def identity(value):
        # used to trigger cse
        return value

    @staticmethod
    def constant(value, dtype):
        return repr(value)

    @staticmethod
    def reciprocal(x):
        return ops.truediv(ops.constant(1, torch.int32), x)

    @staticmethod
    def square(x):
        return ops.mul(x, x)

    @staticmethod
    def erfc(x):
        return ops.sub(ops.constant(1, torch.float32), ops.erf(x))

    @staticmethod
    def erfcx(x):
        return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))

    @staticmethod
    def expm1(x):
        return ops.sub(ops.exp(x), ops.constant(1, torch.float32))

    @staticmethod
    def log10(x):
        return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))

    @staticmethod
    def log2(x):
        return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))

    @staticmethod
    def exp2(x):
        return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))

    @staticmethod
    def log1p(x):
        return ops.log(ops.add(x, ops.constant(1, torch.int32)))

    @staticmethod
    def sigmoid(x):
        one = ops.constant(1, torch.int32)
        return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))

    @staticmethod
    def libdevice_sigmoid(x):
        one = ops.constant(1, torch.int32)
        return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))

    @staticmethod
    def relu(x):
        return ops.maximum(x, ops.constant(0, torch.int32))

    @staticmethod
    def libdevice_abs(x):
        return ops.abs(x)

    @staticmethod
    def libdevice_sqrt(x):
        return ops.sqrt(x)

    @staticmethod
    def libdevice_cos(x):
        return ops.cos(x)

    @staticmethod
    def libdevice_sin(x):
        return ops.sin(x)

    @staticmethod
    def libdevice_log(x):
        return ops.log(x)

    @staticmethod
    def libdevice_exp(x):
        return ops.exp(x)

    @staticmethod
    def bitwise_not(x):
        return f"~{ExprPrinter.paren(x)}"

    @staticmethod
    def logical_not(a):
        return f"{ExprPrinter.paren(a)} == 0"

    @staticmethod
    def bitwise_and(x, y):
        return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"

    @staticmethod
    def bitwise_or(x, y):
        return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"

    @staticmethod
    def bitwise_xor(x, y):
        return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"

    @staticmethod
    def bitwise_left_shift(x, y):
        return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"

    @staticmethod
    def bitwise_right_shift(x, y):
        return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"

    @staticmethod
    def remainder(a, b):
        r = ops.mod(a, b)
        cond = ops.and_(
            ops.ne(r, ops.constant(0, torch.int32)),
            ops.ne(ops.signbit(r), ops.signbit(b)),
        )
        return ops.where(cond, ops.add(r, b), r)

    @staticmethod
    def trunc_to_int(a, dtype):
        return ops.to_dtype(ops.trunc(a), dtype)

    @staticmethod
    def floor_to_int(a, dtype):
        return ops.to_dtype(ops.floor(a), dtype)

    @staticmethod
    def ceil_to_int(a, dtype):
        return ops.to_dtype(ops.ceil(a), dtype)

    @staticmethod
    def round_to_int(a, dtype):
        return ops.to_dtype(ops.round(a), dtype)

    @staticmethod
    def int_truediv(a, b):
        # TODO: this is wrong
        # TODO: an easy bandaid is to generate runtime asserts that it's
        # <= 2**53, which is when this equation is correct
        return ops.truediv(a, b)

    @staticmethod
    def load_seed(name, offset):
        return ops.load(name, sympy.Integer(offset))

    @classmethod
    def _initialize_pointwise_overrides(cls, target):
        assert target in {"triton", "cpp", "cppvec"}, target

        for funcname, data in pointwise_overrides_data.items():
            impl = getattr(data, target)
            if impl is None:
                continue
            setattr(cls, funcname, staticmethod(impl))


@dataclasses.dataclass
class OverridesData:
    name: str
    cpp: Callable[..., str]
    # None when not impl in libdevice/triton
    triton: Optional[Callable[..., str]] = None
    # None when not impl in aten/.../vec
    cppvec: Optional[Callable[..., str]] = None
    type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
    )


# NB: if you add a new special function, don't forget to update
# torch._inductor.ops_handler too
pointwise_overrides_data: Dict[str, OverridesData] = dict(
    airy_ai=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"airy_ai_forward({x})",
        name="special_airy_ai",
    ),
    bessel_j0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"bessel_j0_forward({x})",
        triton=lambda x: f"libdevice.j0({x})",
        name="special_bessel_j0",
    ),
    bessel_j1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"bessel_j1_forward({x})",
        triton=lambda x: f"libdevice.j1({x})",
        name="special_bessel_j1",
    ),
    bessel_y0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"bessel_y0_forward({x})",
        triton=lambda x: f"libdevice.y0({x})",
        name="special_bessel_y0",
    ),
    bessel_y1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"bessel_y1_forward({x})",
        triton=lambda x: f"libdevice.y1({x})",
        name="special_bessel_y1",
    ),
    digamma=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"calc_digamma({x})",
        cppvec=lambda x: f"{x}.digamma()",
        name="digamma",
    ),
    # no cpp nor triton implementation for entr, it is defined as decomposition
    # erf, erfc
    erfcx=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"calc_erfcx({x})",
        triton=lambda x: f"libdevice.erfcx({x})",
        name="special_erfcx",
    ),
    fma=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
        cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
        triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
        name="fma",
    ),
    # erfinv, exp2, expit, gammaln
    igamma=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"calc_igamma({x}, {y})",
        name="igamma",
    ),
    igammac=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"calc_igammac({x}, {y})",
        name="igammac",
    ),
    gammainc=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"calc_igamma({x}, {y})",
        name="special_gammainc",
    ),
    gammaincc=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"calc_igammac({x}, {y})",
        name="special_gammaincc",
    ),
    i0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"calc_i0({x})",
        triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
        cppvec=lambda x: f"{x}.i0()",
        name="i0",
    ),
    i0e=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"calc_i0e({x})",
        cppvec=lambda x: f"{x}.i0e()",
        name="special_i0e",
    ),
    i1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"calc_i1({x})",
        triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
        name="special_i1",
    ),
    i1e=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"calc_i1e({x})",
        name="special_i1e",
    ),
    log_ndtr=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"calc_log_ndtr({x})",
        name="special_log_ndtr",
    ),
    # logit
    modified_bessel_i0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"modified_bessel_i0_forward({x})",
        triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
        name="special_modified_bessel_i0",
    ),
    modified_bessel_i1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"modified_bessel_i1_forward({x})",
        triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
        name="special_modified_bessel_i1",
    ),
    modified_bessel_k0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"modified_bessel_k0_forward({x})",
        name="special_modified_bessel_k0",
    ),
    modified_bessel_k1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"modified_bessel_k1_forward({x})",
        name="special_modified_bessel_k1",
    ),
    # multigamma
    ndtr=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"calc_ndtr({x})",
        name="special_ndtr",
    ),
    ndtri=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"calc_ndtri({x})",
        name="special_ndtri",
    ),
    polygamma=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"calc_polygamma({y}, {x})",
        name="polygamma",
    ),
    # psi - alias to digamma
    # round
    scaled_modified_bessel_k0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
        name="special_scaled_modified_bessel_k0",
    ),
    scaled_modified_bessel_k1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
        name="special_scaled_modified_bessel_k1",
    ),
    # sinc
    spherical_bessel_j0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x: f"spherical_bessel_j0_forward({x})",
        name="special_spherical_bessel_j0",
    ),
    zeta=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"zeta({x}, {y})",
        name="special_zeta",
    ),
    chebyshev_polynomial_t=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
        name="special_chebyshev_polynomial_t",
    ),
    chebyshev_polynomial_u=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
        name="special_chebyshev_polynomial_u",
    ),
    chebyshev_polynomial_v=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
        name="special_chebyshev_polynomial_v",
    ),
    chebyshev_polynomial_w=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
        name="special_chebyshev_polynomial_w",
    ),
    legendre_polynomial_p=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
        name="special_legendre_polynomial_p",
    ),
    shifted_chebyshev_polynomial_t=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
        name="special_shifted_chebyshev_polynomial_t",
    ),
    shifted_chebyshev_polynomial_u=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
        name="special_shifted_chebyshev_polynomial_u",
    ),
    shifted_chebyshev_polynomial_v=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
        name="special_shifted_chebyshev_polynomial_v",
    ),
    shifted_chebyshev_polynomial_w=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
        name="special_shifted_chebyshev_polynomial_w",
    ),
    hermite_polynomial_h=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
        name="special_hermite_polynomial_h",
    ),
    hermite_polynomial_he=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
        name="special_hermite_polynomial_he",
    ),
    laguerre_polynomial_l=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
        name="special_laguerre_polynomial_l",
    ),
)


# Use mypy to check protocol implemented correctly
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
    return h


class DeferredLine(DeferredLineBase):
    """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""

    def __init__(self, name, line):
        super().__init__(line)
        self.name = name
        assert not isinstance(line, DeferredLineBase)

    def __call__(self):
        if all(
            self.name not in x
            for x in (
                V.graph.removed_buffers,
                V.kernel.removed_buffers,
                V.graph.inplaced_to_remove,
                V.kernel.inplaced_to_remove,
            )
        ):
            return self.line
        return None

    def _new_line(self, line):
        return DeferredLine(self.name, line)


class BracesBuffer(IndentedBuffer):
    def indent(self, offset=1):
        @contextlib.contextmanager
        def ctx():
            for _ in range(offset):
                self.writeline("{")
                self._indent += 1
            for _ in range(-offset):
                self._indent -= 1
                self.writeline("}")
            yield
            for _ in range(-offset):
                self.writeline("{")
                self._indent += 1
            for _ in range(offset):
                self._indent -= 1
                self.writeline("}")

        return ctx()


class InplacedBuffer(NamedTuple):
    inner_name: str
    other_names: List[str]


class KernelArgs:
    @staticmethod
    def _lookup(prefix, odict, name):
        assert isinstance(name, (str, sympy.Symbol))
        if name not in odict:
            odict[name] = f"{prefix}{len(odict)}"
        return odict[name]

    def __init__(self, sizevars=None):
        self.input_buffers = {}
        self.output_buffers = {}
        self.inplace_buffers = {}
        self.sizevars = sizevars or {}
        self.workspace_arg = None

    def __repr__(self):
        return "KernelArgs({})".format(
            ", ".join(
                map(
                    repr,
                    [
                        self.input_buffers,
                        self.output_buffers,
                        self.inplace_buffers,
                        self.sizevars,
                    ],
                )
            )
        )

    def _buffer_is_marked_removed(self, name):
        return isinstance(name, str) and name.startswith("REMOVED")

    def input(self, name):
        if V.graph.scheduler:
            name = V.graph.scheduler.mutation_real_name.get(name, name)
        assert name not in V.graph.removed_buffers, name
        if name in self.output_buffers:
            return self.output_buffers[name]
        if name in self.inplace_buffers:
            return self.inplace_buffers[name].inner_name
        if name.startswith("seed"):
            return self._lookup("seed", self.input_buffers, name)
        return self._lookup("in_ptr", self.input_buffers, name)

    def output(self, name):
        if V.graph.scheduler:
            name = V.graph.scheduler.mutation_real_name.get(name, name)
        assert name not in V.graph.removed_buffers, name
        if name in self.inplace_buffers:
            return self.inplace_buffers[name].inner_name
        return self._lookup("out_ptr", self.output_buffers, name)

    def make_inplace(self, input_name, output_name):
        assert output_name not in self.inplace_buffers
        if input_name in self.inplace_buffers:
            buf = self.inplace_buffers[input_name]
            buf.other_names.append(output_name)
            self.inplace_buffers[output_name] = buf
        else:
            buf = InplacedBuffer(
                f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
                [input_name, output_name],
            )
            self.inplace_buffers[input_name] = buf
            self.inplace_buffers[output_name] = buf

    def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
        if self.workspace_arg is None:
            self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
            return "ws_ptr", 0

        offset = self.workspace_arg.nbytes
        zero_fill = zero_fill or self.workspace_arg.zero_fill
        self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
        return "ws_ptr", offset

    def seed_offset(self, name, value):
        if value in self.sizevars:
            return self.sizevars[value]
        if name in self.sizevars.values():
            name = (
                f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
            )
        self.sizevars[value] = name
        return name

    def size(self, name):
        if str(name) == "seed":
            self.sizevars["seed"] = "seed"
            return "seed"
        return self._lookup("ks", self.sizevars, name)

    def call_names(self):
        return chain(
            self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
        )

    def wrap_ptr_arg(self, buf, dtype):
        return buf

    def wrap_size_arg(self, size):
        return str(size)

    def cpp_argdefs(self):
        from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE

        call_args = []
        arg_defs = []
        arg_types = []
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            outer = inplaced.other_names[-1]
            inner = inplaced.inner_name
            dtype = V.graph.get_dtype(outer)
            cpp_dtype = DTYPE_TO_CPP[dtype]
            arg_defs.append(f"{cpp_dtype}* {inner}")
            call_args.append(self.wrap_ptr_arg(outer, dtype))
            arg_types.append(f"{cpp_dtype}*")
        for outer, inner in self.input_buffers.items():
            if outer in self.inplace_buffers:
                continue
            dtype = V.graph.get_dtype(outer)
            cpp_dtype = DTYPE_TO_CPP[dtype]
            arg_defs.append(f"const {cpp_dtype}* {inner}")
            call_args.append(self.wrap_ptr_arg(outer, dtype))
            arg_types.append(f"const {cpp_dtype}*")
        for outer, inner in self.output_buffers.items():
            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
                continue
            dtype = V.graph.get_dtype(outer)
            cpp_dtype = DTYPE_TO_CPP[dtype]
            arg_defs.append(f"{cpp_dtype}* {inner}")
            call_args.append(self.wrap_ptr_arg(outer, dtype))
            arg_types.append(f"{cpp_dtype}*")
        for outer, inner in self.sizevars.items():
            arg_defs.append(f"const {INDEX_TYPE} {inner}")
            call_args.append(self.wrap_size_arg(outer))
            arg_types.append(f"const {INDEX_TYPE}")
            if V.graph.wrapper_code:
                V.graph.wrapper_code.ensure_size_computed(outer)
        assert self.workspace_arg is None, "Workspace not supported on CPU "
        return arg_defs, call_args, arg_types

    def python_argdefs(self):
        arg_defs: List[str] = []
        call_args: List[str] = []
        arg_types: List[torch.dtype] = []
        precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            arg_defs.append(inplaced.inner_name)
            call_args.append(inplaced.other_names[-1])
            arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
            precompile_args.append(
                TensorArg(
                    name=inplaced.inner_name,
                    buffer=inplaced.other_names[-1],
                    dtype=V.graph.get_dtype(inplaced.other_names[-1]),
                )
            )
        for outer, inner in chain(
            self.input_buffers.items(), self.output_buffers.items()
        ):
            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
                continue
            arg_defs.append(inner)
            call_args.append(outer)
            arg_types.append(V.graph.get_dtype(outer))
            precompile_args.append(
                TensorArg(
                    name=inner,
                    buffer=outer,
                    dtype=V.graph.get_dtype(outer),
                )
            )
        for outer, inner in self.sizevars.items():
            arg_defs.append(inner)
            call_args.append(outer)
            arg_types.append(type(outer))  # type: ignore[arg-type]
            precompile_args.append(SizeArg(inner, outer))
            if V.graph.wrapper_code:
                V.graph.wrapper_code.ensure_size_computed(outer)
        if self.workspace_arg is not None:
            arg_defs.append("ws_ptr")
            call_args.append("workspace")
            precompile_args.append(self.workspace_arg)
        return arg_defs, call_args, precompile_args, arg_types

    def aliases(self):
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            for other in inplaced.other_names:
                if (
                    other in V.graph.inplaced_to_remove
                    or other in V.kernel.inplaced_to_remove
                ):
                    continue
                if other in self.input_buffers:
                    yield self.input_buffers[other], inplaced.inner_name
                if other in self.output_buffers:
                    yield self.output_buffers[other], inplaced.inner_name

    def is_removed(self, name):
        def _is_removed(name, buffers):
            return name not in buffers or self._buffer_is_marked_removed(buffers[name])

        return _is_removed(name, self.output_buffers) and _is_removed(
            name, self.inplace_buffers
        )

    # Includes inplace buffers, excludes removed buffers.  Essentially,
    # after you do a call into this kernel, which buffers actually contain
    # updated data?  Modeled off of python_argdefs.
    def live_output_buffers(self):
        live_outs = OrderedSet()  # type: ignore[var-annotated]
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            live_outs.add(inplaced.other_names[-1])
        for outer, inner in self.output_buffers.items():
            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
                continue
            live_outs.add(outer)
        return live_outs


class CSEVariable:
    """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
    To do so, the backends can simply overload `Kernel.create_cse_var`
    The "CSEVariable.update_on_args" method gives you a hook for annotations
    See example of TritonCSEVariable in triton.py
    """

    def __init__(self, name, bounds: ValueRanges[Any]):
        assert isinstance(bounds, ValueRanges)
        self.name = name
        self.bounds = bounds
        self.use_count = 1  # track how many tims this expression is used

    def __str__(self):
        return self.name

    def __hash__(self) -> int:
        return hash(self.name)

    def __eq__(self, other) -> bool:
        return type(other) == type(self) and other.name == self.name

    def update_on_args(self, name, args, kwargs):
        pass

    def __repr__(self):
        return f"{self.__class__.__name__}({self.name!r})"


class CppWrapperKernelArgs(KernelArgs):
    def wrap_ptr_arg(self, buf, dtype):
        from .cpp_utils import DTYPE_TO_CPP

        if config.abi_compatible:
            # In the abi_compatible model, we just return the buf here.
            # We will form correct call args later in wrapper.generate_kernel_all.
            return buf
        else:
            return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"

    def wrap_size_arg(self, size):
        return f"{size}"


class CSE:
    """Common subexpression elimination"""

    def __init__(
        self,
        prefix="",
        suffix="",
        name_prefix="tmp",
        iter_buffers=None,
        store_cache=None,
        reduction_cache=None,
        varname_map=None,
    ):
        self.prefix = prefix
        self.suffix = suffix
        self.cache = {}
        self.name_prefix = name_prefix
        self.store_cache = store_cache or {}
        self.reduction_cache = reduction_cache or {}
        self.iter_buffer_ids = iter_buffers or itertools.count()
        self.invalidated_stores = OrderedSet()  # type: ignore[var-annotated]
        self.varname_map = varname_map or {}

    def invalidate(self, keep_vars: OrderedSet[str]):
        for name, tmp in list(self.store_cache.items()):
            if tmp not in keep_vars:
                del self.store_cache[name]
                self.invalidated_stores.add(name)
        self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}

    def clone(self):
        # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
        return CSE(
            prefix=self.prefix,
            suffix=self.suffix,
            name_prefix=self.name_prefix,
            iter_buffers=self.iter_buffer_ids,
            store_cache=self.store_cache,
            varname_map=self.varname_map,
        )

    def generate(
        self,
        buffer: IndentedBuffer,
        expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
        *,
        bounds: ValueRanges[Any] = ValueRanges.unknown(),
        write=True,
        assignment=True,
    ) -> CSEVariable:
        if isinstance(expr, OpsValue):
            expr = expr.value

        assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
        assert write or assignment
        if isinstance(expr, CSEVariable):
            # If the expressions were always created with all the information, we could
            # assert expr.bounds == bounds, but sometimes the expression is created
            # with the loose ValueRanges.unknown(), so we need to tighten the bounds
            expr.bounds = expr.bounds.tighten(bounds)
            expr.use_count += 1
            return expr
        cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
        var = self.cache.get(cache_key, None)
        if not var:
            var = self.newvar(bounds)
            self.cache[cache_key] = var
            if write:
                if V.kernel.current_node:
                    V.kernel.current_node.codegen_originating_info(
                        buffer, only_once=True
                    )
                if isinstance(expr, IndentedBuffer):
                    if assignment:
                        buffer.writeline(f"{self.prefix}{var} =")
                    buffer.splice(expr)
                    buffer.writeline(self.suffix)
                else:
                    if assignment:
                        line = f"{self.prefix}{var} = {expr}{self.suffix}"
                    else:
                        line = f"{expr}{self.suffix}"
                    buffer.writeline(line)
        else:
            var.bounds = var.bounds.tighten(bounds)
            var.use_count += 1

        return var

    def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
        var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
        var = V.kernel.create_cse_var(var_name, bounds)
        self.varname_map[var_name] = var
        return var


class CodeGen:
    def __init__(self) -> None:
        super().__init__()
        self.exit_stack = contextlib.ExitStack()

    def __enter__(self):
        self.exit_stack.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.exit_stack.__exit__(exc_type, exc_val, exc_tb)


class ScopedDict:
    def __init__(self, original_dict):
        self.original_dict = original_dict
        self.new_items = {}

    def __getitem__(self, key):
        if key in self.new_items:
            return self.new_items[key]
        return self.original_dict[key]

    def __setitem__(self, key, value):
        self.new_items[key] = value

    def __contains__(self, key):
        return key in self.new_items or key in self.original_dict

    def get(self, key, default=None):
        if key in self.new_items:
            return self.new_items[key]
        return self.original_dict.get(key, default)


class Kernel(CodeGen):
    newvar_prefix = ""
    suffix = ""
    overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
    # TODO: these look dead, but with all the getattr it's hard to tell...
    load_format: None = None
    store_format: None = None

    def __init__(self, args=None, increase_kernel_count=True):
        super().__init__()
        if increase_kernel_count:
            metrics.generated_kernel_count += 1
        self.args = args or KernelArgs()
        self.loads = IndentedBuffer()
        self.compute = IndentedBuffer()
        self.stores = IndentedBuffer()

        self.num_load = 0
        self.num_reduction = 0

        self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
        self.must_keep_buffers = OrderedSet()  # type: ignore[var-annotated]
        self.store_buffer_names = OrderedSet()  # type: ignore[var-annotated]
        self._load_mask = None
        self._load_other = None
        # OrderedSet in set_current_node
        self.current_node = None
        self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None

        self.removed_buffers = OrderedSet()  # type: ignore[var-annotated]
        self.inplaced_to_remove = OrderedSet()  # type: ignore[var-annotated]

        # key: the buffer to write
        # value: the buffer to read and whose memory can be reused for
        #   the buffer specified by key
        self.inplace_update_buffers = {}
        # Set minimum number of elements processed per thread.
        self.min_elem_per_thread = 1
        self.kernel_name = None

    @contextlib.contextmanager
    def set_current_node(self, node):
        prior = self.current_node
        self.current_node = node
        self.node_to_bounds = node._body.bounds().get_bounds()
        try:
            yield
        finally:
            self.current_node = prior

    @contextlib.contextmanager
    def swap_buffers(self, lb, cb=None, sb=None):
        def scope_cse(cse):
            new_cse = cse.clone()
            new_cse.cache = ScopedDict(cse.cache)
            new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
            new_cse.store_cache = ScopedDict(cse.store_cache)
            return new_cse

        if cb is None:
            cb = lb
        loads = self.loads
        compute = self.compute
        stores = self.stores
        cse = self.cse
        self.loads = lb
        self.compute = cb
        self.stores = sb
        self.cse = scope_cse(cse)
        try:
            yield
        finally:
            self.loads = loads
            self.compute = compute
            self.stores = stores
            self.cse = cse

    def load(self, name: str, index: sympy.Expr) -> CSEVariable:
        raise NotImplementedError

    def indirect_load(self, name: str, index: sympy.Expr):
        """A load the depends on an index we have read"""
        prior = self.loads
        try:
            # put the load in the compute section as it might have deps
            self.loads = self.compute
            return self.load(name, index)
        finally:
            self.loads = prior

    def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
        raise NotImplementedError

    def store(
        self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
    ) -> None:
        raise NotImplementedError

    def reduction(
        self,
        dtype: torch.dtype,
        src_dtype: torch.dtype,
        reduction_type: ReductionType,
        value: Union[CSEVariable, Tuple[CSEVariable, ...]],
    ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
        raise NotImplementedError

    def scan(
        self,
        dtypes: Tuple[torch.dtype, ...],
        combine_fn: Callable[
            [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
        ],
        values: Tuple[CSEVariable, ...],
    ) -> Tuple[CSEVariable, ...]:
        raise NotImplementedError

    def sort(
        self,
        dtypes: Tuple[torch.dtype, ...],
        values: Tuple[CSEVariable, ...],
        stable: bool,
        descending: bool,
    ) -> Tuple[CSEVariable, ...]:
        raise NotImplementedError

    def var_ranges(self):
        raise NotImplementedError

    def bucketize(
        self,
        values: CSEVariable,
        offsets_name: str,
        offsets_size: sympy.Expr,
        indexing_dtype: torch.dtype,
        right: bool,
    ) -> CSEVariable:
        """
        See [Note: Inductor bucketize op]
        """
        raise NotImplementedError

    @property
    def assert_function(self) -> str:
        raise NotImplementedError

    def indirect_assert(
        self,
        var: Union[CSEVariable, str],
        lower: Optional[str],
        upper: Optional[str],
        mask: Optional[Union[CSEVariable, str]] = None,
    ) -> str:
        if isinstance(var, CSEVariable):
            var = str(var)
        assert isinstance(var, str)
        assert lower is None or isinstance(lower, str)
        assert upper is None or isinstance(upper, str)
        if lower and upper:
            # The conditions need to be in parens because of Python's operator precedence.
            # It'd be less error-prone to use and/or/not, which is suported by triton
            cond = f"({lower} <= {var}) & ({var} < {upper})"
            cond_print = f"{lower} <= {var} < {upper}"
        elif lower:
            cond = f"{lower} <= {var}"
            cond_print = cond
        else:
            assert upper
            cond = f"{var} < {upper}"
            cond_print = cond

        if mask:
            cond = f"({cond}) | ~({mask})"

        return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'

    def check_bounds(
        self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
    ):
        raise NotImplementedError

    def index_to_str(self, index: sympy.Expr) -> str:
        raise NotImplementedError

    def __enter__(self):
        # TODO: hoist this to top level
        class CSEProxy:
            self.name = "CSEProxy"
            vr_analysis = ValueRangeAnalysis()

            @staticmethod
            def __getattr__(name: str) -> Callable[..., CSEVariable]:  # type: ignore[misc]
                def inner(*args, **kwargs):
                    bounds = CSEProxy._bound_variable(name, *args, **kwargs)

                    value = getattr(parent_handler, name)(*args, **kwargs)  # type: ignore[has-type]

                    def do_cse(v):
                        csevar = V.kernel.cse.generate(
                            V.kernel.compute, v, bounds=bounds
                        )
                        csevar.update_on_args(name, args, kwargs)
                        return csevar

                    return pytree.tree_map(do_cse, value)

                return inner

            @staticmethod
            def _bound_variable(name, *args, **kwargs):
                """
                If the variable comes from an FX node, we forward the bound we have already computed
                Else, if the variable when codegen'ing another op, we try to compute its bounds
                """
                from ..select_algorithm import TritonTemplateKernel

                if isinstance(V.kernel, TritonTemplateKernel):
                    return ValueRanges.unknown()

                fx_node = V.interpreter.current_node
                if fx_node.target == name and self.node_to_bounds is not None:
                    assert isinstance(self.node_to_bounds, dict)
                    return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
                elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
                    # These create lots of inner strings. We would need to compute the bounds at the ops
                    # We will also likely not get much from computing VRs on these nodes
                    if any(
                        s in fx_node.target
                        for s in ("set_indirect", "reduction", "scan")
                    ):
                        return ValueRanges.unknown()

                    # We assume that the inputs come from `ops.` and are not strings. If you want to generate
                    # intermediary strings, wrap them in CSE variables with properly initialised bounds.

                    # If there is no FX bound but we know how to compute one we do so
                    assert not kwargs

                    def arg_to_bound(x):
                        if isinstance(x, CSEVariable):
                            return x.bounds
                        elif isinstance(x, sympy.Expr):
                            return bound_sympy(x)
                        else:
                            return x

                    arg_bounds = list(map(arg_to_bound, args))
                    return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
                else:
                    return ValueRanges.unknown()

            @staticmethod
            def indirect_indexing(
                var: CSEVariable,
                size: Union[sympy.Expr, int],
                check: bool = True,
                wrap_neg=True,
            ):
                if isinstance(size, int):
                    size = sympy.Integer(size)
                assert isinstance(size, sympy.Expr), size
                # Skip CSE since this doesn't return an expression

                if var.bounds.lower < 0:  # type: ignore[operator]
                    if wrap_neg:
                        stm = ops.add(var, ops.index_expr(size, torch.long))
                        # Mixed negative and non-negative
                        if var.bounds.upper >= 0:  # type: ignore[operator]
                            lt = ops.lt(var, 0)
                            stm = ops.where(lt, stm, var)
                    else:
                        stm = var

                    # Propagate bounds as we know how to compute them properly
                    new_bounds = ValueRanges.unknown()
                    if var.bounds != ValueRanges.unknown() and isinstance(
                        size, sympy.Number
                    ):
                        # Take the negative part of the bound and add size to it
                        # Then take union of that and the positive part
                        # This is a tighter bound than that of a generic ops.where, as we have info on the cond
                        neg_bounds = var.bounds & ValueRanges(-int_oo, -1)
                        new_bounds = ValueRanges(
                            neg_bounds.lower + size, neg_bounds.upper + size
                        )
                        # We don't have a good way of representing the empty range
                        if var.bounds.upper >= 0:  # type: ignore[operator]
                            pos = var.bounds & ValueRanges(0, int_oo)
                            new_bounds = new_bounds | pos

                    var = self.cse.generate(self.compute, stm, bounds=new_bounds)

                sympy_var = parent_handler.indirect_indexing(var, size, check)
                if generate_assert(check):
                    assert_lower = not (var.bounds.lower >= 0)
                    # value ranges cannot x < s when x and s are symbols
                    assert_upper = not isinstance(size, sympy.Number) or not (
                        var.bounds.upper < size
                    )
                    self.check_bounds(sympy_var, size, assert_lower, assert_upper)
                return sympy_var

            @staticmethod
            def check_bounds(
                expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
            ):
                return self.check_bounds(expr, size, lower, upper)

            @staticmethod
            def load(name: str, index: sympy.Expr) -> CSEVariable:
                if name in self.cse.invalidated_stores:
                    # A load from an invalidated store requires us to
                    # keep the actual buffer around
                    V.kernel.must_keep_buffers.add(name)
                if free_symbol_is_type(index, SymT.TMP):
                    return self.indirect_load(name, index)
                store_cache = self.cse.store_cache
                if name in store_cache:
                    return store_cache[name]
                out = self.load(name, index)
                # count load that is not in the store_cache, and also not in the
                # cse cache.
                if out.use_count == 1:
                    self.num_load += 1
                return out

            @staticmethod
            def _update_store_cache(name: str, value: CSEVariable):
                self.cse.store_cache[name] = value
                if self.current_node and name in V.graph.name_to_buffer:
                    buf = self.current_node.get_output(name)
                    for other_name in buf.get_mutations():
                        self.cse.store_cache[other_name] = value

            @staticmethod
            def store(
                name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
            ) -> None:
                self.store_buffer_names.add(name)
                if mode is None:
                    CSEProxy._update_store_cache(name, value)
                if name not in V.graph.removed_buffers:
                    return self.store(name, index, value, mode=mode)
                else:
                    return None  # type: ignore[return-value]

            @staticmethod
            def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
                self.store_buffer_names.add(name)
                CSEProxy._update_store_cache(name, value)

                if name not in V.graph.removed_buffers:
                    return self.store_reduction(name, index, value)

            @staticmethod
            def reduction(
                dtype: torch.dtype,
                src_dtype: torch.dtype,
                reduction_type: ReductionType,
                value: Union[CSEVariable, Tuple[CSEVariable, ...]],
            ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
                self.num_reduction += 1
                return self.reduction(dtype, src_dtype, reduction_type, value)

            @staticmethod
            def scan(
                dtypes: Tuple[torch.dtype, ...],
                combine_fn: Callable[
                    [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
                    Tuple[CSEVariable, ...],
                ],
                values: Tuple[CSEVariable, ...],
            ) -> Tuple[CSEVariable, ...]:
                return self.scan(dtypes, combine_fn, values)

            @staticmethod
            def sort(
                dtypes: Tuple[torch.dtype, ...],
                values: Tuple[CSEVariable, ...],
                stable: bool,
                descending: bool,
            ) -> Tuple[CSEVariable, ...]:
                return self.sort(dtypes, values, stable, descending)

            @staticmethod
            def bucketize(
                values: CSEVariable,
                offsets_name: str,
                offsets_size: sympy.Expr,
                indexing_dtype: torch.dtype,
                right: bool,
            ) -> CSEVariable:
                """
                [Note: Inductor bucketize op]

                Given values (tensor) and offsets_name (reference to the name of a 1D
                tensor), calculate the bucket that each value belongs to.

                e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
                return =        [ 0, 1, 1, 1, 1, 3, 3, 4].

                When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
                When right == True,  bucket i refers to range [offsets[i], offsets[i+1]).

                Offsets must be non-decreasing or the result is undefined.
                """
                return self.bucketize(
                    values, offsets_name, offsets_size, indexing_dtype, right
                )

        # Use mypy to check protocol implemented correctly
        def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
            return h

        super().__enter__()
        assert self.overrides
        parent_handler = self.overrides(V.get_ops_handler())
        self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
        self.exit_stack.enter_context(V.set_kernel_handler(self))
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Note that V.graph.scheduler can be None when codegening triton template
        kernels.
        """
        if V.graph.scheduler:
            V.graph.scheduler.remove_kernel_local_buffers()
        super().__exit__(exc_type, exc_val, exc_tb)

    def rename_indexing(self, index) -> sympy.Expr:
        # adds the necessary kernel args for index expressions
        # and renames variables in index expressions to kernel arg names
        if isinstance(index, (list, tuple)):
            return [self.rename_indexing(x) for x in index]  # type: ignore[return-value]
        index = V.graph.sizevars.simplify(index)
        sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
        replacements = {
            x: self.args.size(x)
            for x in sorted_symbols
            if symbol_is_type(
                x,
                (
                    SymT.UNBACKED_INT,
                    SymT.SIZE,
                    SymT.PRECOMPUTED_SIZE,
                ),
            )
        }
        return sympy_subs(index, replacements)

    def create_cse_var(self, *args, **kwargs):
        return CSEVariable(*args, **kwargs)


@dataclasses.dataclass
class OptimizationContext:
    key: ClassVar[str] = "opt_ctx"

    dtype: Optional[torch.dtype] = None
    ops_name: str = ""


@functools.lru_cache(None)
def jinja2_env():
    try:
        import jinja2

        return jinja2.Environment(
            undefined=jinja2.StrictUndefined,
        )
    except ImportError:
        return None


class KernelTemplate:
    """
    Base class for defining kernel templates.

    Children classes: TritonTemplate, CUDATemplate
    """

    @staticmethod
    def indent_except_first(source: str, num_indents: int, indents_spacing=4):
        lines = source.splitlines(True)
        if len(lines) > 1:
            lines[1:] = [
                (" " * indents_spacing * num_indents) + line for line in lines[1:]
            ]
        return "".join(lines)

    @staticmethod
    def _template_from_string(source):
        env = jinja2_env()
        if env is not None:
            env.filters["indent_except_first"] = KernelTemplate.indent_except_first
            from jinja2 import TemplateSyntaxError

            class DetailedTemplateSyntaxError(TemplateSyntaxError):
                def __init__(self, original_error):
                    super().__init__(
                        original_error.message,
                        original_error.lineno,
                        original_error.name,
                        original_error.filename,
                    )
                    self.original_error = original_error

                def __str__(self):
                    error_info = f"Error in template at line {self.lineno}\n"
                    error_info += f"Error message: {self.message}\n"
                    if hasattr(self.original_error, "source"):
                        lines = self.original_error.source.split("\n")
                        error_info += "Context:\n"
                        start = max(0, self.lineno - 2)
                        end = min(len(lines), self.lineno + 2)
                        for i in range(start, end):
                            if i == self.lineno - 1:
                                error_info += f"{i+1}: --> {lines[i]}\n"
                                if hasattr(self.original_error, "column"):
                                    error_info += (
                                        "     "
                                        + " " * (self.original_error.column - 1)
                                        + "^\n"
                                    )
                            else:
                                error_info += f"{i+1}:     {lines[i]}\n"
                    return error_info

            try:
                return env.from_string(source)
            except TemplateSyntaxError as e:
                raise DetailedTemplateSyntaxError(e) from e

        return None

    @staticmethod
    def _fake_get_dtype(fake_out):
        _get_dtype_real = V.graph.get_dtype

        def get_dtype(name):
            if name == fake_out.get_name():
                return fake_out.get_dtype()
            return _get_dtype_real(name)

        return get_dtype

    def __init__(self, name: str):
        self.name = name

    def maybe_append_choice(self, choices, **kwargs):
        """
        Maybe generates a new ChoiceCaller and appends it into existing choices.

        choices: A list of ChoiceCallers.
        kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
        """

        try:
            choices.append(self.generate(**kwargs))
        except NotImplementedError as e:
            pass

    def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
        """
        Generates a ChoiceCaller instance from the given arguments.
        """

        raise NotImplementedError
