# SPDX-License-Identifier: Apache-2.0

# Copyright (c) ONNX Project Contributors

import unittest
from typing import Dict, List, Optional, Tuple, Union

import numpy as np

import onnx
from onnx import TensorProto, TypeProto
from onnx.checker import ValidationError
from onnx.defs import OpSchema, get_all_schemas_with_history, get_schema
from onnx.helper import (
    make_graph,
    make_node,
    make_opsetid,
    make_tensor_type_proto,
    make_tensor_value_info,
)
from onnx.numpy_helper import from_array
from onnx.shape_inference import InferenceError, infer_node_outputs

ADD_SCHEMA = max(
    (s for s in get_all_schemas_with_history() if s.name == "Add" and s.domain == ""),
    key=lambda s: s.since_version,
)
RESHAPE_SCHEMA = max(
    (
        s
        for s in get_all_schemas_with_history()
        if s.name == "Reshape" and s.domain == ""
    ),
    key=lambda s: s.since_version,
)


def _to_tensor_types(
    tensor_types: Dict[str, Tuple[int, Tuple[Union[int, str, None], ...]]]
) -> Dict[str, TypeProto]:
    return {key: make_tensor_type_proto(*value) for key, value in tensor_types.items()}


def _run_case(
    schema: OpSchema,
    input_names: List[str],
    output_names: List[str],
    input_types: Dict[str, TypeProto],
    input_data: Optional[Dict[str, np.ndarray]] = None,
) -> Dict[str, TypeProto]:
    if input_data is None:
        input_data = {}
    return infer_node_outputs(
        schema,
        make_node(schema.name, input_names, output_names, domain=schema.domain),
        input_types,
        {key: from_array(arr) for key, arr in input_data.items()},
    )


