# mypy: ignore-errors

import collections
import functools
import inspect
import operator
import types
from typing import Dict, List, Optional, TYPE_CHECKING

import torch
import torch.fx
from torch._guards import Source

from .. import polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import raise_observed_exception, unimplemented
from ..source import AttrSource
from ..utils import (
    get_fake_value,
    guard_if_dyn,
    is_namedtuple,
    istype,
    iter_contains,
    Lit,
    namedtuple_fields,
    odict_values,
    set_example_value,
)
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
from .functions import UserFunctionVariable, UserMethodVariable
from .iter import IteratorVariable


if TYPE_CHECKING:
    from torch._dynamo.symbolic_convert import InstructionTranslator


class BaseListVariable(VariableTracker):
    @staticmethod
    def cls_for_instance(obj):
        if is_namedtuple(obj):
            return functools.partial(NamedTupleVariable, tuple_cls=type(obj))
        return BaseListVariable.cls_for(type(obj))

    @staticmethod
    def cls_for(obj):
        return {
            iter: ListIteratorVariable,
            list: ListVariable,
            slice: SliceVariable,
            torch.Size: SizeVariable,
            tuple: TupleVariable,
            odict_values: ListVariable,
            torch.nn.ParameterList: ListVariable,
            torch.nn.ModuleList: ListVariable,
            collections.deque: DequeVariable,
        }[obj]

    def __init__(
        self,
        items: List[VariableTracker],
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        assert isinstance(items, list)
        assert all(isinstance(x, VariableTracker) for x in items)
        self.items: List[VariableTracker] = items

    def _as_proxy(self):
        return [x.as_proxy() for x in self.items]

    def modified(self, items, **kwargs):
        return type(self)(items, **kwargs)

    @property
    def value(self):
        return self.as_python_constant()

    def debug_repr_helper(self, prefix, suffix):
        return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix

    def as_python_constant(self):
        return self.python_type()([x.as_python_constant() for x in self.items])

    def as_proxy(self):
        assert self.python_type() is not SizeVariable
        return self.python_type()(self._as_proxy())

    def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
        from .tensor import SymNodeVariable

        if isinstance(arg, SymNodeVariable):
            index = arg.sym_num
        else:
            index = arg.as_python_constant()

        if isinstance(index, slice):
            # Set source to None because slicing a list gives a new local
            return self.clone(
                items=self.items[index],
                source=None,
                mutable_local=MutableLocal() if self.mutable_local else None,
            )
        else:
            assert isinstance(index, (int, torch.SymInt))
            return self.items[index]

    def unpack_var_sequence(self, tx):
        return list(self.items)

    def call_method(
        self,
        tx,
        name,
        args: List["VariableTracker"],
        kwargs: Dict[str, "VariableTracker"],
    ) -> "VariableTracker":
        if name == "__getitem__":
            from .tensor import TensorVariable

            assert not kwargs and len(args) == 1
            if isinstance(args[0], TensorVariable):
                value = get_fake_value(args[0].as_proxy().node, tx)
                if value.constant is not None and value.constant.numel() == 1:
                    value = variables.ConstantVariable.create(value.constant.item())
                else:
                    unimplemented("__getitem__ with non-constant tensor")
            else:
                value = args[0]
            return self.getitem_const(tx, value)
        elif name == "__contains__":
            assert len(args) == 1
            assert not kwargs
            return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
        elif name == "index":
            from .builder import SourcelessBuilder

            return tx.inline_user_function_return(
                SourcelessBuilder.create(tx, polyfills.index),
                [self] + list(args),
                kwargs,
            )

        return super().call_method(tx, name, args, kwargs)

    @staticmethod
    def list_compare(tx: "InstructionTranslator", op, left, right):
        return variables.UserFunctionVariable(polyfills.list_cmp).call_function(
            tx, [variables.BuiltinVariable(op), left, right], {}
        )


class RangeVariable(BaseListVariable):
    def __init__(self, items, **kwargs) -> None:
        items_to_map = items
        start = variables.ConstantVariable.create(0)
        stop = None
        step = variables.ConstantVariable.create(1)

        if len(items_to_map) == 1:
            (stop,) = items_to_map
        elif len(items_to_map) == 2:
            start, stop = items_to_map
        elif len(items_to_map) == 3:
            start, stop, step = items_to_map
        else:
            raise AssertionError

        assert stop is not None
        super().__init__([start, stop, step], **kwargs)

    def debug_repr(self):
        return self.debug_repr_helper("range(", ")")

    def python_type(self):
        return range

    def start(self):
        return self.items[0].as_python_constant()

    def stop(self):
        return self.items[1].as_python_constant()

    def step(self):
        return self.items[2].as_python_constant()

    def range_length(self):
        lo = self.start()
        hi = self.stop()
        step = self.step()

        assert step != 0
        if step > 0 and lo < hi:
            return 1 + (hi - 1 - lo) // step
        elif step < 0 and lo > hi:
            return 1 + (lo - 1 - hi) // (0 - step)
        else:
            return 0

    def _get_slice_indices(self, length, slice):
        step_is_negative = 0

        if slice.step is None:
            step = 1
            step_is_negative = False
        else:
            step = slice.step
            step_is_negative = slice.step < 0

        # Find lower and upper bounds for start and stop.
        if step_is_negative:
            lower = -1
            upper = length + lower
        else:
            lower = 0
            upper = length

        # Compute start
        if slice.start is None:
            start = upper if step_is_negative else lower
        else:
            start = slice.start

        if start < 0:
            start += length
            if start < lower:
                start = lower
        else:
            if start > upper:
                start = upper

        # Compute stop.
        if slice.stop is None:
            stop = lower if step_is_negative else upper

        else:
            stop = slice.stop

            if stop < 0:
                stop += length
                if stop < lower:
                    stop = lower
            else:
                if stop > upper:
                    stop = upper

        return [start, stop, step]

    def apply_index(self, index):
        length = self.range_length()
        if index < 0:
            index = length + index

        if index < 0 or index >= length:
            raise IndexError(f"index {index} is out of range")

        return variables.ConstantVariable.create(self.start() + (index * self.step()))

    def apply_slice(self, slice):
        (slice_start, slice_stop, slice_step) = self._get_slice_indices(
            self.range_length(), slice
        )

        def compute_item(index):
            return self.start() + (index * self.step())

        sub_step = self.step() * slice_step
        sub_start = compute_item(slice_start)
        sub_stop = compute_item(slice_stop)

        result = RangeVariable(
            [
                variables.ConstantVariable.create(x)
                for x in [sub_start, sub_stop, sub_step]
            ],
            mutable_local=MutableLocal() if self.mutable_local else None,
        )
        return result

    def as_python_constant(self):
        return range(*[x.as_python_constant() for x in self.items])

    def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
        # implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
        index = arg.as_python_constant()

        if isinstance(index, slice):
            return self.apply_slice(index)
        else:
            return self.apply_index(index)

    def as_proxy(self):
        return self.python_type()(*self._as_proxy())

    def unpack_var_sequence(self, tx=None):
        return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]

    def reconstruct(self, codegen):
        assert "range" not in codegen.tx.f_globals
        codegen.add_push_null(
            lambda: codegen.append_output(codegen.create_load_python_module(range))
        )
        codegen.foreach(self.items)
        codegen.extend_output(create_call_function(3, False))

    def var_getattr(self, tx: "InstructionTranslator", name):
        fields = ["start", "stop", "step"]
        if name not in fields:
            unimplemented(f"range.{name}")
        return self.items[fields.index(name)]


