# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0


from typing import Dict, List, MutableMapping, Optional, Set, Tuple

from onnx import GraphProto, ModelProto, TensorProto, checker, helper, utils


def check_overlapping_names(
    g1: GraphProto, g2: GraphProto, io_map: Optional[List[Tuple[str, str]]] = None
) -> List[Tuple[str, List[str]]]:
    """Checks whether there are name collisions between two graphs

    Returns a list of tuples where the first element represents the member containing overlapping names
    (One of: "node", "edge", "value_info", "initializer", "sparse_initializer"), and the
    second element contains a list of names that appear in both graphs on that category.

    Optionally, it takes an io_map, representing the output/inputs to be connected. It provided, overlapping
    present in the io_map argument will be ignored.
    """
    if type(g1) is not GraphProto:
        raise ValueError("g1 argument is not an ONNX graph")
    if type(g2) is not GraphProto:
        raise ValueError("g2 argument is not an ONNX graph")

    def _overlapping(c1: List[str], c2: List[str]) -> List[str]:
        return list(set(c1) & set(c2))

    def _edge_names(graph: GraphProto, exclude: Optional[Set[str]] = None) -> List[str]:
        if exclude is None:
            exclude = set()
        edges = []
        for n in graph.node:
            for i in n.input:
                if i != "" and i not in exclude:
                    edges.append(i)
            for o in n.output:
                if o != "" and o not in exclude:
                    edges.append(o)
        return edges

    result = []

    if not io_map:
        io_map = []
    io_map_inputs = {elem[1] for elem in io_map}

    # Edges already cover input/output
    overlap = _overlapping(_edge_names(g1), _edge_names(g2, exclude=io_map_inputs))
    if overlap:
        result.append(("edge", overlap))

    overlap = _overlapping(
        [e.name for e in g1.value_info], [e.name for e in g2.value_info]
    )
    if overlap:
        result.append(("value_info", overlap))

    overlap = _overlapping(
        [e.name for e in g1.initializer], [e.name for e in g2.initializer]
    )
    if overlap:
        result.append(("initializer", overlap))

    overlap = _overlapping(
        [e.values.name for e in g1.sparse_initializer],
        [e.values.name for e in g2.sparse_initializer],
    ) + _overlapping(
        [e.indices.name for e in g1.sparse_initializer],
        [e.indices.name for e in g2.sparse_initializer],
    )
    if overlap:
        result.append(("sparse_initializer", overlap))

    return result


