# mypy: ignore-errors

import itertools
import operator
import sys
from typing import Dict, List, Optional, TYPE_CHECKING, Union

from .. import polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import (
    handle_observed_exception,
    ObservedUserStopIteration,
    raise_observed_exception,
    unimplemented,
    UserError,
)
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable


if TYPE_CHECKING:
    from torch._dynamo.symbolic_convert import InstructionTranslator


MAX_ITERATOR_LIMIT = 100 * 1024  # 100k


class ItertoolsVariable(VariableTracker):
    def __init__(self, value, **kwargs) -> None:
        super().__init__(**kwargs)
        self.value = value

    def __repr__(self) -> str:
        return f"ItertoolsVariable({self.value})"

    def as_python_constant(self):
        return self.value

    def call_function(
        self,
        tx: "InstructionTranslator",
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        if (
            self.value is itertools.product
            and not kwargs
            and all(arg.has_unpack_var_sequence(tx) for arg in args)
        ):
            seqs = [arg.unpack_var_sequence(tx) for arg in args]
            items = []
            for item in itertools.product(*seqs):
                items.append(variables.TupleVariable(list(item)))
            return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
        elif self.value is itertools.accumulate:
            from .builtin import BuiltinVariable

            if any(key not in ["initial", "func"] for key in kwargs.keys()):
                unimplemented(
                    "Unsupported kwargs for itertools.accumulate: "
                    f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}"
                )

            acc = kwargs.get("initial")

            if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx):
                seq = args[0].unpack_var_sequence(tx)

                if "func" in kwargs and len(args) == 1:
                    func = kwargs["func"].call_function
                elif len(args) == 2:
                    func = args[1].call_function
                elif len(args) == 1:
                    # Default to operator.add
                    func = BuiltinVariable(operator.add).call_function
                else:
                    unimplemented(
                        "itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg"
                    )
            else:
                unimplemented("Unsupported arguments for itertools.accumulate")

            items = []
            if acc is not None:
                items.append(acc)
            for item in seq:
                if acc is None:
                    acc = item
                else:
                    try:
                        acc = func(tx, [acc, item], {})
                    except Exception as e:
                        unimplemented(
                            f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})",
                            from_exc=e,
                        )
                items.append(acc)

            return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
        elif (
            self.value is itertools.combinations
            and not kwargs
            and len(args) == 2
            and args[0].has_unpack_var_sequence(tx)
            and args[1].is_python_constant()
        ):
            iterable = args[0].unpack_var_sequence(tx)
            r = args[1].as_python_constant()

            items = []
            for item in itertools.combinations(iterable, r):
                items.append(variables.TupleVariable(list(item)))
            return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
        elif self.value is itertools.groupby:
            if any(kw != "key" for kw in kwargs.keys()):
                unimplemented(
                    "Unsupported kwargs for itertools.groupby: "
                    f"{','.join(set(kwargs.keys()) - {'key'})}"
                )

            def retrieve_const_key(key):
                if isinstance(key, variables.SymNodeVariable):
                    return key.evaluate_expr()
                elif isinstance(key, variables.ConstantVariable):
                    return key.as_python_constant()
                else:
                    unimplemented(
                        "Unsupported key type for itertools.groupby: " + str(type(key))
                    )

            if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
                seq = args[0].unpack_var_sequence(tx)
                keyfunc = (
                    (
                        lambda x: (
                            retrieve_const_key(
                                kwargs.get("key").call_function(tx, [x], {})
                            )
                        )
                    )
                    if "key" in kwargs
                    else None
                )
            else:
                unimplemented("Unsupported arguments for itertools.groupby")

            result = []
            try:
                for k, v in itertools.groupby(seq, key=keyfunc):
                    result.append(
                        variables.TupleVariable(
                            [
                                variables.ConstantVariable.create(k)
                                if variables.ConstantVariable.is_literal(k)
                                else k,
                                variables.ListIteratorVariable(
                                    list(v), mutable_local=MutableLocal()
                                ),
                            ],
                            mutable_local=MutableLocal(),
                        )
                    )
            except Exception as e:
                unimplemented(
                    "Unexpected failure when calling itertools.groupby",
                    from_exc=e,
                )
            return variables.ListIteratorVariable(result, mutable_local=MutableLocal())
        elif self.value is itertools.repeat:
            if len(args) < 2:
                return variables.RepeatIteratorVariable(
                    *args, mutable_local=MutableLocal()
                )

            from .builder import SourcelessBuilder

            return tx.inline_user_function_return(
                SourcelessBuilder.create(tx, polyfills.repeat), args, kwargs
            )
        elif self.value is itertools.count:
            return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
        elif self.value is itertools.cycle:
            return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
        elif self.value is itertools.dropwhile:
            return variables.UserFunctionVariable(polyfills.dropwhile).call_function(
                tx, args, kwargs
            )
        elif self.value is itertools.zip_longest:
            return variables.UserFunctionVariable(polyfills.zip_longest).call_function(
                tx, args, kwargs
            )
        else:
            return super().call_function(tx, args, kwargs)


