# SPDX-License-Identifier: Apache-2.0

# Copyright (c) ONNX Project Contributors

import unittest

import onnx
from onnx import checker, utils


class TestFunction(unittest.TestCase):
    def _verify_function_set(self, extracted_model, function_set, func_domain):  # type: ignore
        checker.check_model(extracted_model)
        self.assertEqual(len(extracted_model.functions), len(function_set))
        for function in function_set:
            self.assertIsNotNone(
                next(
                    (
                        f
                        for f in extracted_model.functions
                        if f.name == function and f.domain == func_domain
                    ),
                    None,
                )
            )

    def test_extract_model_with_local_function(self) -> None:
        r"""#   1. build a model with graph below. extract models with output combinations
        #   2. validate extracted models' local functions
        #
        # model graph:
        #      i0                    i1                 i2
        #      |   __________________|__________________/_________
        #      |  |                  |             |   /          |
        #      |  |                  |             |  /           |
        #   func_add        func_identity          add         identity
        #    |  ___\___________\____________________|_________    |
        #    | |    \           \                   |  _______|___|
        #    | |     \           \                  | |       |   |
        #    add     function_nested_identity_add   add     function_nested_identity_add
        #     |                 |                    |              |
        #     |                 |                    |              |
        #   o_func_add      o_all_func0           o_no_func     o_all_func1
        #
        # where function_nested_identity_add is a function that is defined with functions:
        #       a               b
        #       |               |
        #   func_identity   func_identity
        #             \       /
        #             func_add
        #                |
        #                c
        #
        """
        # function common
        func_domain = "local"
        func_opset_imports = [onnx.helper.make_opsetid("", 14)]
        func_nested_opset_imports = [
            onnx.helper.make_opsetid("", 14),
            onnx.helper.make_opsetid(func_domain, 1),
        ]

        # add function
        func_add_name = "func_add"
        func_add_inputs = ["a", "b"]
        func_add_outputs = ["c"]
        func_add_nodes = [onnx.helper.make_node("Add", ["a", "b"], ["c"])]
        func_add = onnx.helper.make_function(
            func_domain,
            func_add_name,
            func_add_inputs,
            func_add_outputs,
            func_add_nodes,
            func_opset_imports,
        )

        # identity function
        func_identity_name = "func_identity"
        func_identity_inputs = ["a"]
        func_identity_outputs = ["b"]
        func_identity_nodes = [onnx.helper.make_node("Identity", ["a"], ["b"])]
        func_identity = onnx.helper.make_function(
            func_domain,
            func_identity_name,
            func_identity_inputs,
            func_identity_outputs,
            func_identity_nodes,
            func_opset_imports,
        )

        # nested identity/add function
        func_nested_identity_add_name = "func_nested_identity_add"
        func_nested_identity_add_inputs = ["a", "b"]
        func_nested_identity_add_outputs = ["c"]
        func_nested_identity_add_nodes = [
            onnx.helper.make_node("func_identity", ["a"], ["a1"], domain=func_domain),
            onnx.helper.make_node("func_identity", ["b"], ["b1"], domain=func_domain),
            onnx.helper.make_node("func_add", ["a1", "b1"], ["c"], domain=func_domain),
        ]
        func_nested_identity_add = onnx.helper.make_function(
            func_domain,
            func_nested_identity_add_name,
            func_nested_identity_add_inputs,
            func_nested_identity_add_outputs,
            func_nested_identity_add_nodes,
            func_nested_opset_imports,
        )

        # create graph nodes
        node_func_add = onnx.helper.make_node(
            func_add_name, ["i0", "i1"], ["t0"], domain=func_domain
        )
        node_add0 = onnx.helper.make_node("Add", ["i1", "i2"], ["t2"])
        node_add1 = onnx.helper.make_node("Add", ["t0", "t2"], ["o_func_add"])
        node_func_identity = onnx.helper.make_node(
            func_identity_name, ["i1"], ["t1"], domain=func_domain
        )
        node_identity = onnx.helper.make_node("Identity", ["i1"], ["t3"])
        node_add2 = onnx.helper.make_node("Add", ["t3", "t2"], ["o_no_func"])
        node_func_nested0 = onnx.helper.make_node(
            func_nested_identity_add_name,
            ["t0", "t1"],
            ["o_all_func0"],
            domain=func_domain,
        )
        node_func_nested1 = onnx.helper.make_node(
            func_nested_identity_add_name,
            ["t3", "t2"],
            ["o_all_func1"],
            domain=func_domain,
        )

        graph_name = "graph_with_imbedded_functions"
        ir_version = 8
        opset_imports = [
            onnx.helper.make_opsetid("", 14),
            onnx.helper.make_opsetid("local", 1),
        ]
        tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=2, shape=[5])

        graph = onnx.helper.make_graph(
            [
                node_func_add,
                node_add0,
                node_add1,
                node_func_identity,
                node_identity,
                node_func_nested0,
                node_func_nested1,
                node_add2,
            ],
            graph_name,
            [
                onnx.helper.make_value_info(name="i0", type_proto=tensor_type_proto),
                onnx.helper.make_value_info(name="i1", type_proto=tensor_type_proto),
                onnx.helper.make_value_info(name="i2", type_proto=tensor_type_proto),
            ],
            [
                onnx.helper.make_value_info(
                    name="o_no_func", type_proto=tensor_type_proto
                ),
                onnx.helper.make_value_info(
                    name="o_func_add", type_proto=tensor_type_proto
                ),
                onnx.helper.make_value_info(
                    name="o_all_func0", type_proto=tensor_type_proto
                ),
                onnx.helper.make_value_info(
                    name="o_all_func1", type_proto=tensor_type_proto
                ),
            ],
        )

        meta = {
            "ir_version": ir_version,
            "opset_imports": opset_imports,
            "producer_name": "test_extract_model_with_local_function",
            "functions": [func_identity, func_add, func_nested_identity_add],
        }
        model = onnx.helper.make_model(graph, **meta)

        checker.check_model(model)
        extracted_with_no_funcion = utils.Extractor(model).extract_model(
            ["i0", "i1", "i2"], ["o_no_func"]
        )
        self._verify_function_set(extracted_with_no_funcion, {}, func_domain)

        extracted_with_add_funcion = utils.Extractor(model).extract_model(
            ["i0", "i1", "i2"], ["o_func_add"]
        )
        self._verify_function_set(
            extracted_with_add_funcion, {func_add_name}, func_domain
        )

        extracted_with_o_all_funcion0 = utils.Extractor(model).extract_model(
            ["i0", "i1", "i2"], ["o_all_func0"]
        )
        self._verify_function_set(
            extracted_with_o_all_funcion0,
            {func_add_name, func_identity_name, func_nested_identity_add_name},
            func_domain,
        )

        extracted_with_o_all_funcion1 = utils.Extractor(model).extract_model(
            ["i0", "i1", "i2"], ["o_all_func1"]
        )
        self._verify_function_set(
            extracted_with_o_all_funcion1,
            {func_add_name, func_identity_name, func_nested_identity_add_name},
            func_domain,
        )

        extracted_with_o_all_funcion2 = utils.Extractor(model).extract_model(
            ["i0", "i1", "i2"],
            ["o_no_func", "o_func_add", "o_all_func0", "o_all_func1"],
        )
        self._verify_function_set(
            extracted_with_o_all_funcion2,
            {func_add_name, func_identity_name, func_nested_identity_add_name},
            func_domain,
        )


if __name__ == "__main__":
    unittest.main()