def merge_graphs(
    g1: GraphProto,
    g2: GraphProto,
    io_map: List[Tuple[str, str]],
    inputs: Optional[List[str]] = None,
    outputs: Optional[List[str]] = None,
    prefix1: Optional[str] = None,
    prefix2: Optional[str] = None,
    name: Optional[str] = None,
    doc_string: Optional[str] = None,
) -> GraphProto:
    """Combines two ONNX graphs into a single one.

    The combined graph is defined by connecting the specified set of outputs/inputs. Those inputs/outputs
    not specified in the io_map argument will remain as inputs/outputs of the combined graph.

    Arguments:
        g1 (GraphProto): First graph
        g2 (GraphProto): Second graph
        io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...]
                                          representing outputs of the first graph and inputs of the second
                                          to be connected
        inputs (list of string): Optional list of inputs to be included in the combined graph
                                 By default, all inputs not present in the ``io_map`` argument will be
                                 included in the combined model
        outputs (list of string): Optional list of outputs to be included in the combined graph
                                  By default, all outputs not present in the ``io_map`` argument will be
                                  included in the combined model
        prefix1 (string): Optional prefix to be added to all names in g1
        prefix2 (string): Optional prefix to be added to all names in g2
        name (string): Optional name for the combined graph
                       By default, the name is g1.name and g2.name concatenated with an undescore delimiter
        doc_string (string): Optional docstring for the combined graph
                             If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used

    Returns:
        GraphProto
    """
    if type(g1) is not GraphProto:
        raise ValueError("g1 argument is not an ONNX graph")
    if type(g2) is not GraphProto:
        raise ValueError("g2 argument is not an ONNX graph")

    # Prefixing names in the graph if requested, adjusting io_map accordingly
    if prefix1 or prefix2:
        if prefix1:
            g1_copy = GraphProto()
            g1_copy.CopyFrom(g1)
            g1 = g1_copy
            g1 = add_prefix_graph(g1, prefix=prefix1)
        if prefix2:
            g2_copy = GraphProto()
            g2_copy.CopyFrom(g2)
            g2 = g2_copy
            g2 = add_prefix_graph(g2, prefix=prefix2)
        io_map = [
            (
                prefix1 + io[0] if prefix1 else io[0],
                prefix2 + io[1] if prefix2 else io[1],
            )
            for io in io_map
        ]

    io_map_g1_outs = {io[0] for io in io_map}
    io_map_g2_ins = {io[1] for io in io_map}
    reversed_io_map = {in_name: out_name for out_name, in_name in io_map}
    g1_outs = {o.name for o in g1.output}
    g2_ins = {i.name for i in g2.input}

    # If necessary extract subgraphs
    if inputs or outputs:
        if not inputs:
            g1_inputs = [i.name for i in g1.input]
            g2_inputs = [i.name for i in g2.input]
        else:
            input_set = set(inputs)
            g1_inputs = [i.name for i in g1.input if i.name in input_set]
            g2_inputs = [
                i.name
                for i in g2.input
                if i.name in input_set or i.name in io_map_g2_ins
            ]

        if not outputs:
            g1_outputs = [o.name for o in g1.output]
            g2_outputs = [o.name for o in g2.output]
        else:
            output_set = set(outputs)
            g1_outputs = [
                o.name
                for o in g1.output
                if o.name in output_set or o.name in io_map_g1_outs
            ]
            g2_outputs = [o.name for o in g2.output if o.name in output_set]

        if len(g1_inputs) < len(g1.input) or len(g1_outputs) < len(g1.output):
            e1 = utils.Extractor(helper.make_model(g1))
            g1 = e1.extract_model(g1_inputs, g1_outputs).graph

        if len(g2_inputs) < len(g2.input) or len(g2_outputs) < len(g2.output):
            e2 = utils.Extractor(helper.make_model(g2))
            g2 = e2.extract_model(g2_inputs, g2_outputs).graph

    # Check that input/output names specified in the io_map argument are valid input/output names
    for g1_out_name, g2_in_name in io_map:
        if g1_out_name not in g1_outs:
            raise ValueError(f"Output {g1_out_name} is not present in g1")
        if g2_in_name not in g2_ins:
            raise ValueError(f"Input {g2_in_name} is not present in g2")

    # Check for name collision
    overlapping_names = check_overlapping_names(g1, g2, io_map)
    if len(overlapping_names) > 0:
        category, names = overlapping_names[0]
        raise ValueError(
            "Cant merge two graphs with overlapping names. "
            f"Found repeated {category} names: "
            + ", ".join(names)
            + "\n"
            + "Consider using ``onnx.compose.add_prefix`` to add a prefix to names in one of the graphs."
        )

    g = GraphProto()

    g.node.extend(g1.node)
    g2_nodes_begin = len(g.node)
    g.node.extend(g2.node)
    g2_nodes_end = len(g.node)

    # Connecting outputs of the first graph with the inputs of the second
    for node_idx in range(g2_nodes_begin, g2_nodes_end):
        node = g.node[node_idx]
        for index, name_ in enumerate(node.input):
            if name_ in reversed_io_map:
                node.input[index] = reversed_io_map[name_]

    if inputs:
        input_set = set(inputs)
        g.input.extend([i for i in g1.input if i.name in input_set])
        g.input.extend([i for i in g2.input if i.name in input_set])
    else:
        g.input.extend(g1.input)
        g.input.extend([i for i in g2.input if i.name not in io_map_g2_ins])

    if outputs:
        output_set = set(outputs)
        g.output.extend([o for o in g1.output if o.name in output_set])
        g.output.extend([o for o in g2.output if o.name in output_set])
    else:
        g.output.extend([o for o in g1.output if o.name not in io_map_g1_outs])
        g.output.extend(g2.output)

    g.initializer.extend(g1.initializer)
    g.initializer.extend(
        [init for init in g2.initializer if init.name not in io_map_g2_ins]
    )

    g.sparse_initializer.extend(g1.sparse_initializer)
    g.sparse_initializer.extend(
        [
            init
            for init in g2.sparse_initializer
            if init.values.name not in io_map_g2_ins
        ]
    )

    g.value_info.extend(g1.value_info)
    g.value_info.extend([vi for vi in g2.value_info if vi.name not in io_map_g2_ins])

    g.name = name if name is not None else "_".join([g1.name, g2.name])

    if doc_string is None:
        doc_string = (
            f"Graph combining {g1.name} and {g2.name}\n"
            + g1.name
            + "\n\n"
            + g1.doc_string
            + "\n\n"
            + g2.name
            + "\n\n"
            + g2.doc_string
        )
    g.doc_string = doc_string

    return g