class CommonListMethodsVariable(BaseListVariable):
    """
    Implement methods common to List and other List-like things
    """

    def call_method(
        self,
        tx,
        name,
        args: List["VariableTracker"],
        kwargs: Dict[str, "VariableTracker"],
    ) -> "VariableTracker":
        from .tensor import SymNodeVariable

        if name == "append" and self.mutable_local:
            assert not kwargs
            (arg,) = args
            tx.output.side_effects.mutation(self)
            self.items.append(arg)
            return ConstantVariable.create(None)
        elif (
            name == "extend"
            and self.mutable_local
            and args
            and args[0].has_force_unpack_var_sequence(tx)
        ):
            assert not kwargs
            (arg,) = args
            seq = arg.force_unpack_var_sequence(tx)
            tx.output.side_effects.mutation(self)
            self.items.extend(seq)
            return ConstantVariable.create(None)
        elif name == "insert" and self.mutable_local:
            assert not kwargs
            idx, value = args
            if isinstance(idx, SymNodeVariable):
                const_idx = idx.evaluate_expr()
            else:
                const_idx = idx.as_python_constant()
            tx.output.side_effects.mutation(self)
            self.items.insert(const_idx, value)
            return ConstantVariable.create(None)
        elif name == "pop" and self.mutable_local:
            assert not kwargs
            tx.output.side_effects.mutation(self)
            return self.items.pop(*[a.as_python_constant() for a in args])
        elif name == "clear" and self.mutable_local:
            assert not kwargs and not args
            tx.output.side_effects.mutation(self)
            self.items.clear()
            return ConstantVariable.create(None)
        elif (
            name == "__setitem__"
            and self.mutable_local
            and args
            and args[0].is_python_constant()
        ):
            assert not kwargs
            key, value = args
            tx.output.side_effects.mutation(self)
            if isinstance(key, SliceVariable):
                self.items[key.as_python_constant()] = list(value.items)
            else:
                self.items[key.as_python_constant()] = value
            return ConstantVariable.create(None)
        elif name == "copy":
            # List copy() doesn't have args and kwargs
            assert not kwargs
            assert not args
            items = list(self.items)
            return self.modified(items, mutable_local=MutableLocal())
        elif name == "reverse" and self.mutable_local:
            assert not kwargs
            assert not args
            self.items.reverse()
            tx.output.side_effects.mutation(self)
            return ConstantVariable.create(None)
        else:
            return super().call_method(tx, name, args, kwargs)


