# mypy: allow-untyped-defs
from __future__ import annotations

import functools
import itertools
import re
from enum import auto, Enum
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple

import sympy

import torch.fx
from torch._dynamo.utils import identity
from torch.utils._sympy.symbol import SymT

from . import config, dependencies
from .codegen.common import index_prevent_reordering
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
from .virtualized import ops, V


class InterpreterShim(torch.fx.Interpreter):
    @staticmethod
    @functools.lru_cache(None)
    def _dummy_gm():
        return torch.fx.symbolic_trace(identity)

    def __init__(self, graph, submodules):
        # call super() with a placeholder to avoid constructing a
        # GraphModule which is very expensive (it does codegen).
        super().__init__(self._dummy_gm(), garbage_collect_values=False)
        self.module = self  # type: ignore[assignment]
        self.graph = graph
        self.submodules = submodules
        self.extra_traceback = False
        self.fetch_attr = submodules.__getitem__  # type: ignore[method-assign]
        self.current_node = None

    def run_node(self, n: torch.fx.Node) -> Any:
        self.current_node = n
        return super().run_node(n)

    def run(self, *args, **kwargs):
        with V.set_interpreter_handler(self):
            return super().run(*args, **kwargs)


class MemoryEntry(NamedTuple):
    index_name: str  # LoopBody.indexing_exprs[index_name]
    buffer_name: Optional[str]
    mode: Optional[str]  # V.ops.store(..., mode=mode)


class MemoryUsageType(Enum):
    # These are 1:1 with the opcode generating the usage
    LOAD = auto()
    LOAD_SEED = auto()
    STORE = auto()
    STORE_REDUCTION = auto()
    INDEX_EXPR = auto()
    CHECK_BOUNDS = auto()
    BUCKETIZE = auto()


