# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0

import unittest
from typing import Callable, List, Optional, Sequence, Tuple

import numpy as np

from onnx import (
    FunctionProto,
    GraphProto,
    ModelProto,
    NodeProto,
    SparseTensorProto,
    TensorProto,
    ValueInfoProto,
    checker,
    compose,
    helper,
    parser,
    version_converter,
)


def _load_model(m_def: str) -> ModelProto:
    """Parses a model from a string representation, including checking the model for correctness"""
    m = parser.parse_model(m_def)
    checker.check_model(m)
    return m


def _prefixed(prefix: str, s: str) -> str:
    """Prefixes a string (if not empty)"""
    return prefix + s if len(s) > 0 else s


def _get_shape(value_info: ValueInfoProto) -> List[int]:
    """Returns a list of integers representing the shape of the provided ValueInfoProto"""
    return [
        value_info.type.tensor_type.shape.dim[d].dim_value
        for d in range(len(value_info.type.tensor_type.shape.dim))
    ]


def _make_sparse_tensor(name: str) -> SparseTensorProto:
    dense_shape = [3, 3]
    linear_indices = [2, 3, 5]
    sparse_values = [1.7, 0.4, 0.9]
    values_tensor = helper.make_tensor(
        name=name + "_values",
        data_type=TensorProto.FLOAT,
        dims=[len(sparse_values)],
        vals=np.array(sparse_values).astype(np.float32),
        raw=False,
    )

    indices_tensor = helper.make_tensor(
        name=name + "_idx",
        data_type=TensorProto.INT64,
        dims=[len(linear_indices)],
        vals=np.array(linear_indices).astype(np.int64),
        raw=False,
    )
    return helper.make_sparse_tensor(values_tensor, indices_tensor, dense_shape)


M1_DEF = """
    <
        ir_version: 7,
        opset_import: [ "": 10, "com.microsoft": 1]
    >
    agraph (float[N, M] A0, float[N, M] A1, float[N, M] _A) => (float[N, M] B00, float[N, M] B10, float[N, M] B20)
    {
        B00 = Add(A0, A1)
        B10 = Sub(A0, A1)
        B20 = Mul(A0, A1)
    }
    """

M2_DEF = """
    <
        ir_version: 7,
        opset_import: [ "": 10, "com.microsoft": 1]
    >
    agraph (float[N, M] B01, float[N, M] B11, float[N, M] B21) => (float[N, M] D0)
    {
        C0 = Add(B01, B11)
        C1 = Sub(B11, B21)
        M1 = Mul(C0, C1)
    }
    """


