# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict

import numpy as np

from onnx.onnx_pb import NodeProto
from onnx.reference.op_run import OpRun, RuntimeTypeError


class OpRunUnary(OpRun):
    """Ancestor to all unary operators in this subfolder.

    Checks that input and output types are the same.
    """

    def run(self, x):  # type: ignore
        """Calls method ``_run``, catches exceptions, displays a longer error message.

        Supports only unary operators.
        """
        self._log("-- begin %s.run(1 input)", self.__class__.__name__)
        try:
            res = self._run(x)
        except TypeError as e:
            raise TypeError(
                f"Issues with types {', '.join(str(type(_)) for _ in [x])} "
                f"(unary operator {self.__class__.__name__!r})."
            ) from e
        self._log("-- done %s.run -> %d outputs", self.__class__.__name__, len(res))
        return self._check_and_fix_outputs(res)


class OpRunUnaryNum(OpRunUnary):
    """Ancestor to all unary and numerical operators in this subfolder.

    Checks that input and output types are the same.
    """

    def run(self, x):  # type: ignore
        """Calls method ``OpRunUnary.run``.

        Catches exceptions, displays a longer error message.
        Checks that the result is not empty.
        """
        res = OpRunUnary.run(self, x)
        if len(res) == 0 or res[0] is None:
            return res
        if not isinstance(res[0], list) and res[0].dtype != x.dtype:
            raise RuntimeTypeError(
                f"Output type mismatch: input '{x.dtype}' != output '{res[0].dtype}' "
                f"(operator {self.__class__.__name__!r})."
            )
        return self._check_and_fix_outputs(res)


class OpRunBinary(OpRun):
    """Ancestor to all binary operators in this subfolder.

    Checks that input and output types are the same.
    """

    def run(self, x, y):  # type: ignore
        """Calls method ``_run``, catches exceptions, displays a longer error message.

        Supports only binary operators.
        """
        self._log("-- begin %s.run(2 inputs)", self.__class__.__name__)
        if x is None or y is None:
            raise RuntimeError(
                f"x and y have different dtype: {type(x)} != {type(y)} ({type(self)})"
            )
        if x.dtype != y.dtype:
            raise RuntimeTypeError(
                f"Input type mismatch: {x.dtype} != {y.dtype} "
                f"(operator '{self.__class__.__name__!r}', "
                f"shapes {x.shape}, {y.shape})."
            )
        try:
            res = self._run(x, y)
        except (TypeError, ValueError) as e:
            raise TypeError(
                f"Issues with types {', '.join(str(type(_)) for _ in [x, y])} "
                f"(binary operator {self.__class__.__name__!r})."
            ) from e
        self._log("-- done %s.run -> %d outputs", self.__class__.__name__, len(res))
        return self._check_and_fix_outputs(res)


class OpRunBinaryComparison(OpRunBinary):
    """Ancestor to all binary operators in this subfolder comparing tensors."""

    pass


class OpRunBinaryNum(OpRunBinary):
    """Ancestor to all binary operators in this subfolder.

    Checks that input oud output types are the same.
    """

    def run(self, x, y):  # type: ignore
        """Calls method ``OpRunBinary.run``, catches exceptions, displays a longer error message."""
        res = OpRunBinary.run(self, x, y)
        if res[0].dtype != x.dtype:
            raise RuntimeTypeError(
                f"Output type mismatch: {x.dtype} != {res[0].dtype} or {y.dtype} "
                f"(operator {self.__class__.__name__!r})"
                f" type(x)={type(x)} type(y)={type(y)}"
            )
        return self._check_and_fix_outputs(res)


class OpRunBinaryNumpy(OpRunBinaryNum):
    """*numpy_fct* is a binary numpy function which
    takes two matrices.
    """

    def __init__(
        self, numpy_fct: Any, onnx_node: NodeProto, run_params: Dict[str, Any]
    ):
        OpRunBinaryNum.__init__(self, onnx_node, run_params)
        self.numpy_fct = numpy_fct

    def _run(self, a, b):  # type: ignore
        res = (self.numpy_fct(a, b),)
        return self._check_and_fix_outputs(res)


class OpRunReduceNumpy(OpRun):  # type: ignore
    """Implements the reduce logic.
    It must have a parameter *axes*.
    """

    def __init__(self, onnx_node: NodeProto, run_params: Dict[str, Any]):
        OpRun.__init__(self, onnx_node, run_params)
        if hasattr(self, "axes"):
            if isinstance(self.axes, np.ndarray):  # type: ignore
                if len(self.axes.shape) == 0 or self.axes.shape[0] == 0:  # type: ignore
                    self.axes = None
                else:
                    self.axes = tuple(self.axes)
            elif self.axes in [[], ()]:
                self.axes = None
            elif isinstance(self.axes, list):
                self.axes = tuple(self.axes)

    def is_axes_empty(self, axes):
        return axes is None

    def handle_axes(self, axes):  # noqa: PLR0911
        if isinstance(axes, tuple):
            if len(axes) == 0:
                return None
            return axes
        if axes is None:
            return None
        if isinstance(axes, (int, tuple)):
            return axes
        if not isinstance(axes, np.ndarray):
            raise TypeError(f"axes must be an array, not {type(axes)}.")
        if len(axes.shape) == 0:
            return int(axes)
        if 0 in axes.shape:
            return None
        return tuple(axes.ravel().tolist())

    def output_shape(self, data, axes, keepdims):
        return np.sum(data, axis=axes, keepdims=keepdims).shape

    def reduce_constant(self, data, const_val, axes, keepdims):
        """Special case reduction where the output value is a constant."""
        output_shape = self.output_shape(data, axes, keepdims)
        return (np.full(output_shape, const_val, dtype=data.dtype),)