class ListVariable(CommonListMethodsVariable):
    def python_type(self):
        return list

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(length={len(self.items)})"

    def debug_repr(self):
        return self.debug_repr_helper("[", "]")

    def reconstruct(self, codegen):
        codegen.foreach(self.items)
        codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items)))

    def call_method(
        self,
        tx,
        name,
        args: List["VariableTracker"],
        kwargs: Dict[str, "VariableTracker"],
    ) -> "VariableTracker":
        if (
            name == "__setitem__"
            and self.mutable_local
            and args
            and args[0].is_python_constant()
        ):
            assert not kwargs
            key, value = args
            tx.output.side_effects.mutation(self)
            if isinstance(key, SliceVariable):
                if not value.has_force_unpack_var_sequence(tx):
                    unimplemented(
                        f"Missing dynamo support for expanding {value} into a list for slice assignment."
                    )
                self.items[key.as_python_constant()] = value.force_unpack_var_sequence(
                    tx
                )
            else:
                self.items[key.as_python_constant()] = value
            return ConstantVariable.create(None)
        else:
            return super().call_method(tx, name, args, kwargs)

    def var_getattr(self, tx, name):
        if name == "__class__":
            source = AttrSource(self.source, name) if self.source else None
            class_type = self.python_type()
            if class_type is list:
                return variables.BuiltinVariable(class_type, source=source)
            else:
                return variables.UserDefinedClassVariable(class_type, source=source)
        return super().var_getattr(tx, name)

    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
        if self.python_type() is not list:
            return super().call_hasattr(tx, name)
        return variables.ConstantVariable.create(hasattr([], name))