class LoopBody:
    """
    Captures the body of a Loops subclass into an FX graph.  Persists any
    indexing simplifications and makes it easier to analyze loop bodies.
    """

    indexing_exprs: Dict[str, sympy.Expr]
    indexing_exprs_name: Dict[sympy.Expr, str]
    submodules: Dict[str, Any]
    subblocks: Dict[str, LoopBodyBlock]
    indirect_vars: List[str]
    indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr]
    root_block: LoopBodyBlock
    memory_usage: Dict[MemoryUsageType, List[MemoryEntry]]

    def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars):
        super().__init__()

        _flat_sizes = tuple(var_ranges.values())
        self.sizes = (
            _flat_sizes[: len(iter_vars)],
            _flat_sizes[len(iter_vars) :],
        )

        self.iter_vars = iter_vars
        self.reduce_vars = reduce_vars
        self.var_ranges = var_ranges

        if isinstance(fn, LoopBody):
            self._init_with_copy(fn, args)
        else:
            self._init_with_tracing(fn, args)

        self.indexing = None

    def _init_with_tracing(self, fn, args):
        """Do an FX trace of an arbitrary callable to construct self"""
        self.indexing_exprs = {}
        self.indexing_exprs_name = {}
        self.submodules = {"get_index": self.get_index}
        self.subblocks = {}
        self.indirect_vars = []
        self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {}
        self.memory_usage = {t: [] for t in MemoryUsageType}
        self.root_block = LoopBodyBlock(self, fn, args)  # traces
        del self.indexing_exprs_name  # not used after _init_with_tracing

    def _init_with_copy(self, other: LoopBody, args):
        """
        _init_with_tracing() is slow, so this is a fast path in the case
        where we are just reordering/merging/splitting the args of an
        existing LoopBody.
        """
        indexing_exprs = other.indexing_from_args(args)
        self.indexing_exprs = {
            name: V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges)
            for name, expr in indexing_exprs.items()
        }
        self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()}
        self.indirect_vars = other.indirect_vars
        self.indirect_var_ranges = other.indirect_var_ranges
        self.memory_usage = other.memory_usage
        self.root_block = other.root_block.clone(self)

        submodules = {**other.submodules}
        submodules.pop("get_index")
        self.submodules = {
            "get_index": self.get_index,
            **{k: v.clone(self) for k, v in submodules.items()},  # type: ignore[attr-defined]
        }

    def merge_loops(self) -> LoopBody:
        """
        Merge both iteration and reduction loops and return a new LoopBody.
        """
        old_body = self
        old_sizes = self.sizes
        old_iter_vars, old_reduce_vars = old_body.vars
        old_iter_sizes, old_reduce_sizes = old_sizes

        index_exprs = [*old_body.indexing_exprs.values()]

        iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops(
            old_iter_vars,
            old_iter_sizes,
            index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes),
        )

        reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops(
            old_reduce_vars,
            old_reduce_sizes,
            index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes),
        )

        # if iter_sizes == old_iter_sizes:
        #     # no dimensions get merged.
        #     return old_sizes, old_body

        # Note: if no dimension get merges, the symbol prefix will
        # remain 'y'. But if we merge dimensions, we change prefix to
        # 'z'. If this is an issue, we can always retrace the LoopBody
        # to change symbol prefix to 'z'.
        #
        # There is indeed an issue due to symbol name conflicting.
        # y0 maybe reused for the y dimension later.
        (
            iter_vars,
            reduce_vars,
        ), var_ranges = dependencies.index_vars_no_squeeze(
            iter_sizes, reduce_sizes, prefix="t"
        )
        new_body = LoopBody(
            old_body,
            [iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
            var_ranges,
            iter_vars,
            reduce_vars,
        )

        # use the original symbol prefix
        # Can try to optimize if this is a bottleneck for compilation time
        (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
            iter_sizes, reduce_sizes, prefix="z"
        )
        new_body2 = LoopBody(
            new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
        )
        return new_body2

    def reorder_iter_loops(self, new_order) -> LoopBody:
        """
        Reorder iteration loops and return a new LoopBody.
        """
        from .ir import same_reorder

        old_body = self
        old_sizes = self.sizes
        assert len(old_sizes[0]) == len(new_order)
        reorder_fn = same_reorder(new_order)

        iter_size, reduce_size = old_sizes
        new_iter_size = reorder_fn(iter_size)

        new_sizes = (new_iter_size, reduce_size)

        (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
            *new_sizes, prefix="t"  # type: ignore[arg-type]
        )

        inverse_order = {b: a for a, b in enumerate(new_order)}
        inverse_order = [inverse_order[i] for i in range(len(new_order))]

        def new_body(*indices: Sequence[sympy.Expr]) -> Any:
            index = list(itertools.chain(*indices))
            assert len(index) == len(iter_size) + len(reduce_size)
            iter_idx = index[: len(iter_size)]
            reduce_idx = index[len(iter_size) :]
            iter_idx = [iter_idx[i] for i in inverse_order]
            return old_body(iter_idx, reduce_idx)

        loop_body = LoopBody(
            new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars
        )

        # use the original symbol prefix so we can do multiple round of reordering
        (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
            *new_sizes, prefix="z"  # type: ignore[arg-type]
        )
        new_body = LoopBody(
            loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
        )
        return new_body

    @property
    def vars(self):
        assert self.iter_vars is not None
        assert self.reduce_vars is not None
        return self.iter_vars, self.reduce_vars

    @cache_on_self
    def get_nodes(self):
        all_graphs = itertools.chain(
            (self.root_block.graph,),
            (block.graph for block in self.subblocks.values()),
        )
        return [node for graph in all_graphs for node in graph.nodes]

    @cache_on_self
    def bounds(self):
        # Doing a local import to avoid dumping all the code here
        from .bounds import BoundVars

        return BoundVars(self)

    def get_read_expr(self, buffer_name):
        # reversed to match old behavior
        for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]):
            if entry.buffer_name == buffer_name:
                return self.indexing_exprs[entry.index_name]
        raise KeyError(buffer_name)

    def get_write_expr(self, buffer_name):
        for entry in itertools.chain(
            self.memory_usage[MemoryUsageType.STORE],
            self.memory_usage[MemoryUsageType.STORE_REDUCTION],
        ):
            if entry.buffer_name == buffer_name:
                return self.indexing_exprs[entry.index_name]
        raise KeyError(buffer_name)

    def get_read_exprs(self):
        return [
            self.indexing_exprs[entry.index_name]
            for entry in self.memory_usage[MemoryUsageType.LOAD]
        ]

    def get_write_exprs(self):
        return [
            self.indexing_exprs[entry.index_name]
            for entry in itertools.chain(
                self.memory_usage[MemoryUsageType.STORE],
                self.memory_usage[MemoryUsageType.STORE_REDUCTION],
            )
        ]

    def debug_str(self):
        lines = [f"var_ranges = {dict(self.var_ranges)}"]
        lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
        lines.extend(
            [
                block.debug_str(name)
                for name, block in itertools.chain(
                    [("body", self.root_block)], self.subblocks.items()
                )
            ]
        )
        return "\n".join(lines)

    def is_memory_copy(self) -> bool:
        """
        True of this contains only a single loads and store.
        Note, this could involve a layout change.
        """
        return (
            len(self.memory_usage[MemoryUsageType.LOAD]) == 1
            and len(self.memory_usage[MemoryUsageType.STORE]) == 1
            and len(self.submodules) == 1  # get_index
            and self.root_block.contains_only_ops(("load", "store"))
        )

    __repr__ = debug_str

    def add_index_expr(
        self,
        expr: sympy.Expr,
        mtype: MemoryUsageType,
        buffer_name: Optional[str] = None,
        mode: Optional[str] = None,
    ):
        name = self.indexing_exprs_name.get(expr)
        if not name:
            name = f"index{len(self.indexing_exprs)}"
            self.indexing_exprs_name[expr] = name
            self.indexing_exprs[name] = expr
        self.memory_usage[mtype].append(MemoryEntry(name, buffer_name, mode))
        return name

    def add_submodule(self, block, prefix):
        """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
        if prefix[-1].isnumeric() and prefix not in self.submodules:
            name = prefix
        else:
            name = f"{prefix}{len(self.submodules)}"
        self.submodules[name] = block
        return name

    def add_indirect(self, size):
        var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars))
        assert var not in self.indirect_var_ranges
        self.indirect_vars.append(var)
        self.indirect_var_ranges[var] = size
        return var

    def replace_indirect(self, old, new):
        """Swap in a variable used in indirect indexing"""
        if str(old) == str(new):
            return
        assert self.indexing is not None
        self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}

    def get_index(self, name):
        assert self.indexing is not None
        return self.indexing[name]

    def indexing_from_args(self, indices):
        index = [*itertools.chain.from_iterable(indices)]
        assert len(index) == len(self.var_ranges), (index, self.var_ranges)
        assert all(
            v not in self.var_ranges for v in index
        ), f"{self.var_ranges=}, {indices=}"
        replacements = dict(zip(self.var_ranges.keys(), index))
        return {
            name: sympy_subs(expr, replacements)
            for name, expr in self.indexing_exprs.items()
        }

    def __call__(self, *indices):
        self.indexing = self.indexing_from_args(indices)
        result = self.root_block()
        self.indexing = None
        return result

    def bind_set_indirect_shim(self, var, size, check, wrap_neg):
        def set_indirect(new_var):
            self.replace_indirect(
                var, V.ops.indirect_indexing(new_var, size, check, wrap_neg)
            )

        set_indirect.clone = functools.partial(  # type: ignore[attr-defined]
            LoopBody.bind_set_indirect_shim,
            var=var,
            size=size,
            check=check,
            wrap_neg=wrap_neg,
        )
        return set_indirect

    def bind_scan_shim(self, combine_fn):
        def shim(dtypes, values):
            return V.ops.scan(dtypes, combine_fn, values)

        shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn)  # type: ignore[attr-defined]
        return shim

    def bind_masked_shim(self, name):
        def shim(mask, other):
            return V.ops.masked(mask, self.subblocks[name], other)

        shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name)  # type: ignore[attr-defined]
        return shim


class LoopBodyBlock:
    """
    Captures the body of a Loops subclass into an FX graph.
    In normal cases there will be a 1:1 mapping between LoopBody and
    LoopBodyBlock, hower in the case of ops.masked() the masked out
    operations will manifest as an extra LoopBodyBlock.
    """

    def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]):
        self.body = body

        def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs):
            return tracer.create_proxy(
                "call_module",
                "get_index",
                (body.add_index_expr(expr, mtype, **kwargs),),
                {},
            )

        class CaptureIndexing(V.WrapperHandler):  # type: ignore[name-defined]
            self.name = "CaptureIndexing"

            def load(self, name: str, index: sympy.Expr):
                index = add_index(index, MemoryUsageType.LOAD, buffer_name=name)
                return self._inner.load(name, index)

            def load_seed(self, name: str, index: int):
                assert isinstance(index, int)
                body.add_index_expr(
                    sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name
                )
                return self._inner.load_seed(name, index)

            def store(self, name, index, value, mode=None):
                index = add_index(
                    index, MemoryUsageType.STORE, buffer_name=name, mode=mode
                )
                return self._inner.store(name, index, value, mode)

            def store_reduction(self, name, index, value):
                index = add_index(
                    index, MemoryUsageType.STORE_REDUCTION, buffer_name=name
                )
                return self._inner.store_reduction(name, index, value)

            def reduction(self, dtype, src_dtype, reduction_type, value):
                result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
                if "welford" in reduction_type:
                    return tuple(result[i] for i in range(3))
                return result

            def index_expr(self, index, dtype):
                if isinstance(index, (int, sympy.Integer)):
                    return self._inner.constant(int(index), dtype)
                index = add_index(index, MemoryUsageType.INDEX_EXPR)
                return self._inner.index_expr(index, dtype)

            def check_bounds(self, index, size, lower, upper):
                index = add_index(index, MemoryUsageType.CHECK_BOUNDS)
                size = add_index(size, MemoryUsageType.CHECK_BOUNDS)
                return self._inner.check_bounds(index, size, lower, upper)

            def bucketize(
                self,
                values,
                offsets_name: str,
                offsets_size: sympy.Expr,
                indexing_dtype: torch.dtype,
                right: bool,
            ):
                offsets_size = add_index(
                    offsets_size, MemoryUsageType.BUCKETIZE, buffer_name=offsets_name
                )
                return self._inner.bucketize(
                    values, offsets_name, offsets_size, indexing_dtype, right
                )

            @staticmethod
            def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy):
                """
                Recursively capture the masked out body in another LoopBodyBlock
                """
                name = self.body.add_submodule(None, "masked_subblock")
                self.body.submodules[name] = self.body.bind_masked_shim(name)
                self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
                return tracer.create_proxy(
                    "call_module", name, (mask_proxy, other_proxy), {}
                )

            @staticmethod
            def scan(
                dtype_proxy,
                combine_fn: Callable[
                    [Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]
                ],
                value_proxy,
            ):
                shim = self.body.bind_scan_shim(combine_fn)
                name = self.body.add_submodule(shim, "scan")
                result = tracer.create_proxy(
                    "call_module",
                    name,
                    (dtype_proxy, value_proxy),
                    {},
                )
                # Proxies are iterable, but some methods expect tuples/lists
                return tuple(result[i] for i in range(len(value_proxy)))

            def sort(self, dtypes, values, stable, descending):
                result = self._inner.sort(dtypes, values, stable, descending)
                # Proxies are iterable, but some methods expect tuples/lists
                return tuple(result[i] for i in range(len(values)))

            def frexp(self, value_proxy):
                result = self._inner.frexp(value_proxy)
                # Proxies are iterable, but some methods expect tuples/lists
                return (result[0], result[1])

            @staticmethod
            def indirect_indexing(index_proxy, size, check=True, wrap_neg=True):
                """
                Flow data from tensors into indexing formulas.
                Introduce a call_module to update the indexing.
                """

                var = self.body.add_indirect(size)
                set_indirect = self.body.bind_set_indirect_shim(
                    var, size, check, wrap_neg
                )
                tracer.create_proxy(
                    "call_module",
                    self.body.add_submodule(set_indirect, f"set_{var}"),
                    (index_proxy,),
                    {},
                )
                return var

            @staticmethod
            def output(result):
                tracer.create_proxy("output", "output", (result,), {})

        tracer = torch.fx.Tracer()
        tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
        proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})

        from .index_propagation import IndexPropagation
        from .sizevars import SimplifyIndexing

        handler: Any = SimplifyIndexing(
            CaptureIndexing(proxy_ops), self.body.var_ranges
        )
        if config.constant_and_index_propagation:
            handler = IndexPropagation(
                handler, self.body.var_ranges, self.body.indirect_var_ranges
            )

        with V.set_ops_handler(handler):
            # This indirection is just a cute way to get IndexPropagation to
            # unwrap the return value.
            ops.output(fn(*args))
        self.graph = tracer.graph

    def __call__(self):
        graph = self.graph
        submodules = self.body.submodules

        return InterpreterShim(graph, submodules).run(V.get_ops_handler())

    def debug_str(self, name="block"):
        code = torch.fx.GraphModule(self.body.submodules, self.graph).code
        return re.sub(
            # strip `; del var0` suffixes to make output prettier
            r";[^\n]*",
            "",
            code.strip().replace("def forward(", f"def {name}("),
        )

    def contains_only_ops(self, allowed_ops) -> bool:
        return all(
            node.target in allowed_ops
            for node in self.graph.find_nodes(op="call_method")
        )

    def clone(self, body: LoopBody):
        """Shallow copy with a new parent LoopBody"""
        copy = LoopBodyBlock.__new__(LoopBodyBlock)
        copy.__dict__.update({**self.__dict__, "body": body})
        return copy