class IteratorVariable(VariableTracker):
    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    def next_variable(self, tx):
        unimplemented("abstract method, must implement")

    # NOTE: only call when unpacking this iterator safely done eagerly!
    # Normally, iterators are accessed lazily.
    # Example of safe eager unpacking: list(map(f, seq))
    # Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
    def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
        result = []
        while True:
            try:
                result.append(self.next_variable(tx))
            except ObservedUserStopIteration:
                handle_observed_exception(tx)
                break
        return result

    # don't call force_unpack_var_sequence since it can mutate
    # IteratorVariable state!
    def has_force_unpack_var_sequence(self, tx) -> bool:
        return True


class RepeatIteratorVariable(IteratorVariable):
    def __init__(self, item: VariableTracker, **kwargs) -> None:
        super().__init__(**kwargs)
        self.item = item

    # Repeat needs no mutation, clone self
    def next_variable(self, tx):
        return self.item

    def reconstruct(self, codegen):
        codegen.add_push_null(
            lambda: codegen.extend_output(
                [
                    codegen.create_load_python_module(itertools),
                    codegen.create_load_attr("repeat"),
                ]
            )
        )
        codegen(self.item)
        codegen.extend_output(create_call_function(1, False))


class CountIteratorVariable(IteratorVariable):
    def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
        super().__init__(**kwargs)
        if not isinstance(item, VariableTracker):
            item = ConstantVariable.create(item)
        if not isinstance(step, VariableTracker):
            step = ConstantVariable.create(step)
        self.item = item
        self.step = step

    def next_variable(self, tx):
        assert self.mutable_local
        old_item = self.item
        tx.output.side_effects.mutation(self)
        self.item = self.item.call_method(tx, "__add__", [self.step], {})
        return old_item

    def reconstruct(self, codegen):
        codegen.add_push_null(
            lambda: codegen.extend_output(
                [
                    codegen.create_load_python_module(itertools),
                    codegen.create_load_attr("count"),
                ]
            )
        )
        codegen(self.item)
        codegen(self.step)
        codegen.extend_output(create_call_function(2, False))


class CycleIteratorVariable(IteratorVariable):
    def __init__(
        self,
        iterator: IteratorVariable,
        saved: List[VariableTracker] = None,
        saved_index: int = 0,
        item: Optional[VariableTracker] = None,
        **kwargs,
    ) -> None:
        if saved is None:
            saved = []
        super().__init__(**kwargs)
        self.iterator = iterator
        self.saved = saved
        self.saved_index = saved_index
        self.item = item

    def next_variable(self, tx):
        assert self.mutable_local

        if self.iterator is not None:
            try:
                new_item = self.iterator.next_variable(tx)
                if len(self.saved) > MAX_ITERATOR_LIMIT:
                    unimplemented(
                        "input iterator to itertools.cycle has too many items"
                    )
                tx.output.side_effects.mutation(self)
                self.saved.append(new_item)
                self.item = new_item
                if self.item is None:
                    return self.next_variable(tx)
                return self.item
            except ObservedUserStopIteration:
                handle_observed_exception(tx)
                self.iterator = None
                return self.next_variable(tx)
        elif len(self.saved) > 0:
            tx.output.side_effects.mutation(self)
            self.saved_index = (self.saved_index + 1) % len(self.saved)
            return self.item
        else:
            raise_observed_exception(StopIteration, tx, self)


