# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0

import unittest
from typing import List, Optional

import onnx.shape_inference
from onnx import ModelProto, TensorProto, TensorShapeProto, ValueInfoProto, helper
from onnx.helper import make_model, make_tensor_value_info


class TestSymbolicShape(unittest.TestCase):
    def _assert_valueinfo_shape(
        self, onnx_model: ModelProto, value_infos: List[ValueInfoProto]
    ) -> None:
        """Assert onnx_model.value_info should be the same as expected value_infos
        Instead of exact symbol, use -1 to represent symbolic shape in expected value_infos
        """
        for expected_vi in value_infos:
            shape = self._get_shape_from_name(onnx_model, expected_vi.name)
            assert shape is not None, f"{onnx_model}"
            if expected_vi.type.HasField("tensor_type"):
                expected_shape = expected_vi.type.tensor_type.shape
            elif expected_vi.type.HasField("sparse_tensor_type"):
                expected_shape = expected_vi.type.sparse_tensor_type.shape
            assert len(shape.dim) == len(expected_shape.dim), f"{onnx_model}"
            for dim_i, dim in enumerate(shape.dim):
                expected_dim = expected_shape.dim[dim_i]
                # -1 means it's a symbolic shape
                if expected_dim.dim_value == -1:
                    # symbolic dimension must exist
                    assert dim.dim_param, f"{onnx_model}"
                else:
                    assert dim.dim_value == expected_dim.dim_value, f"{onnx_model}"

    def _count_unique_dim_param_number(self, onnx_model: ModelProto) -> int:
        """Return the total number of unique symbolic shape"""
        symbol_shape_set = set()
        inputs = list(onnx_model.graph.input)
        outputs = list(onnx_model.graph.output)
        valueinfos = list(onnx_model.graph.value_info)
        for v in inputs + outputs + valueinfos:
            for dim in v.type.tensor_type.shape.dim:
                if dim.dim_param:
                    symbol_shape_set.add(dim.dim_param)
        return len(symbol_shape_set)

    def _get_shape_from_name(
        self, onnx_model: ModelProto, name: str
    ) -> Optional[TensorShapeProto]:
        """Get shape from tensor_type or sparse_tensor_type according to given name"""
        inputs = list(onnx_model.graph.input)
        outputs = list(onnx_model.graph.output)
        valueinfos = list(onnx_model.graph.value_info)
        for v in inputs + outputs + valueinfos:
            if v.name == name:
                if v.type.HasField("tensor_type"):
                    return v.type.tensor_type.shape  # type: ignore
                if v.type.HasField("sparse_tensor_type"):
                    return v.type.sparse_tensor_type.shape  # type: ignore
        return None

    def test_concat_enable_symbolic(self) -> None:
        concat = helper.make_node(
            "Concat", inputs=["A", "B"], outputs=["C"], name="Concat", axis=1
        )
        cast = onnx.helper.make_node(
            "Cast", inputs=["C"], outputs=["output"], to=TensorProto.FLOAT
        )
        graph_def = helper.make_graph(
            name="test_graph",
            nodes=[concat, cast],
            inputs=[
                helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, "A"]),
                helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3]),
            ],
            outputs=[
                helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, None])
            ],
        )

        onnx_model = make_model(graph_def)
        inferred_model = onnx.shape_inference.infer_shapes(onnx_model, strict_mode=True)
        self._assert_valueinfo_shape(
            inferred_model, [make_tensor_value_info("C", TensorProto.FLOAT, (2, -1))]
        )
        # the symbolic shape of C and output should be the same
        assert self._get_shape_from_name(
            inferred_model, "C"
        ) == self._get_shape_from_name(inferred_model, "output")

    def test_two_symbolic_concat(self) -> None:
        concat1 = helper.make_node(
            "Concat", inputs=["A", "B"], outputs=["C"], name="Concat", axis=1
        )
        concat2 = helper.make_node(
            "Concat", inputs=["C", "D"], outputs=["E"], name="Concat", axis=1
        )
        cast = onnx.helper.make_node(
            "Cast", inputs=["E"], outputs=["output"], to=TensorProto.FLOAT
        )
        graph_def = helper.make_graph(
            name="test_graph",
            nodes=[concat1, concat2, cast],
            inputs=[
                helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, "A"]),
                helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3]),
                helper.make_tensor_value_info("D", TensorProto.FLOAT, [2, "D"]),
            ],
            outputs=[
                helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, None])
            ],
        )

        onnx_model = make_model(graph_def)
        inferred_model = onnx.shape_inference.infer_shapes(onnx_model, strict_mode=True)
        self._assert_valueinfo_shape(
            inferred_model,
            [
                make_tensor_value_info("C", TensorProto.FLOAT, (2, -1)),
                make_tensor_value_info("E", TensorProto.FLOAT, (2, -1)),
            ],
        )
        # the symbolic shape of E and output should be the same
        assert self._get_shape_from_name(
            inferred_model, "E"
        ) == self._get_shape_from_name(inferred_model, "output")

    def test_duplicate_symbolic_shape(self) -> None:
        concat1 = helper.make_node(
            "Concat", inputs=["A", "B"], outputs=["C"], name="Concat", axis=1
        )
        concat2 = helper.make_node(
            "Concat", inputs=["C", "D"], outputs=["E"], name="Concat", axis=1
        )
        cast = onnx.helper.make_node(
            "Cast", inputs=["E"], outputs=["output"], to=TensorProto.FLOAT
        )
        graph_def = helper.make_graph(
            name="test_graph",
            nodes=[concat1, concat2, cast],
            inputs=[
                helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, "unk__0"]),
                helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3]),
                helper.make_tensor_value_info("D", TensorProto.FLOAT, [2, "unk__1"]),
            ],
            outputs=[
                helper.make_tensor_value_info(
                    "output", TensorProto.FLOAT, [2, "unk__0"]
                )
            ],
        )

        onnx_model = make_model(graph_def)
        original_count = self._count_unique_dim_param_number(onnx_model)
        inferred_model = onnx.shape_inference.infer_shapes(onnx_model, strict_mode=True)
        inferred_count = self._count_unique_dim_param_number(inferred_model)
        # to prevent duplicate so the inferred count will be count + 2
        # new symbol 'unk__2' and 'unk__3' should be generated
        # original: {'unk_0', 'unk__1'}
        # inferred: {'unk_0', 'unk__1', 'unk__2', 'unk__3'}
        assert inferred_count == original_count + 2, f"{inferred_model}{onnx_model}"

    def test_unknown_shape(self) -> None:
        concat = helper.make_node(
            "Concat", inputs=["A", "B"], outputs=["C"], name="Concat", axis=1
        )
        cast = onnx.helper.make_node(
            "Cast", inputs=["C"], outputs=["output"], to=TensorProto.FLOAT
        )
        graph_def = helper.make_graph(
            name="test_graph",
            nodes=[concat, cast],
            inputs=[
                helper.make_tensor_value_info(
                    "A", TensorProto.FLOAT, [3, None]
                ),  # unknown shape
                helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, None]),
            ],
            outputs=[
                helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, None])
            ],
        )

        onnx_model = make_model(graph_def)
        inferred_model = onnx.shape_inference.infer_shapes(onnx_model, strict_mode=True)
        self._assert_valueinfo_shape(
            inferred_model, [make_tensor_value_info("C", TensorProto.FLOAT, (3, -1))]
        )
        # the symbolic shape of C and output should be the same
        # ('unk__0', 'unk__1')
        assert self._get_shape_from_name(
            inferred_model, "C"
        ) == self._get_shape_from_name(inferred_model, "output")


if __name__ == "__main__":
    unittest.main()