class DequeVariable(CommonListMethodsVariable):
    def python_type(self):
        return collections.deque

    def debug_repr(self):
        return self.debug_repr_helper("deque([", "])")

    def reconstruct(self, codegen):
        assert "deque" not in codegen.tx.f_globals
        codegen.add_push_null(
            lambda: codegen.append_output(
                codegen.create_load_python_module(collections.deque)
            )
        )
        codegen.foreach(self.items)
        codegen.extend_output(
            [
                create_instruction("BUILD_LIST", arg=len(self.items)),
                *create_call_function(1, False),
            ]
        )

    def call_method(
        self,
        tx,
        name,
        args: List["VariableTracker"],
        kwargs: Dict[str, "VariableTracker"],
    ) -> "VariableTracker":
        if (
            name == "__setitem__"
            and self.mutable_local
            and args
            and args[0].is_python_constant()
        ):
            assert not kwargs
            key, value = args
            assert key.is_python_constant() and isinstance(
                key.as_python_constant(), int
            )
            tx.output.side_effects.mutation(self)
            self.items[key.as_python_constant()] = value
            return ConstantVariable.create(None)
        elif (
            name == "extendleft"
            and self.mutable_local
            and args[0].has_force_unpack_var_sequence(tx)
        ):
            assert not kwargs

            (arg,) = args
            prefix = arg.force_unpack_var_sequence(tx)
            prefix.reverse()
            tx.output.side_effects.mutation(self)
            self.items = prefix + list(self.items)
            return ConstantVariable.create(None)
        elif name == "popleft" and self.mutable_local:
            assert not args
            assert not kwargs
            item = self.items[0]
            tx.output.side_effects.mutation(self)
            self.items = self.items[1:]
            return item
        elif name == "appendleft" and self.mutable_local:
            assert not kwargs
            tx.output.side_effects.mutation(self)
            self.items = [args[0]] + list(self.items)
            return ConstantVariable.create(None)
        else:
            return super().call_method(tx, name, args, kwargs)


class TupleVariable(BaseListVariable):
    def python_type(self):
        return tuple

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(length={len(self.items)})"

    def debug_repr(self):
        return self.debug_repr_helper("(", ")")

    def reconstruct(self, codegen):
        codegen.foreach(self.items)
        codegen.append_output(create_instruction("BUILD_TUPLE", arg=len(self.items)))

    def call_method(
        self,
        tx,
        name,
        args: List["VariableTracker"],
        kwargs: Dict[str, "VariableTracker"],
    ) -> "VariableTracker":
        return super().call_method(tx, name, args, kwargs)

    def var_getattr(self, tx, name):
        if name == "__class__":
            source = AttrSource(self.source, name) if self.source else None
            class_type = self.python_type()
            if class_type is tuple:
                return variables.BuiltinVariable(class_type, source=source)
            else:
                return variables.UserDefinedClassVariable(class_type, source=source)
        return super().var_getattr(tx, name)

    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
        if self.python_type() is not tuple:
            return super().call_hasattr(tx, name)
        return variables.ConstantVariable.create(hasattr((), name))


