# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0


import numpy as np

from onnx.reference.custom_element_types import (
    bfloat16,
    float8e4m3fn,
    float8e4m3fnuz,
    float8e5m2,
    float8e5m2fnuz,
    int4,
    uint4,
)
from onnx.reference.op_run import OpRun, RefAttrName


def _check_dtype(val):  # type: ignore
    a = val.dtype
    if not isinstance(a, np.dtype) and a not in {
        bfloat16,
        float8e4m3fn,
        float8e4m3fnuz,
        float8e5m2,
        float8e5m2fnuz,
        uint4,
        int4,
        np.int8,
        np.uint8,
        np.float16,
        np.float32,
        np.float64,
        np.int32,
        np.int64,
        np.int16,
        np.uint16,
        np.uint32,
        np.bool_,
        np.str_,
        np.uint64,
        bool,
        str,
    }:
        raise TypeError(
            f"Type ({a}, {type(a)}) is not a numpy type (operator 'Constant')"
        )


class ConstantCommon(OpRun):
    def _check(self, cst):  # type: ignore
        if isinstance(cst, tuple):
            raise TypeError(f"Unexpected type {type(cst)} for a constant.")
        return cst


class Constant_1(ConstantCommon):
    def __init__(self, onnx_node, run_params):  # type: ignore
        ConstantCommon.__init__(self, onnx_node, run_params)
        self.cst = self.value  # type: ignore
        _check_dtype(self.cst)

    def _run(self, **overridden_attributes):  # type: ignore
        if overridden_attributes and (
            len(overridden_attributes) > 1
            or "value" not in overridden_attributes
            or id(overridden_attributes["value"]) != id(self.value)
        ):
            raise RuntimeError(
                "Function attributes are not implemented for opset <= 11. Use opset > 12."
            )
        return (self._check(self.cst),)


class Constant_9(Constant_1):
    def __init__(self, onnx_node, run_params):  # type: ignore
        Constant_1.__init__(self, onnx_node, run_params)


class Constant_11(ConstantCommon):
    def __init__(self, onnx_node, run_params):  # type: ignore
        ConstantCommon.__init__(self, onnx_node, run_params)
        if getattr(self, "sparse_value", None) is None:
            self.cst = self.value  # type: ignore
        else:
            self.cst = self.sparse_value  # type: ignore
        _check_dtype(self.cst)

    def _run(self, **overridden_attributes):  # type: ignore
        if overridden_attributes and (
            len(overridden_attributes) > 1
            or "value" not in overridden_attributes
            or id(overridden_attributes["value"]) != id(self.value)
        ):
            raise RuntimeError(
                "Function attributes are not implemented for opset <= 11. Use opset > 12."
            )
        return (self._check(self.cst),)


class Constant_12(ConstantCommon):
    def __init__(self, onnx_node, run_params):  # type: ignore
        ConstantCommon.__init__(self, onnx_node, run_params)
        if hasattr(self, "sparse_value") and self.sparse_value is not None:  # type: ignore
            self.cst_name = "sparse_value"
            self.cst = self.sparse_value  # type: ignore
            self.cst_convert = lambda v: v
        elif hasattr(self, "value") and self.value is not None:  # type: ignore
            self.cst_name = "value"  # type: ignore
            self.cst = self.value if isinstance(self.value, RefAttrName) else self.value  # type: ignore
            self.cst_convert = lambda v: v
        else:
            for attr, np_dtype in {
                "value_float": np.float32,
                "value_floats": np.float32,
                "value_int": np.int64,
                "value_ints": np.int64,
                "value_string": np.str_,
                "value_strings": np.str_,
            }.items():
                if hasattr(self, attr) and getattr(self, attr) is not None:  # type: ignore
                    self.cst_name = attr
                    v = getattr(self, attr)
                    self.cst = (
                        v  # type: ignore
                        if isinstance(v, RefAttrName)  # type: ignore
                        else np.array(v, dtype=np_dtype)  # type: ignore
                    )
                    self.cst_convert = lambda v, np_dtype=np_dtype: np.array(  # type: ignore
                        v, dtype=np_dtype
                    )
                    break
        if not hasattr(self, "cst_name"):
            raise AttributeError(
                f"No constant is defined for operator 'Constant', outputs are {onnx_node.output}."
            )

    def _run(self, **overridden_attributes):  # type: ignore
        if self.has_linked_attribute:
            if overridden_attributes is None:
                raise RuntimeError(
                    f"Attributes are empty, cannot retrieve value for {self.cst!r}."
                )
            if self.cst_name not in overridden_attributes:
                raise RuntimeError(
                    f"Cannot find attribute {self.cst_name!r} in {list(overridden_attributes)!r}."
                )
            value = overridden_attributes[self.cst_name]
            if isinstance(value, np.ndarray):
                return (value,)
            return (self.cst_convert(value),)
        return (self._check(self.cst),)