class TestInferenceFunctionCall(unittest.TestCase):
    def test_add_inference(self) -> None:
        cases = [
            (
                {"A": (TensorProto.FLOAT, ()), "B": (TensorProto.FLOAT, ())},
                {"C": (TensorProto.FLOAT, ())},
            ),
            (
                {
                    "A": (TensorProto.FLOAT, (None, 2)),
                    "B": (TensorProto.FLOAT, (2,)),
                },
                {"C": (TensorProto.FLOAT, (None, 2))},
            ),
            (
                {
                    "A": (TensorProto.FLOAT, (None, 2)),
                    "B": (TensorProto.FLOAT, (1, 2)),
                },
                {"C": (TensorProto.FLOAT, (None, 2))},
            ),
            (
                {
                    "A": (TensorProto.DOUBLE, ("n", "m")),
                    "B": (TensorProto.DOUBLE, (1, "n", "m")),
                },
                {"C": (TensorProto.DOUBLE, (1, "n", "m"))},
            ),
            (
                {
                    "A": (TensorProto.FLOAT, ("x", 2)),
                    "B": (TensorProto.FLOAT, ("y", 2)),
                },
                {"C": (TensorProto.FLOAT, (None, 2))},
            ),
        ]
        for ins, outs in cases:
            assert _run_case(ADD_SCHEMA, ["A", "B"], ["C"], _to_tensor_types(ins)) == _to_tensor_types(outs)  # type: ignore

    def test_add_inference_raises_errors(self) -> None:
        with self.assertRaises(ValidationError):
            _run_case(
                ADD_SCHEMA,
                ["A"],
                ["C"],
                _to_tensor_types({"A": (TensorProto.FLOAT, (3, 4))}),
            )
        with self.assertRaises(ValidationError):
            _run_case(
                ADD_SCHEMA,
                ["A", "B"],
                ["C"],
                _to_tensor_types({"A": (TensorProto.FLOAT, (3, 4)), "B": (2, (3, 4))}),
            )
        with self.assertRaises(InferenceError):
            _run_case(
                ADD_SCHEMA,
                ["A", "B"],
                ["C"],
                _to_tensor_types(
                    {
                        "A": (TensorProto.FLOAT, (2, 4)),
                        "B": (TensorProto.FLOAT, (3, 4)),
                    }
                ),
            )
        with self.assertRaises(KeyError):
            _run_case(
                ADD_SCHEMA,
                ["A", "B"],
                ["C"],
                _to_tensor_types({"A": (TensorProto.FLOAT, (3, 4))}),
            )

    def test_reshape_inference(self) -> None:
        assert _run_case(
            RESHAPE_SCHEMA,
            ["x", "t"],
            ["y"],
            _to_tensor_types(
                {
                    "x": (TensorProto.FLOAT, (5, 4)),
                    "t": (TensorProto.INT64, (3,)),
                }
            ),
            {"t": np.array([2, 2, 5], dtype=np.int64)},
        ) == _to_tensor_types({"y": (TensorProto.FLOAT, (2, 2, 5))})

    def test_scan_inference_with_subgraph(self) -> None:
        seq_len = "sequence"
        input_size = 2
        loop_state_size = 3

        input_value_infos = [
            make_tensor_value_info("loop_state_in", TensorProto.UNDEFINED, None),
            make_tensor_value_info("input", TensorProto.UNDEFINED, None),
            make_tensor_value_info("outer", TensorProto.UNDEFINED, None),
        ]
        output_value_infos = [
            make_tensor_value_info("loop_state_out", TensorProto.UNDEFINED, None),
            make_tensor_value_info("output", TensorProto.FLOAT, (seq_len, input_size)),
        ]

        subgraph = make_graph(
            [
                make_node("Identity", ["loop_state_in"], ["loop_state_out"]),
                make_node("Add", ["input", "outer"], ["output"]),
            ],
            "subgraph",
            input_value_infos,
            output_value_infos,
        )

        assert infer_node_outputs(
            get_schema("Scan", 9),
            make_node(
                "Scan",
                ["loop_state_orig", "scan_input", "scan_outer"],
                ["loop_state_final", "scan_output"],
                num_scan_inputs=1,
                body=subgraph,
            ),
            _to_tensor_types(
                {
                    "loop_state_orig": (TensorProto.FLOAT, (loop_state_size,)),
                    "scan_input": (TensorProto.FLOAT, (seq_len, input_size)),
                    "scan_outer": (TensorProto.FLOAT, (input_size,)),
                }
            ),
            # Same as default value in Scan-9
            opset_imports=[make_opsetid("", 9)],
            ir_version=4,
        ) == _to_tensor_types(
            {
                "loop_state_final": (TensorProto.FLOAT, (loop_state_size,)),
                "scan_output": (TensorProto.FLOAT, (seq_len, input_size)),
            }
        )

    def test_inference_with_conflow(self) -> None:
        model_script = """
        <
            ir_version: 8,
            opset_import: ["" : 18, "onnxscript.atenlib" : 1],
            producer_name: "pytorch",
            producer_version: "2.1.0"
        >
        torch_jit (float input_0) => (float reault, int64 index)
        {
            reault, index = onnxscript.atenlib.aten_min_dim <dim = 0, keepdim = 1> (input_0)
        }
        <
            domain: "onnxscript.atenlib",
            opset_import: ["" : 18]
        >
        aten_min_dim <dim>(self) => (result_7, indices_6)
        {
            tmp = Shape (self)
            tmp_0 = Size (tmp)
            tmp_1 = Constant <value = int64 tmp_1 {0}> ()
            tmp_1_cast = CastLike (tmp_1, tmp_0)
            tmp_2 = Equal (tmp_0, tmp_1_cast)
            cond = Not (tmp_2)
            indices_6, result_7 = If (cond) <
                then_branch = thenGraph_4 () => ( indices,  result) {
                    dim = Constant <value_int: int = @dim> ()
                    tmp_3 = Constant <value_ints = [-1]> ()
                    dims = Reshape (dim, tmp_3)
                    result = ReduceMin <keepdims: int = @keepdim> (self, dims)
                    indices = ArgMin <axis: int = @dim, keepdims: int = @keepdim> (self)
                }, else_branch = elseGraph_4 () => ( indices_4,  result_5) {
                    indices_4 = Constant <value_int = 0> ()
                    result_5 = Identity (self)
                }
            >
        }
        """
        model = onnx.parser.parse_model(model_script)
        onnx.shape_inference.infer_shapes(model, strict_mode=False)
        with self.assertRaises(onnx.shape_inference.InferenceError):
            onnx.shape_inference.infer_shapes(model, strict_mode=True)

    def test_inference_with_attribute(self) -> None:
        model_script = """
        <
            ir_version: 8,
            opset_import: ["" : 18, "custom" : 1],
            producer_name: "",
            producer_version: "1.0"
        >
        MeanVarianceNormalization (float[N] x) => (float[M] y)
        {
            y = custom.custom_mvn <axes = [0]> (x)
        }
        <
            domain: "custom",
            opset_import: ["" : 18]
        >
        custom_mvn <axes>(X) => (Y)
        {
          Exponent = Constant <value = float {2.0}>()
          Epsilon = Constant <value = float {1e-9}>()
          axes = Constant <value_ints: ints = @axes>()
          X_RM = ReduceMean (X, axes)
          EX_squared = Pow (X_RM, Exponent)
          X_squared = Pow (X, Exponent)
          E_Xsquared = ReduceMean (X_squared, axes)
          Variance = Sub (E_Xsquared, EX_squared)
          STD = Sqrt (Variance)
          X_variance = Sub (X, X_RM)
          Processed_STD = Add (STD, Epsilon)
          Y = Div (X_variance, Processed_STD)
        }
        """
        model = onnx.parser.parse_model(model_script)
        # onnx.shape_inference.infer_shapes(model, strict_mode=False)
        onnx.shape_inference.infer_shapes(model, strict_mode=True)


if __name__ == "__main__":
    unittest.main()