class SizeVariable(TupleVariable):
    """torch.Size(...)"""

    _nonvar_fields = {
        "proxy",
        *TupleVariable._nonvar_fields,
    }

    def __init__(
        self,
        items: List[VariableTracker],
        proxy: Optional[torch.fx.Proxy] = None,
        **kwargs,
    ) -> None:
        self.proxy = proxy
        super().__init__(items, **kwargs)

    def debug_repr(self):
        return self.debug_repr_helper("torch.Size([", "])")

    def python_type(self):
        return torch.Size

    def as_proxy(self):
        if self.proxy is not None:
            return self.proxy

        # torch.Size needs special handling.  Normally, we pun a list-like
        # container to directly contain Proxy/Node objects from FX, and FX
        # knows to look inside containers (via map_aggregate).  But torch.Size
        # is weird; although it subclasses from tuple, it doesn't allow
        # members which aren't int-like (rejecting Proxy and Node).  This
        # means we can't use the normal representation trick
        # torch.Size([proxy0, proxy1]).  I looked into seeing if I could
        # relax torch.Size in PyTorch proper, but if torch.Size constructor
        # sees a type that it doesn't recognize, it will try to call
        # __index__() on it, so there is no BC way to actually change this
        # behavior (though it occurs to me that I could have just added a
        # YOLO no checking alternate constructor.)
        #
        # To work around this problem, I represent a torch.Size proxy as
        # a straight up proxy, that would have been constructed by taking
        # the constituent proxies as arguments.  This trick can be generally
        # used for any construct that we need a proxy for but we can't
        # directly represent as an aggregate; I don't see very many examples
        # of this in torchdynamo though!

        # Look for a proxy.  If there are none, do the legacy behavior
        tracer = None
        proxies = self._as_proxy()
        for proxy in proxies:
            if isinstance(proxy, torch.fx.Proxy):
                tracer = proxy.tracer
                break

        if tracer is None:
            return torch.Size(proxies)

        proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {})
        set_example_value(
            proxy.node,
            torch.Size(
                [
                    p.node.meta["example_value"] if not isinstance(p, int) else p
                    for p in proxies
                ]
            ),
        )
        return proxy

    def reconstruct(self, codegen):
        codegen.add_push_null(lambda: codegen.load_import_from("torch", "Size"))
        codegen.foreach(self.items)
        build_torch_size = [
            create_instruction("BUILD_TUPLE", arg=len(self.items)),
        ] + create_call_function(1, False)
        codegen.extend_output(build_torch_size)

    def unpack_var_sequence(self, tx):
        return list(self.items)

    def numel(self, tx):
        from .builtin import BuiltinVariable
        from .tensor import SymNodeVariable

        const_result = 1
        sym_sizes = []

        for v in self.items:
            if isinstance(v, ConstantVariable):
                const_result *= v.value
            else:
                assert isinstance(v, SymNodeVariable), type(v)
                # Delay proxy calls  until we know it will be necessary
                sym_sizes.append(v)

        result = ConstantVariable.create(const_result)
        if sym_sizes and const_result == 1:
            # Skip multiplying by 1
            result, *sym_sizes = sym_sizes

        if not sym_sizes or const_result == 0:
            return result

        mul = BuiltinVariable(operator.mul)
        for v in sym_sizes:
            result = mul.call_function(tx, [result, v], {})
        return result

    def call_method(
        self,
        tx,
        name,
        args: List["VariableTracker"],
        kwargs: Dict[str, "VariableTracker"],
    ) -> "VariableTracker":
        if name == "__getitem__":
            assert not kwargs and len(args) == 1
            out = self.get_item_dyn(tx, args[0])
            return out
        elif name == "numel":
            assert not args and not kwargs
            return self.numel(tx)

        return super().call_method(tx, name, args, kwargs)

    def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
        from .tensor import SymNodeVariable

        if isinstance(arg, SymNodeVariable):
            index = arg.sym_num
        else:
            index = arg.as_python_constant()
        if isinstance(index, slice):
            return SizeVariable(self.items[index])
        else:
            assert isinstance(index, (int, torch.SymInt))
            return self.items[index]

    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
        return variables.ConstantVariable.create(hasattr(torch.Size, name))