def merge_models(
    m1: ModelProto,
    m2: ModelProto,
    io_map: List[Tuple[str, str]],
    inputs: Optional[List[str]] = None,
    outputs: Optional[List[str]] = None,
    prefix1: Optional[str] = None,
    prefix2: Optional[str] = None,
    name: Optional[str] = None,
    doc_string: Optional[str] = None,
    producer_name: Optional[str] = "onnx.compose.merge_models",
    producer_version: Optional[str] = "1.0",
    domain: Optional[str] = "",
    model_version: Optional[int] = 1,
) -> ModelProto:
    """Combines two ONNX models into a single one.

    The combined model is defined by connecting the specified set of outputs/inputs.
    Those inputs/outputs not specified in the io_map argument will remain as
    inputs/outputs of the combined model.

    Both models should have the same IR version, and same operator sets imported.

    Arguments:
        m1 (ModelProto): First model
        m2 (ModelProto): Second model
        io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...]
                                          representing outputs of the first graph and inputs of the second
                                          to be connected
        inputs (list of string): Optional list of inputs to be included in the combined graph
                                 By default, all inputs not present in the ``io_map`` argument will be
                                 included in the combined model
        outputs (list of string): Optional list of outputs to be included in the combined graph
                                  By default, all outputs not present in the ``io_map`` argument will be
                                  included in the combined model
        prefix1 (string): Optional prefix to be added to all names in m1
        prefix2 (string): Optional prefix to be added to all names in m2
        name (string): Optional name for the combined graph
                       By default, the name is g1.name and g2.name concatenated with an undescore delimiter
        doc_string (string): Optional docstring for the combined graph
                             If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used
        producer_name (string): Optional producer name for the combined model. Default: 'onnx.compose'
        producer_version (string): Optional producer version for the combined model. Default: "1.0"
        domain (string): Optional domain of the combined model. Default: ""
        model_version (int): Optional version of the graph encoded. Default: 1

    Returns:
        ModelProto
    """
    if type(m1) is not ModelProto:
        raise ValueError("m1 argument is not an ONNX model")
    if type(m2) is not ModelProto:
        raise ValueError("m2 argument is not an ONNX model")

    if m1.ir_version != m2.ir_version:
        raise ValueError(
            f"IR version mismatch {m1.ir_version} != {m2.ir_version}."
            " Both models should have the same IR version"
        )
    ir_version = m1.ir_version

    opset_import_map: MutableMapping[str, int] = {}
    opset_imports = list(m1.opset_import) + list(m2.opset_import)

    for entry in opset_imports:
        if entry.domain in opset_import_map:
            found_version = opset_import_map[entry.domain]
            if entry.version != found_version:
                raise ValueError(
                    "Can't merge two models with different operator set ids for a given domain. "
                    f"Got: {m1.opset_import} and {m2.opset_import}"
                )
        else:
            opset_import_map[entry.domain] = entry.version

    # Prefixing names in the graph if requested, adjusting io_map accordingly
    if prefix1 or prefix2:
        if prefix1:
            m1_copy = ModelProto()
            m1_copy.CopyFrom(m1)
            m1 = m1_copy
            m1 = add_prefix(m1, prefix=prefix1)
        if prefix2:
            m2_copy = ModelProto()
            m2_copy.CopyFrom(m2)
            m2 = m2_copy
            m2 = add_prefix(m2, prefix=prefix2)
        io_map = [
            (
                prefix1 + io[0] if prefix1 else io[0],
                prefix2 + io[1] if prefix2 else io[1],
            )
            for io in io_map
        ]

    graph = merge_graphs(
        m1.graph,
        m2.graph,
        io_map,
        inputs=inputs,
        outputs=outputs,
        name=name,
        doc_string=doc_string,
    )
    model = helper.make_model(
        graph,
        producer_name=producer_name,
        producer_version=producer_version,
        domain=domain,
        model_version=model_version,
        opset_imports=opset_imports,
        ir_version=ir_version,
    )

    # Merging model metadata props
    model_props = {}
    for meta_entry in m1.metadata_props:
        model_props[meta_entry.key] = meta_entry.value
    for meta_entry in m2.metadata_props:
        if meta_entry.key in model_props:
            value = model_props[meta_entry.key]
            if value != meta_entry.value:
                raise ValueError(
                    "Can't merge models with different values for the same model metadata property."
                    f" Found: property = {meta_entry.key}, with values {value} and {meta_entry.value}."
                )
        else:
            model_props[meta_entry.key] = meta_entry.value
    helper.set_model_props(model, model_props)

    # Merging functions
    function_overlap = list(
        {f.name for f in m1.functions} & {f.name for f in m2.functions}
    )
    if function_overlap:
        raise ValueError(
            "Can't merge models with overlapping local function names."
            " Found in both graphs: " + ", ".join(function_overlap)
        )
    model.functions.MergeFrom(m1.functions)
    model.functions.MergeFrom(m2.functions)

    checker.check_model(model)
    return model