class ZipVariable(IteratorVariable):
    """
    Represents zip(*iterables)
    """

    _nonvar_fields = {
        "index",
        "strict",
        *IteratorVariable._nonvar_fields,
    }

    def __init__(
        self,
        iterables: List[Union[List[VariableTracker], VariableTracker]],
        strict: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        assert isinstance(iterables, list)
        # can be list[Variable] or VariableTracker (with next_variable implemented)
        self.iterables = iterables
        self.index = 0
        self.strict = strict

    def python_type(self):
        return zip

    def has_unpack_var_sequence(self, tx) -> bool:
        return all(
            isinstance(it, list) or it.has_unpack_var_sequence(tx)
            for it in self.iterables
        )

    def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
        assert self.has_unpack_var_sequence(tx)
        iterables = []
        for it in self.iterables:
            if isinstance(it, list):
                iterables.append(it[self.index :])
            else:
                iterables.append(it.unpack_var_sequence(tx))
        kwargs = {"strict": self.strict} if self.strict else {}
        zipped = zip(*iterables, **kwargs)
        return [variables.TupleVariable(list(var)) for var in zipped]

    def next_variable(self, tx):
        assert self.mutable_local
        old_index = self.index
        args = []

        def get_item(it):
            if isinstance(it, list):
                if old_index >= len(it):
                    raise_observed_exception(StopIteration, tx, self)
                return it[old_index]
            else:
                return it.next_variable(tx)

        try:
            for idx, it in enumerate(self.iterables):
                args.append(get_item(it))
        except ObservedUserStopIteration:
            if self.strict:
                if idx == 0:
                    # all other iterables should be exhausted
                    for it in self.iterables:
                        try:
                            get_item(it)
                        except ObservedUserStopIteration:
                            handle_observed_exception(tx)
                            continue
                        # no ObservedUserStopIteration - fall through to UserError
                        break
                    else:
                        # all iterables exhausted, raise original error
                        raise
                handle_observed_exception(tx)
                raise UserError(
                    ValueError,
                    "zip() has one argument of len differing from others",
                ) from None
            raise

        tx.output.side_effects.mutation(self)
        self.index += 1
        return variables.TupleVariable(args)

    def reconstruct_items(self, codegen):
        for it in self.iterables:
            if isinstance(it, list):
                remaining_items = it[self.index :]
                codegen.foreach(remaining_items)
                codegen.append_output(
                    create_instruction("BUILD_TUPLE", arg=len(remaining_items))
                )
            else:
                codegen(it)

    def reconstruct(self, codegen):
        codegen.add_push_null(
            lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
        )
        self.reconstruct_items(codegen)
        codegen.append_output(
            create_instruction("BUILD_TUPLE", arg=len(self.iterables))
        )
        if sys.version_info >= (3, 10):
            codegen.extend_output(
                [
                    codegen.create_load_const("strict"),
                    codegen.create_load_const(self.strict),
                    create_instruction("BUILD_MAP", arg=1),
                    create_instruction("CALL_FUNCTION_EX", arg=1),
                ]
            )
        else:
            codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0))


class MapVariable(ZipVariable):
    """
    Represents map(fn, *iterables)
    """

    def __init__(
        self,
        fn: VariableTracker,
        iterables: List[Union[List[VariableTracker], VariableTracker]],
        **kwargs,
    ) -> None:
        super().__init__(iterables, **kwargs)
        self.fn = fn

    def python_type(self):
        return map

    def has_unpack_var_sequence(self, tx) -> bool:
        return False

    def next_variable(self, tx):
        args = super().next_variable(tx)
        return self.fn.call_function(tx, args.items, {})

    def reconstruct(self, codegen):
        codegen.add_push_null(
            lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
        )
        codegen(self.fn)
        self.reconstruct_items(codegen)
        codegen.extend_output(
            [
                create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1),
                create_instruction("CALL_FUNCTION_EX", arg=0),
            ]
        )