class NamedTupleVariable(TupleVariable):
    _nonvar_fields = {
        "tuple_cls",
        *TupleVariable._nonvar_fields,
    }

    def __init__(self, items, tuple_cls, **kwargs) -> None:
        super().__init__(items, **kwargs)
        self.tuple_cls = tuple_cls

    def debug_repr(self):
        return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))

    def python_type(self):
        return self.tuple_cls

    def as_python_constant(self):
        return self.python_type()(*[x.as_python_constant() for x in self.items])

    def as_proxy(self):
        assert self.python_type() is not SizeVariable
        return self.python_type()(*self._as_proxy())

    def reconstruct(self, codegen):
        create_fn = getattr(self.tuple_cls, "_make", self.tuple_cls)
        codegen.add_push_null(
            lambda: codegen.append_output(codegen._create_load_const(create_fn))
        )
        codegen.foreach(self.items)
        codegen.extend_output(
            [
                create_instruction("BUILD_TUPLE", arg=len(self.items)),
            ]
            + create_call_function(1, False)
        )

    def var_getattr(self, tx: "InstructionTranslator", name):
        def check_and_create_method():
            method = inspect.getattr_static(self.tuple_cls, name, None)
            if isinstance(method, classmethod):
                # We need the unbounded cls method to avoid the inline __self__
                return UserMethodVariable(
                    method.__func__,
                    variables.UserDefinedClassVariable(self.tuple_cls),
                )
            elif isinstance(method, staticmethod):
                return UserFunctionVariable(method.__func__)
            elif inspect.isfunction(method):
                return UserMethodVariable(method, self)
            else:
                return None

        fields = namedtuple_fields(self.tuple_cls)
        if name not in fields:
            method = check_and_create_method()
            if not method:
                return super().var_getattr(tx, name)
            return method
        return self.items[fields.index(name)]

    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
        return variables.ConstantVariable.create(hasattr(self.tuple_cls, name))


class SliceVariable(BaseListVariable):
    def __init__(self, items, **kwargs) -> None:
        items_to_map = items
        start, stop, step = [variables.ConstantVariable.create(None)] * 3

        if len(items_to_map) == 1:
            (stop,) = items_to_map
        elif len(items_to_map) == 2:
            start, stop = items_to_map
        elif len(items_to_map) == 3:
            start, stop, step = items_to_map
        else:
            raise AssertionError

        if isinstance(start, variables.TensorVariable) or isinstance(
            stop, variables.TensorVariable
        ):
            unimplemented("Dynamic slicing on data-dependent value is not supported")

        super().__init__([start, stop, step], **kwargs)

    def debug_repr(self):
        return self.debug_repr_helper("slice(", ")")

    def as_proxy(self):
        return slice(*self._as_proxy())

    def python_type(self):
        return slice

    def as_python_constant(self):
        return slice(*[guard_if_dyn(x) for x in self.items])

    def reconstruct(self, codegen):
        codegen.foreach(self.items)
        codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))

    def var_getattr(self, tx: "InstructionTranslator", name):
        fields = ["start", "stop", "step"]
        if name not in fields:
            unimplemented(f"slice.{name}")
        return self.items[fields.index(name)]


class ListIteratorVariable(IteratorVariable):
    _nonvar_fields = {
        "index",
        *IteratorVariable._nonvar_fields,
    }

    def __init__(self, items, index: int = 0, **kwargs) -> None:
        super().__init__(**kwargs)
        assert isinstance(items, list)
        # Removing this check as it slows things down too much
        # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492

        # assert all(isinstance(x, VariableTracker) for x in items)
        self.items = items
        self.index = index

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"

    def next_variable(self, tx):
        assert self.mutable_local
        old_index = self.index
        if old_index >= len(self.items):
            raise_observed_exception(StopIteration, tx, self)

        tx.output.side_effects.mutation(self)
        self.index += 1
        return self.items[old_index]

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ):
        if name == "__contains__":
            assert len(args) == 1
            assert not kwargs
            return iter_contains(self.items[self.index :], args[0], tx)

        return super().call_method(tx, name, args, kwargs)

    def python_type(self):
        return type(iter([]))

    def as_python_constant(self):
        if self.index > 0:
            raise NotImplementedError
        return iter([x.as_python_constant() for x in self.items])

    def unpack_var_sequence(self, tx):
        return list(self.items[self.index :])

    def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
        return self.unpack_var_sequence(tx)

    def reconstruct(self, codegen):
        remaining_items = self.items[self.index :]
        codegen.foreach(remaining_items)
        codegen.extend_output(
            [
                create_instruction("BUILD_TUPLE", arg=len(remaining_items)),
                create_instruction("GET_ITER"),
            ]
        )