def add_prefix_graph(
    graph: GraphProto,
    prefix: str,
    rename_nodes: Optional[bool] = True,
    rename_edges: Optional[bool] = True,
    rename_inputs: Optional[bool] = True,
    rename_outputs: Optional[bool] = True,
    rename_initializers: Optional[bool] = True,
    rename_value_infos: Optional[bool] = True,
    inplace: Optional[bool] = False,
    name_map: Optional[Dict[str, str]] = None,
) -> GraphProto:
    """Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs,
    initializers, sparse initializer, value infos.

    It can be used as a utility before merging graphs that have overlapping names.
    Empty names are not prefixed.

    Arguments:
        graph (GraphProto): Graph
        prefix (str): Prefix to be added to each name in the graph
        rename_nodes (bool): Whether to prefix node names
        rename_edges (bool): Whether to prefix node edge names
        rename_inputs (bool): Whether to prefix input names
        rename_outputs (bool): Whether to prefix output names
        rename_initializers (bool): Whether to prefix initializer and sparse initializer names
        rename_value_infos (bool): Whether to prefix value info names
        inplace (bool): If True, mutates the graph directly.
                        Otherwise, a copy will be created
        name_map: (Dict): shared name_map in subgraph

    Returns:
        GraphProto
    """
    if type(graph) is not GraphProto:
        raise ValueError("graph argument is not an ONNX graph")

    if not inplace:
        g = GraphProto()
        g.CopyFrom(graph)
    else:
        g = graph

    def _prefixed(prefix: str, name: str) -> str:
        return prefix + name if len(name) > 0 else name

    if name_map is None:
        name_map = {}
    if rename_edges:
        for n in g.node:
            for e in n.input:
                name_map[e] = _prefixed(prefix, e)
            for e in n.output:
                name_map[e] = _prefixed(prefix, e)

    if rename_inputs:
        for entry in g.input:
            name_map[entry.name] = _prefixed(prefix, entry.name)
    if rename_outputs:
        for entry in g.output:
            name_map[entry.name] = _prefixed(prefix, entry.name)

    if rename_nodes:
        for n in g.node:
            n.name = _prefixed(prefix, n.name)
            for attribute in n.attribute:
                if attribute.g:
                    add_prefix_graph(
                        attribute.g, prefix, inplace=True, name_map=name_map
                    )

    if rename_initializers:
        for init in g.initializer:
            name_map[init.name] = _prefixed(prefix, init.name)
        for sparse_init in g.sparse_initializer:
            name_map[sparse_init.values.name] = _prefixed(
                prefix, sparse_init.values.name
            )
            name_map[sparse_init.indices.name] = _prefixed(
                prefix, sparse_init.indices.name
            )

    if rename_value_infos:
        for entry in g.value_info:
            name_map[entry.name] = _prefixed(prefix, entry.name)

    for n in g.node:
        for i, output in enumerate(n.output):
            if n.output[i] in name_map:
                n.output[i] = name_map[output]
        for i, input_ in enumerate(n.input):
            if n.input[i] in name_map:
                n.input[i] = name_map[input_]

    for in_desc in g.input:
        if in_desc.name in name_map:
            in_desc.name = name_map[in_desc.name]
    for out_desc in g.output:
        if out_desc.name in name_map:
            out_desc.name = name_map[out_desc.name]

    for initializer in g.initializer:
        if initializer.name in name_map:
            initializer.name = name_map[initializer.name]
    for sparse_initializer in g.sparse_initializer:
        if sparse_initializer.values.name in name_map:
            sparse_initializer.values.name = name_map[sparse_initializer.values.name]
        if sparse_initializer.indices.name in name_map:
            sparse_initializer.indices.name = name_map[sparse_initializer.indices.name]

    for value_info in g.value_info:
        if value_info.name in name_map:
            value_info.name = name_map[value_info.name]

    return g