class TestComposeFunctions(unittest.TestCase):
    def _test_merge_models(
        self,
        m1def: str,
        m2def: str,
        io_map: List[Tuple[str, str]],
        check_expectations: Callable[[GraphProto, GraphProto, GraphProto], None],
        inputs: Optional[List[str]] = None,
        outputs: Optional[List[str]] = None,
        prefix1: Optional[str] = None,
        prefix2: Optional[str] = None,
    ) -> None:
        m1, m2 = _load_model(m1def), _load_model(m2def)
        g3 = compose.merge_graphs(
            m1.graph,
            m2.graph,
            io_map=io_map,
            inputs=inputs,
            outputs=outputs,
            prefix1=prefix1,
            prefix2=prefix2,
        )
        checker.check_graph(g3)
        check_expectations(m1.graph, m2.graph, g3)
        m3 = compose.merge_models(
            m1,
            m2,
            io_map=io_map,
            inputs=inputs,
            outputs=outputs,
            prefix1=prefix1,
            prefix2=prefix2,
        )
        checker.check_model(m3)
        check_expectations(m1.graph, m2.graph, m3.graph)

    def test_case_connect_all_no_name_collision(self) -> None:
        """Tests a simple scenario where two models without overlapping names are merged by
        connecting all the outputs in the first models to all the inputs in the second model
        """

        def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
            self.assertEqual(g3.input, g1.input)
            self.assertEqual(g3.output, g2.output)
            self.assertEqual(
                ["Add", "Sub", "Mul", "Add", "Sub", "Mul"],
                [item.op_type for item in g3.node],
            )

        io_map = [("B00", "B01"), ("B10", "B11"), ("B20", "B21")]
        self._test_merge_models(M1_DEF, M2_DEF, io_map, check_expectations)

    def test_case_connect_same_output_twice(self) -> None:
        """Tests a scenario where we merge two models by connecting a single output in the first model
        to all the inputs in the second
        """

        def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
            del g2  # Unused
            self.assertEqual(g3.input, g1.input)
            self.assertEqual(["B10", "B20", "D0"], [elem.name for elem in g3.output])
            self.assertEqual(
                ["Add", "Sub", "Mul", "Add", "Sub", "Mul"],
                [item.op_type for item in g3.node],
            )

        io_map = [("B00", "B01"), ("B00", "B11"), ("B00", "B21")]
        self._test_merge_models(M1_DEF, M2_DEF, io_map, check_expectations)

    def test_case_connect_same_output_drop_outputs(self) -> None:
        """Tests a scenario where we merge two models by connecting a single output in the first model
        to all the inputs in the second, while dropping the rest of the outputs in the first model
        """

        def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
            del g2  # Unused
            self.assertEqual(g3.input, g1.input)
            self.assertEqual(["D0"], [elem.name for elem in g3.output])
            self.assertEqual(
                ["Add", "Add", "Sub", "Mul"], [item.op_type for item in g3.node]
            )

        io_map = [("B00", "B01"), ("B00", "B11"), ("B00", "B21")]
        outputs = ["D0"]
        self._test_merge_models(
            M1_DEF, M2_DEF, io_map, check_expectations, outputs=outputs
        )

    def test_case_connect_same_input_output_name(self) -> None:
        """Tests a scenario where we merge two models, where the inputs/outputs connected
        are named exactly the same
        """
        m1_def = """
            <
                ir_version: 7,
                opset_import: [ "": 10]
            >
            agraph (float[N, M] A) => (float[N, M] B)
            {
                B = Add(A, A)
            }
            """
        m2_def = """
            <
                ir_version: 7,
                opset_import: [ "": 10]
            >
            agraph (float[N, M] B) => (float[N, M] C)
            {
                C = Add(B, B)
            }
            """
        io_map = [("B", "B")]

        def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
            del g1, g2  # Unused

            self.assertEqual(["A"], [elem.name for elem in g3.input])
            self.assertEqual(["C"], [elem.name for elem in g3.output])

        self._test_merge_models(m1_def, m2_def, io_map, check_expectations)

    def test_case_drop_inputs_outputs(self) -> None:
        """Tests a scenario where we merge two models, not including some of the inputs/outputs"""
        m1_def = """
            <
                ir_version: 7,
                opset_import: [ "": 10]
            >
            agraph (float[N] A0, float[N] B0) => (float[N] A1, float[N] B1)
            {
                A1 = Add(A0, A0)
                B1 = Sub(B0, B0)
            }
            """
        m2_def = """
            <
                ir_version: 7,
                opset_import: [ "": 10]
            >
            agraph (float[N] A2, float[N] B2) => (float[N] A3, float[N] B3)
            {
                A3 = Add(A2, A2)
                B3 = Sub(B2, B2)
            }
            """
        io_map = [("A1", "B2")]

        def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
            del g1, g2  # Unused

            self.assertEqual(["A0"], [elem.name for elem in g3.input])
            self.assertEqual(["B3"], [elem.name for elem in g3.output])
            self.assertEqual(["Add", "Sub"], [elem.op_type for elem in g3.node])

        inputs = ["A0"]
        outputs = ["B3"]
        self._test_merge_models(
            m1_def, m2_def, io_map, check_expectations, inputs=inputs, outputs=outputs
        )

    def test_case_name_collision_prefix(self) -> None:
        """Tests a scenario where we merge two models that have name collisions, but they
        are avoided by prefixing the models model.
        """
        m1_def = """
            <
                ir_version: 7,
                opset_import: [ "": 10]
            >
            agraph (float[N] A, float[N] B) => (float[N] C)
            {
                C = Add(A, B)
            }
            """
        io_map = [("C", "A")]

        def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
            del g1, g2  # Unused

            self.assertEqual(["m1/A", "m1/B", "m2/B"], [elem.name for elem in g3.input])
            self.assertEqual(["m2/C"], [elem.name for elem in g3.output])
            self.assertEqual(["Add", "Add"], [elem.op_type for elem in g3.node])

        self._test_merge_models(
            m1_def, m1_def, io_map, check_expectations, prefix1="m1/", prefix2="m2/"
        )

    def test_case_connect_partially_no_name_collision(self) -> None:
        """Tests a scenario where two models without overlapping names are merged by
        connecting some outputs from the first model to some inputs in the second.
        The remaining inputs/outputs should be present in the combined model
        """

        def check_expectations(g1: GraphProto, g2: GraphProto, g4: GraphProto) -> None:
            del g1, g2  # Unused

            # B20 <-> B21 not connected. They should still be present
            # in the inputs and outputs of the combined graph
            self.assertEqual(
                ["A0", "A1", "_A", "B21"], [elem.name for elem in g4.input]
            )
            self.assertEqual(["B20", "D0"], [elem.name for elem in g4.output])

        io_map = [("B00", "B01"), ("B10", "B11")]
        self._test_merge_models(M1_DEF, M2_DEF, io_map, check_expectations)

    def test_merge_models_with_metadata_props(self) -> None:
        m1 = _load_model(M1_DEF)
        helper.set_model_props(m1, {"p1": "v1", "p2": "v2"})

        m2 = _load_model(M2_DEF)
        helper.set_model_props(m2, {"p3": "v3", "p4": "v4"})

        io_map = [("B00", "B01")]
        m3 = compose.merge_models(m1, m2, io_map=io_map)
        assert len(m3.metadata_props) == 4

        # Overlap, but same value
        helper.set_model_props(m2, {"p1": "v1", "p4": "v4"})
        m3 = compose.merge_models(m1, m2, io_map=io_map)
        assert len(m3.metadata_props) == 3

        # Same keys but not same value. Error
        helper.set_model_props(m2, {"p1": "v5", "p4": "v4"})
        self.assertRaises(ValueError, compose.merge_models, m1, m2, io_map=io_map)

    def test_error_wrong_input_output_name(self) -> None:
        """Tests that providing a non existing output/input name in the io_map argument produces an error."""
        m1, m2 = _load_model(M1_DEF), _load_model(M2_DEF)

        self.assertRaises(
            ValueError,
            compose.merge_models,
            m1,
            m2,
            io_map=[("wrong_outname", "B01"), ("B10", "B11"), ("B20", "B21")],
        )

        # Wrong output name
        self.assertRaises(
            ValueError,
            compose.merge_models,
            m1,
            m2,
            io_map=[("B00", "wrong_input"), ("B10", "B11"), ("B20", "B21")],
        )

    def test_error_ir_version_mismatch(self) -> None:
        m1 = _load_model(
            """
    <
        ir_version: 7,
        opset_import: [ "": 13]
    >
    agraph (float[N, M] X0) => (float[N, M] Y0)
    {
        Y0 = Add(X0, X0)
    }
    """
        )

        m2 = _load_model(
            """
    <
        ir_version: 6,
        opset_import: [ "": 13]
    >
    agraph (float[N, M] X1) => (float[N, M] Y1)
    {
        Y1 = Add(X1, X1)
    }
    """
        )
        # Wrong IR version name
        self.assertRaises(
            ValueError, compose.merge_models, m1, m2, io_map=[("Y0", "X1")]
        )

    def test_error_opset_import_mismatch(self) -> None:
        """Tests that providing models with different operator set imported produces an error."""
        m1, m2 = _load_model(M1_DEF), _load_model(M2_DEF)
        m1 = helper.make_model(
            m1.graph, producer_name="test", opset_imports=[helper.make_opsetid("", 10)]
        )
        m2 = helper.make_model(
            m2.graph, producer_name="test", opset_imports=[helper.make_opsetid("", 15)]
        )

        io_map = [("B00", "B01"), ("B10", "B11"), ("B20", "B21")]
        self.assertRaises(ValueError, compose.merge_models, m1, m2, io_map)

        # Converting to the same Operator set version, should work
        m1 = version_converter.convert_version(m1, 15)
        m3 = compose.merge_models(m1, m2, io_map=io_map)
        checker.check_model(m3)

    # FIXME: This function should be removed, as tests should not contain a copy of the tested logic.
    def _test_add_prefix(
        self,
        rename_nodes: bool = False,
        rename_edges: bool = False,
        rename_inputs: bool = False,
        rename_outputs: bool = False,
        rename_initializers: bool = False,
        rename_value_infos: bool = False,
        inplace: bool = False,
    ) -> None:
        m1 = _load_model(M1_DEF)

        prefix = "pre/"

        if inplace:
            m2 = ModelProto()
            m2.CopyFrom(m1)
            compose.add_prefix(
                m2,
                prefix,
                rename_nodes=rename_nodes,
                rename_edges=rename_edges,
                rename_inputs=rename_inputs,
                rename_outputs=rename_outputs,
                rename_initializers=rename_initializers,
                rename_value_infos=rename_value_infos,
                inplace=True,
            )
        else:
            m2 = compose.add_prefix(
                m1,
                prefix,
                rename_nodes=rename_nodes,
                rename_edges=rename_edges,
                rename_inputs=rename_inputs,
                rename_outputs=rename_outputs,
                rename_initializers=rename_initializers,
                rename_value_infos=rename_value_infos,
            )
        g_in = m1.graph
        g_out = m2.graph

        if (
            rename_edges
            or rename_inputs
            or rename_outputs
            or rename_initializers
            or rename_value_infos
        ):
            name_mapping = {}

            # Rename inputs/outputs/edges. Propagate name changes from and to edges
            if rename_edges:
                for n in g_in.node:
                    for e in n.input:
                        name_mapping[e] = _prefixed(prefix, e)
                    for e in n.output:
                        name_mapping[e] = _prefixed(prefix, e)
            if rename_inputs:
                for elem in g_in.input:
                    name_mapping[elem.name] = _prefixed(prefix, elem.name)
            if rename_outputs:
                for elem in g_in.output:
                    name_mapping[elem.name] = _prefixed(prefix, elem.name)

            if rename_initializers:
                for init in g_in.initializer:
                    name_mapping[init.name] = _prefixed(prefix, init.name)
                for sparse_init in g_in.sparse_initializer:
                    name_mapping[sparse_init.values.name] = _prefixed(
                        prefix, sparse_init.values.name
                    )
                    name_mapping[sparse_init.indices.name] = _prefixed(
                        prefix, sparse_init.indices.name
                    )

            if rename_value_infos:
                for value_info in g_in.output:
                    name_mapping[value_info.name] = _prefixed(prefix, value_info.name)

            for n1, n0 in zip(g_out.node, g_in.node):
                for e1, e0 in zip(n1.input, n0.input):
                    self.assertEqual(name_mapping.get(e0, e0), e1)
                for e1, e0 in zip(n1.output, n0.output):
                    self.assertEqual(name_mapping.get(e0, e0), e1)
            for i1, i0 in zip(g_out.input, g_in.input):
                self.assertEqual(name_mapping.get(i0.name, i0.name), i1.name)
            for o1, o0 in zip(g_out.output, g_in.output):
                self.assertEqual(name_mapping.get(o0.name, o0.name), o1.name)

            for init1, init0 in zip(g_out.initializer, g_in.initializer):
                self.assertEqual(name_mapping.get(init0.name, init0.name), init1.name)

            for sparse_init1, sparse_init0 in zip(
                g_out.sparse_initializer, g_in.sparse_initializer
            ):
                self.assertEqual(
                    name_mapping.get(
                        sparse_init0.values.name, sparse_init0.values.name
                    ),
                    sparse_init1.values.name,
                )
                self.assertEqual(
                    name_mapping.get(
                        sparse_init0.indices.name, sparse_init0.indices.name
                    ),
                    sparse_init1.indices.name,
                )

            for vi1, vi0 in zip(g_out.value_info, g_in.value_info):
                self.assertEqual(name_mapping.get(vi0.name, vi0.name), vi1.name)

            if rename_nodes:
                for n1, n0 in zip(g_out.node, g_in.node):
                    self.assertEqual(_prefixed(prefix, n0.name), n1.name)

    def test_add_prefix_nodes(self) -> None:
        """Tests renaming nodes only"""
        self._test_add_prefix(rename_nodes=True)

    def test_add_prefix_edges(self) -> None:
        """Tests prefixing nodes edges. This will also rename inputs/outputs, since the names are shared"""
        self._test_add_prefix(rename_edges=True)

    def test_add_prefix_inputs(self) -> None:
        """Tests prefixing graph inputs only. Relevant node edges should be renamed as well"""
        self._test_add_prefix(rename_inputs=True)

    def test_add_prefix_outputs(self) -> None:
        """Tests prefixing graph outputs only. Relevant node edges should be renamed as well"""
        self._test_add_prefix(rename_outputs=True)

    def test_add_prefix_attribute_subgraph(self) -> None:
        """Tests prefixing attribute's subgraph. Relevant subgraph should be renamed as well"""
        C = helper.make_tensor_value_info("C", TensorProto.BOOL, [1])
        X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, 1])
        Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, 1])
        Z = helper.make_tensor_value_info("Z", TensorProto.FLOAT, [None, 1])
        Out = helper.make_tensor_value_info("Out", TensorProto.FLOAT, [None, 1])

        XY = helper.make_node("Mul", inputs=["X", "Y"], outputs=["XY"])
        add = helper.make_node("Add", inputs=["XY", "Z"], outputs=["Out"])
        sub = helper.make_node("Sub", inputs=["XY", "Z"], outputs=["Out"])

        cond = helper.make_node(
            "If",
            inputs=["C"],
            outputs=["Out"],
            then_branch=helper.make_graph(
                nodes=[add], name="then", inputs=[], outputs=[Out]
            ),
            else_branch=helper.make_graph(
                nodes=[sub], name="else", inputs=[], outputs=[Out]
            ),
        )
        graph = helper.make_graph(
            nodes=[XY, cond], name="graph", inputs=[C, X, Y, Z], outputs=[Out]
        )
        prefix = "prefix."
        prefixed_graph = compose.add_prefix_graph(graph, prefix)
        checker.check_graph(prefixed_graph)
        for n1, n0 in zip(prefixed_graph.node, graph.node):
            self.assertEqual(_prefixed(prefix, n0.name), n1.name)
            for attribute1, attribute0 in zip(n1.attribute, n0.attribute):
                if attribute1.g:
                    for subgraph_n1, subgraph_n0 in zip(
                        attribute1.g.node, attribute0.g.node
                    ):
                        for input_n1, input_n0 in zip(
                            subgraph_n1.input, subgraph_n0.input
                        ):
                            self.assertEqual(_prefixed(prefix, input_n0), input_n1)
                        for output_n1, output_n0 in zip(
                            subgraph_n1.output, subgraph_n0.output
                        ):
                            self.assertEqual(_prefixed(prefix, output_n0), output_n1)

    def test_add_prefix_all(self) -> None:
        """Tests prefixing all names in the graph"""
        self._test_add_prefix(True, True, True, True, True, True)

    def test_add_prefix_inplace(self) -> None:
        """Tests prefixing inplace"""
        self._test_add_prefix(inplace=True)

    def test_expand_out_dim(self) -> None:
        """Tests expanding output dimensions. The resulting graph should have the same output names,
        but with one more dimension at the specified index.
        """
        m1 = _load_model(M1_DEF)

        def _check_model(m1: ModelProto, m2: ModelProto, dim_idx: int) -> None:
            for out_g2, out_g1 in zip(m2.graph.output, m1.graph.output):
                self.assertEqual(out_g2.name, out_g1.name)
                self.assertEqual(
                    out_g2.type.tensor_type.elem_type, out_g1.type.tensor_type.elem_type
                )
                expected_out_shape = _get_shape(out_g1)
                expected_out_shape.insert(dim_idx, 1)
                self.assertEqual(_get_shape(out_g2), expected_out_shape)

        for dim_idx in [0, 2, -1, -3]:
            m2 = compose.expand_out_dim(m1, dim_idx)
            _check_model(m1, m2, dim_idx)

        # Test inplace
        m2 = ModelProto()
        m2.CopyFrom(m1)
        dim_idx = 0
        compose.expand_out_dim(m2, dim_idx, inplace=True)
        _check_model(m1, m2, dim_idx)

    def _test_overlapping_names(
        self,
        inputs0: Sequence[str] = ("i0", "i1"),
        inputs1: Sequence[str] = ("i2", "i3"),
        outputs0: Sequence[str] = ("o0", "o1"),
        outputs1: Sequence[str] = ("o2", "o3"),
        value_info0: Sequence[str] = ("v0", "v1"),
        value_info1: Sequence[str] = ("v2", "v3"),
        initializer0: Sequence[str] = ("init0", "init1"),
        initializer1: Sequence[str] = ("init2", "init3"),
        sparse_initializer0: Sequence[str] = ("sparse_init0", "sparse_init1"),
        sparse_initializer1: Sequence[str] = ("sparse_init2", "sparse_init3"),
    ) -> None:
        n0 = [
            helper.make_node("Identity", inputs=[inputs0[i]], outputs=[outputs0[i]])
            for i in range(len(inputs0))
        ]
        i0 = [
            helper.make_tensor_value_info(inputs0[i], TensorProto.FLOAT, [])
            for i in range(len(inputs0))
        ]
        o0 = [
            helper.make_tensor_value_info(outputs0[i], TensorProto.FLOAT, [])
            for i in range(len(outputs0))
        ]
        vi0 = [
            helper.make_tensor_value_info(value_info0[i], TensorProto.FLOAT, [])
            for i in range(len(value_info0))
        ]
        init0 = [
            helper.make_tensor(
                name=initializer0[i], data_type=TensorProto.INT64, dims=(), vals=[1]
            )
            for i in range(len(initializer0))
        ]

        sparse_init0 = [
            _make_sparse_tensor(sparse_initializer0[i])
            for i in range(len(sparse_initializer0))
        ]

        n1 = [
            helper.make_node("Identity", inputs=[inputs1[i]], outputs=[outputs1[i]])
            for i in range(len(inputs1))
        ]
        i1 = [
            helper.make_tensor_value_info(inputs1[i], TensorProto.FLOAT, [])
            for i in range(len(inputs1))
        ]
        o1 = [
            helper.make_tensor_value_info(outputs1[i], TensorProto.FLOAT, [])
            for i in range(len(outputs1))
        ]
        vi1 = [
            helper.make_tensor_value_info(value_info1[i], TensorProto.FLOAT, [])
            for i in range(len(value_info1))
        ]
        init1 = [
            helper.make_tensor(
                name=initializer1[i], data_type=TensorProto.INT64, dims=(), vals=[1]
            )
            for i in range(len(initializer1))
        ]
        sparse_init1 = [
            _make_sparse_tensor(sparse_initializer1[i])
            for i in range(len(sparse_initializer1))
        ]

        ops = [helper.make_opsetid("", 10)]
        m0 = helper.make_model(
            helper.make_graph(
                nodes=n0,
                name="g0",
                inputs=i0,
                outputs=o0,
                value_info=vi0,
                initializer=init0,
                sparse_initializer=sparse_init0,
            ),
            producer_name="test",
            opset_imports=ops,
        )
        m1 = helper.make_model(
            helper.make_graph(
                nodes=n1,
                name="g1",
                inputs=i1,
                outputs=o1,
                value_info=vi1,
                initializer=init1,
                sparse_initializer=sparse_init1,
            ),
            producer_name="test",
            opset_imports=ops,
        )

        overlap = compose.check_overlapping_names(m0.graph, m1.graph)
        i = 0

        overlapping_inputs = list(set(inputs0) & set(inputs1))
        overlapping_outputs = list(set(outputs0) & set(outputs1))
        overlapping_edges = list(set(overlapping_inputs + overlapping_outputs))
        if overlapping_edges:
            self.assertEqual(overlap[i], ("edge", overlapping_edges))
            i += 1

        overlapping_vis = list(set(value_info0) & set(value_info1))
        if overlapping_vis:
            self.assertEqual(overlap[i], ("value_info", overlapping_vis))
            i += 1

        overlapping_init = list(set(initializer0) & set(initializer1))
        if overlapping_init:
            self.assertEqual(overlap[i], ("initializer", overlapping_init))
            i += 1

        overlapping_sparse_init = list(
            set(sparse_initializer0) & set(sparse_initializer1)
        )
        if overlapping_sparse_init:
            expected_overlap = []
            for overlapping_name in overlapping_sparse_init:
                expected_overlap.append(overlapping_name + "_values")
                expected_overlap.append(overlapping_name + "_idx")
            self.assertEqual(overlap[i], ("sparse_initializer", expected_overlap))
            i += 1

        m0_new = compose.add_prefix(m0, prefix="g0/")
        overlap = compose.check_overlapping_names(m0_new.graph, m1.graph)
        self.assertEqual(0, len(overlap))

    def test_overlapping_input_names(self) -> None:
        """Tests error checking when the name of the inputs overlaps"""
        self._test_overlapping_names(inputs0=["i0", "i1"], inputs1=["i1", "i2"])

    def test_overlapping_output_names(self) -> None:
        """Tests error checking when the name of the output overlaps"""
        self._test_overlapping_names(outputs0=["o0", "o1"], outputs1=["o1", "o2"])

    def test_overlapping_value_info_names(self) -> None:
        """Tests error checking when the name of value_info entries overlaps"""
        self._test_overlapping_names(
            value_info0=["vi0", "vi1"], value_info1=["vi1", "vi2"]
        )

    def test_overlapping_initializer_names(self) -> None:
        """Tests error checking when the name of initializer entries overlaps"""
        self._test_overlapping_names(
            initializer0=["init0", "init1"], initializer1=["init1", "init2"]
        )

    def test_overlapping_sparse_initializer_names(self) -> None:
        """Tests error checking when the name of sparse_initializer entries overlaps"""
        self._test_overlapping_names(
            sparse_initializer0=["sparse_init0", "sparse_init1"],
            sparse_initializer1=["sparse_init1", "sparse_init2"],
        )

    def test_overlapping_function_names(self) -> None:
        """Tests error checking when the name of local function entries overlaps"""
        ops = [helper.make_opsetid("", 10), helper.make_opsetid("local", 10)]

        def _make_function(
            domain: str,
            fname: str,
            inputs: List[str],
            outputs: List[str],
            nodes: List[NodeProto],
        ) -> FunctionProto:
            f = FunctionProto()
            f.domain = domain
            f.name = fname
            f.input.extend(inputs)
            f.output.extend(outputs)
            f.node.extend(nodes)
            f.opset_import.extend(ops)
            return f

        ops = [helper.make_opsetid("", 10), helper.make_opsetid("local", 10)]

        g = GraphProto()
        g.input.extend(
            [
                helper.make_tensor_value_info("x0", TensorProto.FLOAT, []),
                helper.make_tensor_value_info("x1", TensorProto.FLOAT, []),
            ]
        )
        g.output.extend(
            [
                helper.make_tensor_value_info("y", TensorProto.FLOAT, []),
            ]
        )
        g.node.extend(
            [helper.make_node("f1", domain="local", inputs=["x0", "x1"], outputs=["y"])]
        )

        g1 = GraphProto()
        g1.CopyFrom(g)
        g1.name = "g1"
        m1 = helper.make_model(g1, producer_name="test", opset_imports=ops)
        m1.functions.extend(
            [
                _make_function(
                    "local",
                    "f1",
                    ["x0", "x1"],
                    ["y"],
                    [helper.make_node("Add", inputs=["x0", "x1"], outputs=["y"])],
                )
            ]
        )
        checker.check_model(m1)

        g2 = GraphProto()
        g2.CopyFrom(g)
        g2.name = "g2"
        m2 = helper.make_model(g2, producer_name="test", opset_imports=ops)
        m2.functions.extend(
            [
                _make_function(
                    "local",
                    "f1",
                    ["x0", "x1"],
                    ["y"],
                    [helper.make_node("Mul", inputs=["x0", "x1"], outputs=["y"])],
                )
            ]
        )
        checker.check_model(m2)

        m = compose.merge_models(
            m1, m2, io_map=[("y", "x0"), ("y", "x1")], prefix1="m1/", prefix2="m2/"
        )
        checker.check_model(m)

        nodes = [n.op_type for n in m.graph.node]
        self.assertEqual(["m1/f1", "m2/f1"], nodes)

        functions = [f.name for f in m.functions]
        self.assertEqual(["m1/f1", "m2/f1"], functions)

        g3 = GraphProto()
        g3.CopyFrom(g)
        g3.name = "g3"
        g3.node[0].op_type = "f2"
        m3 = helper.make_model(g3, producer_name="test", opset_imports=ops)
        m3.functions.extend(
            [
                _make_function(
                    "local",
                    "f1",
                    ["x0", "x1"],
                    ["y"],
                    [
                        helper.make_node("Add", inputs=["x0", "x1"], outputs=["y0"]),
                        helper.make_node("Mul", inputs=["x0", "x1"], outputs=["y1"]),
                        helper.make_node("Add", inputs=["y0", "y1"], outputs=["y"]),
                    ],
                ),
                _make_function(
                    "local",
                    "f2",
                    ["x0", "x1"],
                    ["y"],
                    [
                        helper.make_node(
                            "f1", domain="local", inputs=["x0", "x1"], outputs=["y0"]
                        ),
                        helper.make_node("Mul", inputs=["x0", "x1"], outputs=["y1"]),
                        helper.make_node("Add", inputs=["y0", "y1"], outputs=["y"]),
                    ],
                ),
            ]
        )
        checker.check_model(m3)

        m = compose.merge_models(
            m1, m3, io_map=[("y", "x0"), ("y", "x1")], prefix1="m1/", prefix2="m3/"
        )
        checker.check_model(m)

        nodes = [n.op_type for n in m.graph.node]
        self.assertEqual(["m1/f1", "m3/f2"], nodes)

        functions = [f.name for f in m.functions]
        self.assertEqual(["m1/f1", "m3/f1", "m3/f2"], functions)

        self.assertEqual(["Add"], [n.op_type for n in m.functions[0].node])
        self.assertEqual(
            ["Add", "Mul", "Add"], [n.op_type for n in m.functions[1].node]
        )
        self.assertEqual(
            ["m3/f1", "Mul", "Add"], [n.op_type for n in m.functions[2].node]
        )

    def test_merge_drop_unnecessary_initializers_and_value_info(self) -> None:
        """Tests automatic removal of initializers when merging graphs"""
        ops = [helper.make_opsetid("", 10)]

        g = GraphProto()
        g.input.extend([helper.make_tensor_value_info("x", TensorProto.FLOAT, [])])
        g.output.extend([helper.make_tensor_value_info("y", TensorProto.FLOAT, [])])
        g.node.extend([helper.make_node("Identity", inputs=["x"], outputs=["y"])])

        g1 = GraphProto()
        g1.CopyFrom(g)
        g1.name = "g1"
        m1 = helper.make_model(g1, producer_name="test", opset_imports=ops)
        checker.check_model(m1)

        g2 = GraphProto()
        g2.CopyFrom(g)
        g2.name = "g2"
        g2.initializer.extend(
            [
                helper.make_tensor(
                    name="x", data_type=TensorProto.FLOAT, dims=(), vals=[0]
                )
            ]
        )
        m2 = helper.make_model(g2, producer_name="test", opset_imports=ops)
        checker.check_model(m2)

        g3 = GraphProto()
        g3.CopyFrom(g)
        g3.name = "g3"
        g3.sparse_initializer.extend([_make_sparse_tensor("x")])
        m3 = helper.make_model(g3, producer_name="test", opset_imports=ops)
        checker.check_model(m3)

        g4 = GraphProto()
        g4.CopyFrom(g)
        g4.name = "g3"
        g4.value_info.extend(
            [helper.make_tensor_value_info("x", TensorProto.FLOAT, [])]
        )
        m4 = helper.make_model(g4, producer_name="test", opset_imports=ops)
        checker.check_model(m4)

        # Initializer 'x' from m1 is removed, because there is no longer an input with that name
        out_m1 = compose.merge_models(m1, m2, prefix1="m1/", io_map=[("y", "x")])
        self.assertEqual(0, len(out_m1.graph.initializer))

        # Sparse initializer 'x' from m1 is removed, because there is no longer an input with that name
        out_m2 = compose.merge_models(m1, m3, prefix1="m1/", io_map=[("y", "x")])
        self.assertEqual(0, len(out_m2.graph.initializer))

        # Value info 'x' from m1 is removed, because there is no longer an input with that name
        out_m3 = compose.merge_models(m1, m4, prefix1="m1/", io_map=[("y", "x")])
        self.assertEqual(0, len(out_m3.graph.value_info))


if __name__ == "__main__":
    unittest.main()