class TupleIteratorVariable(ListIteratorVariable):
    pass


class RestrictedListSubclassVariable(ListVariable):
    """
    This is a special case of UserDefinedObjectVariable where:
        1) The user subclasses list
        2) None of the list methods are overriden, merely some new methods are added

    In these cases, we can prevent graph breaks by not using the general
    UserDefinedObjectVariable machinery and instead treating it like
    a ListVariable.
    """

    _nonvar_fields = {"user_cls", "user_cls_source", *ListVariable._nonvar_fields}
    _allowed_names = {
        "__call__",
        "__module__",
        "__dict__",
        "__doc__",
        "__name__",
        "__qualname__",
    }
    _disallowed_names = {
        "__getattribute__",
        "__getattr__",
        "__setattr__",
    }

    @classmethod
    def _is_non_conflicting_subclass(
        cls,
        user_cls: type,
        python_cls: type,
    ):
        """Ensures user_cls inherits from python_cls (e.g. list) and does not override any methods on python_cls"""
        if (
            not istype(user_cls, type)
            or user_cls.__bases__ != (python_cls,)
            or user_cls.__mro__ != (user_cls, python_cls, object)
        ):
            return False  # not subclass
        return not any(
            hasattr(python_cls, name) or name in cls._disallowed_names
            for name in set(user_cls.__dict__.keys()) - cls._allowed_names
        )

    @classmethod
    def is_matching_cls(cls, user_cls: type):
        return cls._is_non_conflicting_subclass(user_cls, list)

    def __init__(
        self, items, *, user_cls: type, user_cls_source: Source, **kwargs
    ) -> None:
        super().__init__(items=items, **kwargs)
        self.user_cls = user_cls
        self.user_cls_source = user_cls_source
        assert istype(user_cls, type)
        assert isinstance(user_cls_source, Source)

    def debug_repr(self):
        # The constructor is safe as no methods, including __init__, are
        # allowed to be overridden
        # NB: This is guaranteed to print like a list, as __repr__ cannot be
        # overridden, this is... well, it's OK I guess (consistent with
        # eager), but it could be misleading.  You will have to query type
        # instead for details.
        return repr(self.user_cls([Lit(x.debug_repr()) for x in self.items]))

    def python_type(self):
        return self.user_cls

    def as_proxy(self):
        return [x.as_proxy() for x in self.items]

    def as_python_constant(self):
        raise NotImplementedError

    def is_python_constant(self):
        return False

    @property
    def value(self):
        raise AttributeError("value")

    def modified(self, items, **kwargs):
        return type(self)(
            items,
            user_cls=self.user_cls,
            user_cls_source=self.user_cls_source,
            **kwargs,
        )

    def reconstruct(self, codegen):
        codegen.add_push_null(lambda: codegen(self.user_cls_source))
        super().reconstruct(codegen)
        codegen.extend_output(create_call_function(1, False))

    def call_method(
        self,
        tx,
        name,
        args: List["VariableTracker"],
        kwargs: Dict[str, "VariableTracker"],
    ) -> "VariableTracker":
        if name in self.user_cls.__dict__:
            method = self.user_cls.__dict__[name]
            if isinstance(method, types.FunctionType):
                # inline the method
                source = AttrSource(self.user_cls_source, name)
                return UserMethodVariable(method, self, source=source).call_function(
                    tx, args, kwargs
                )
            unimplemented(
                f"RestrictedListSubclassVariable method {self.user_cls.__name__}.{name}"
            )
        return super().call_method(tx, name, args, kwargs)

    def call_function(
        self,
        tx: "InstructionTranslator",
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        return self.call_method(tx, "__call__", args, kwargs)