def add_prefix(
    model: ModelProto,
    prefix: str,
    rename_nodes: Optional[bool] = True,
    rename_edges: Optional[bool] = True,
    rename_inputs: Optional[bool] = True,
    rename_outputs: Optional[bool] = True,
    rename_initializers: Optional[bool] = True,
    rename_value_infos: Optional[bool] = True,
    rename_functions: Optional[bool] = True,
    inplace: Optional[bool] = False,
) -> ModelProto:
    """Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs,
    initializers, sparse initializer, value infos, and local functions.

    It can be used as a utility before merging graphs that have overlapping names.
    Empty names are not _prefixed.

    Arguments:
        model (ModelProto): Model
        prefix (str): Prefix to be added to each name in the graph
        rename_nodes (bool): Whether to prefix node names
        rename_edges (bool): Whether to prefix node edge names
        rename_inputs (bool): Whether to prefix input names
        rename_outputs (bool): Whether to prefix output names
        rename_initializers (bool): Whether to prefix initializer and sparse initializer names
        rename_value_infos (bool): Whether to prefix value info nanes
        rename_functions (bool): Whether to prefix local function names
        inplace (bool): If True, mutates the model directly.
                        Otherwise, a copy will be created

    Returns:
        ModelProto
    """
    if type(model) is not ModelProto:
        raise ValueError("model argument is not an ONNX model")

    if not inplace:
        m = ModelProto()
        m.CopyFrom(model)
        model = m

    add_prefix_graph(
        model.graph,
        prefix,
        rename_nodes=rename_nodes,
        rename_edges=rename_edges,
        rename_inputs=rename_inputs,
        rename_outputs=rename_outputs,
        rename_initializers=rename_initializers,
        rename_value_infos=rename_value_infos,
        inplace=True,  # No need to create a copy, since it's a new model
    )

    if rename_functions:
        f_name_map = {}
        for f in model.functions:
            new_f_name = prefix + f.name
            f_name_map[f.name] = new_f_name
            f.name = new_f_name
        # Adjust references to local functions in other local function
        # definitions
        for f in model.functions:
            for n in f.node:
                if n.op_type in f_name_map:
                    n.op_type = f_name_map[n.op_type]
        # Adjust references to local functions in the graph
        for n in model.graph.node:
            if n.op_type in f_name_map:
                n.op_type = f_name_map[n.op_type]

    return model


def expand_out_dim_graph(
    graph: GraphProto,
    dim_idx: int,
    inplace: Optional[bool] = False,
) -> GraphProto:
    """Inserts an extra dimension with extent 1 to each output in the graph.

    Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs,
    for example when the second one expects a batch dimension.

    Arguments:
        graph (GraphProto): Graph
        dim_idx (int): Index of the dimension to be inserted.
                       A negative value means counting dimensions from the back.
        inplace (bool): If True, mutates the model directly.
                        Otherwise, a copy will be created

    Returns:
        GraphProto
    """
    if type(graph) is not GraphProto:
        raise ValueError("graph argument is not an ONNX graph")

    if not inplace:
        g = GraphProto()
        g.CopyFrom(graph)
    else:
        g = graph

    orig_out_names = [output.name for output in g.output]

    for n in g.node:
        for i, out in enumerate(n.output):
            if out in orig_out_names:
                n.output[i] = out + f"_collapsed_dim_{dim_idx}"
        for i, inp in enumerate(n.input):
            if inp in orig_out_names:
                n.input[i] = inp + f"_collapsed_dim_{dim_idx}"

    expand_dim_k = g.name + "_expand_out_dim_idx"
    g.node.append(
        helper.make_node(
            "Constant",
            inputs=[],
            outputs=[expand_dim_k],
            name=f"{expand_dim_k}-constant",
            value=helper.make_tensor(
                name=f"{expand_dim_k}-value",
                data_type=TensorProto.INT64,
                dims=[
                    1,
                ],
                vals=[
                    dim_idx,
                ],
            ),
        )
    )

    for _ in range(len(g.output)):
        o = g.output.pop(0)
        prev_output = o.name + f"_collapsed_dim_{dim_idx}"
        g.node.append(
            helper.make_node(
                "Unsqueeze",
                inputs=[prev_output, expand_dim_k],
                outputs=[o.name],
                name=f"unsqueeze-{o.name}",
            )
        )
        new_shape = [d.dim_value for d in o.type.tensor_type.shape.dim]
        new_shape.insert(dim_idx, 1)
        g.output.append(
            helper.make_tensor_value_info(
                o.name, o.type.tensor_type.elem_type, new_shape
            )
        )
    return g


def expand_out_dim(
    model: ModelProto,
    dim_idx: int,
    inplace: Optional[bool] = False,
) -> ModelProto:
    """Inserts an extra dimension with extent 1 to each output in the graph.

    Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs,
    for example when the second one expects a batch dimension.

    Arguments:
        model (ModelProto): Model
        dim_idx (int): Index of the dimension to be inserted.
                       A negative value means counting dimensions from the back.
        inplace (bool): If True, mutates the model directly.
                        Otherwise, a copy will be created

    Returns:
        ModelProto
    """
    if type(model) is not ModelProto:
        raise ValueError("model argument is not an ONNX model")

    if not inplace:
        m = ModelProto()
        m.CopyFrom(model)
        model = m

    expand_out_dim_graph(
        model.graph,
        dim_idx,
        inplace=True,  # No need to create a copy, since it's a new model
    )
    return model
