# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Union

import numpy as np

from onnx import (
    AttributeProto,
    FunctionProto,
    GraphProto,
    ModelProto,
    NodeProto,
    SparseTensorProto,
    TensorProto,
)
from onnx.helper import (
    make_attribute,
    make_function,
    make_graph,
    make_model,
    make_node,
    make_tensor,
    make_tensor_value_info,
    set_model_props,
    tensor_dtype_to_np_dtype,
)
from onnx.numpy_helper import from_array


def _replace_constant(
    node: NodeProto, threshold: int, value_constant_of_shape: float
) -> List[NodeProto]:
    """Replaces a Constant node with a large tensor (with more than threshold elements) by a sequence of nodes that produces a dummy constant of same shape as original tensor."""
    if node.op_type != "Constant":
        raise TypeError(f"Node type must be 'Constant' not {node.op_type!r}.")
    for att in node.attribute:
        if att.name == "sparse_value":
            raise NotImplementedError(
                f"This feature is not yet implemented for a sparse constant "
                f"(node name={node.name!r})."
            )
        if att.name == "value":
            value = att.t
            new_name = f"{value.name}__SHAPE"
            dims = value.dims
            size = np.prod(dims)
            if size <= threshold:
                return [node]
            init = from_array(np.array(list(dims), dtype=np.int64), name=new_name)
            dtype = tensor_dtype_to_np_dtype(value.data_type)
            node_shape = make_node(
                "Constant",
                [],
                [new_name],
                value=init,
            )
            new_node = make_node(
                "ConstantOfShape",
                [new_name],
                node.output,
                value=from_array(np.array([value_constant_of_shape], dtype=dtype)),
            )
            return [node_shape, new_node]
        raise NotImplementedError(
            f"Replacement of constant with attribute {att.name!r}"
        )
    return [node]


def _replace_constant_of_shape_with_range(
    onx: Union[GraphProto, FunctionProto]
) -> Union[GraphProto, FunctionProto]:
    """Replaces all *ConstantOfShape* by node *Range* to avoid constant tensors.

    The function is not recursive. The recursivity is done by
    *replace_initializer_by_constant_of_shape*.
    """
    if isinstance(onx, GraphProto):
        nodes = list(onx.node)
    elif isinstance(onx, FunctionProto):
        nodes = list(onx.node)
    else:
        raise TypeError(f"Not implemented for type {type(onx)}.")

    existing_names = set()
    for node in nodes:
        existing_names |= set(node.input)
        existing_names |= set(node.output)

    def _find_name(prefix):
        if prefix not in existing_names:
            existing_names.add(prefix)
            return prefix
        i = 2
        while True:
            name = f"{prefix}_{i}"
            if name not in existing_names:
                existing_names.add(name)
                return name
            i += 1
        # The function should never go through that line.
        raise RuntimeError("The function should never go through that line.")

    cst0 = make_node("Constant", [], [_find_name("zero")], value_int=0)
    cst1 = make_node("Constant", [], [_find_name("one")], value_int=1)
    update = {}
    for inode, node in enumerate(nodes):
        if node.op_type != "ConstantOfShape":
            continue
        shape = node.input[0]

        n = make_node("ReduceProd", [shape], [_find_name(f"{shape}_N")])
        a = make_node(
            "Range",
            [cst0.output[0], n.output[0], cst1.output[0]],
            [_find_name(f"{shape}_RANGE")],
        )
        if len(node.attribute) == 1:
            to = node.attribute[0].t.data_type
        else:
            to = TensorProto.FLOAT
        ac = make_node("Cast", [a.output[0]], [_find_name(f"{shape}_RANGEf")], to=to)
        cl = make_node("Cast", [n.output[0]], [_find_name(f"{shape}_Nf")], to=to)
        d = make_node(
            "Div", [ac.output[0], cl.output[0]], [_find_name(f"{shape}_FLAT")]
        )
        resh = make_node("Reshape", [d.output[0], shape], node.output)
        update[inode] = [n, a, ac, cl, d, resh]

    for inode, up in sorted(update.items(), reverse=True):
        nodes[inode : inode + 1] = up
    nodes.insert(0, cst0)
    nodes.insert(1, cst1)

    if isinstance(onx, GraphProto):
        graph = make_graph(
            nodes,
            onx.name,
            onx.input,
            onx.output,
            initializer=onx.initializer,
            sparse_initializer=onx.sparse_initializer,
        )
        return graph
    if isinstance(onx, FunctionProto):
        new_onx = make_function(
            onx.domain,
            onx.name,
            onx.input,
            onx.output,
            nodes,
            opset_imports=onx.opset_import,
        )
        return new_onx
    raise TypeError(f"Not implemented for type {type(onx)}.")


def _replace_constant_of_shape_value(
    onx: Union[GraphProto, FunctionProto], value_constant_of_shape: float
) -> Union[GraphProto, FunctionProto]:
    """Replaces all fill value of all nodes *ConstantOfShape*."""
    if isinstance(onx, GraphProto):
        nodes = list(onx.node)
    elif isinstance(onx, FunctionProto):
        nodes = list(onx.node)
    else:
        raise TypeError(f"Not implemented for type {type(onx)}.")

    existing_names = set()
    for node in nodes:
        existing_names |= set(node.input)
        existing_names |= set(node.output)

    update = {}
    for inode, node in enumerate(nodes):
        if node.op_type != "ConstantOfShape":
            continue
        tensor = node.attribute[0].t
        new_tensor = make_tensor(
            tensor.name, tensor.data_type, [1], [value_constant_of_shape]
        )
        new_node = make_node("ConstantOfShape", node.input, node.output)
        att = make_attribute(node.attribute[0].name, value=new_tensor)
        new_node.attribute.append(att)
        update[inode] = new_node

    for inode, up in update.items():
        nodes[inode] = up

    if isinstance(onx, GraphProto):
        graph = make_graph(
            nodes,
            onx.name,
            onx.input,
            onx.output,
            initializer=onx.initializer,
            sparse_initializer=onx.sparse_initializer,
        )
        return graph
    if isinstance(onx, FunctionProto):
        new_onx = make_function(
            onx.domain,
            onx.name,
            onx.input,
            onx.output,
            nodes,
            opset_imports=onx.opset_import,
        )
        return new_onx
    raise TypeError(f"Not implemented for type {type(onx)}.")


def replace_initializer_by_constant_of_shape(  # noqa: PLR0911
    onx: Union[FunctionProto, GraphProto, ModelProto],
    threshold: int = 128,
    ir_version: Optional[int] = None,
    use_range: bool = False,
    value_constant_of_shape: float = 0.5,
):
    """Replace initializers or constant node by nodes *ConstantOfShape* to reduce the size.

    This reduce the cost to write a unit test about a specific graph structure.

    Args:
        onx: ModelProto
        threshold: every initializer under this threshold is not
            impacted
        ir_version: initializer must be specified as input for
            `ir_version <= 3`, this must be specified if onx is
            :class:`FunctionProto` or :class:`GraphProto`
        use_range: if uses operator *Range* instead of *ConstantOfShape*
            to avoid constant tensors
        value_constant_of_shape: value to use as a value for all nodes
            *ConstantOfShape*, a high value may produce nan or inf
            predictions

    Returns:
        onx, modified ModelProto

    The function is designed so that the function can be reapplied on a modified model
    and either replace *ConstantOfShape* with *Range* operators, either replace the fill value
    for every *ConstantOfShape*.
    """
    if isinstance(onx, FunctionProto):
        modified = False
        new_nodes: List[NodeProto] = []
        for node in onx.node:
            if node.op_type == "Constant":
                cst_nodes = _replace_constant(node, threshold, value_constant_of_shape)
                if len(cst_nodes) == 2:  # noqa: PLR2004
                    modified = True
                new_nodes.extend(cst_nodes)
                continue
            new_nodes.append(node)
        if modified:
            new_onx = make_function(
                onx.domain,
                onx.name,
                onx.input,
                onx.output,
                new_nodes,
                opset_imports=onx.opset_import,
            )
            if use_range:
                return _replace_constant_of_shape_with_range(new_onx)
            if value_constant_of_shape != 1:
                return _replace_constant_of_shape_value(
                    new_onx, value_constant_of_shape
                )
            return new_onx
        if use_range:
            return _replace_constant_of_shape_with_range(onx)
        if value_constant_of_shape != 1:
            return _replace_constant_of_shape_value(onx, value_constant_of_shape)
        return onx

    if isinstance(onx, ModelProto):
        new_graph = replace_initializer_by_constant_of_shape(
            onx.graph,
            ir_version=ir_version or onx.ir_version,
            threshold=threshold,
            use_range=use_range,
            value_constant_of_shape=value_constant_of_shape,
        )
        new_functions = [
            replace_initializer_by_constant_of_shape(
                f,
                threshold=threshold,
                ir_version=ir_version or onx.ir_version,
                use_range=use_range,
                value_constant_of_shape=value_constant_of_shape,
            )
            for f in onx.functions
        ]
        model = make_model(
            new_graph,
            functions=new_functions,
            producer_name=onx.producer_name,
            producer_version=onx.producer_version,
            ir_version=ir_version or onx.ir_version,
            doc_string=onx.doc_string,
            domain=onx.domain,
            model_version=onx.model_version,
        )
        if len(onx.metadata_props) > 0:  # pragma: no cover
            values = {p.key: p.value for p in onx.metadata_props}
            set_model_props(model, values)

        del model.opset_import[:]
        for oimp in onx.opset_import:
            op_set = model.opset_import.add()
            if oimp.domain == "" and oimp.version < 11 and use_range:  # noqa: PLR2004
                raise RuntimeError(
                    f"Range was introduced in opset 11 but opset is {oimp.version}."
                )
            if oimp.domain == "" and oimp.version < 9:  # noqa: PLR2004
                raise RuntimeError(
                    f"ConstantOfShape was introduced in "
                    f"opset 9 but opset is {oimp.version}."
                )
            op_set.domain = oimp.domain
            op_set.version = oimp.version
        return model

    if not isinstance(onx, GraphProto):
        raise TypeError(f"onx should be a GraphProto at this stage not {type(onx)}.")

    n_modifications = 0
    new_nodes = []
    removed = set()
    additional_inputs = []

    new_inits: List[TensorProto] = []
    for init in onx.initializer:
        dims = tuple(init.dims)
        size = np.prod(dims)
        if size <= threshold:
            new_inits.append(init)
            continue
        n_modifications += 1
        new_name = f"{init.name}__SHAPE"
        new_inits.append(
            from_array(np.array(list(dims), dtype=np.int64), name=new_name)
        )
        dtype = tensor_dtype_to_np_dtype(init.data_type)
        node = make_node(
            "ConstantOfShape",
            [new_name],
            [init.name],
            value=from_array(np.array([0.5], dtype=dtype)),
        )
        new_nodes.append(node)
        removed.add(init.name)
        if ir_version is not None and ir_version <= 3:  # noqa: PLR2004
            additional_inputs.append(
                make_tensor_value_info(new_name, TensorProto.INT64, [len(dims)])
            )

    new_sparse_inits: List[SparseTensorProto] = []
    for sp_init in onx.sparse_initializer:
        dims = tuple(sp_init.dims)
        size = np.prod(dims)
        if size <= threshold:
            new_sparse_inits.append(sp_init)
            continue
        raise NotImplementedError(
            f"This feature is not yet implemented for a sparse initializer "
            f"(indices.name={sp_init.indices.name!r}, "
            f"values.name={sp_init.values.name!r})."
        )

    for node in onx.node:
        if node.op_type == "Constant":
            shape_nodes = _replace_constant(node, threshold, value_constant_of_shape)
            if len(shape_nodes) == 2:  # noqa: PLR2004
                n_modifications += 1
            new_nodes.extend(shape_nodes)
            continue
        modified = False
        atts = []
        for att in node.attribute:
            if (
                att.type == AttributeProto.GRAPH
                and hasattr(att, "g")
                and att.g is not None
            ):
                g = replace_initializer_by_constant_of_shape(
                    att.g,
                    threshold=threshold,
                    ir_version=ir_version,
                    use_range=use_range,
                    value_constant_of_shape=value_constant_of_shape,
                )
                if id(g) != id(att.g):
                    modified = True
                    att = make_attribute(att.name, g)  # noqa: PLW2901
            atts.append(att)
        if modified:
            new_node = make_node(node.op_type, node.input, node.output)
            new_node.attribute.extend(atts)
            new_nodes.append(new_node)
            n_modifications += 1
        else:
            new_nodes.append(node)

    if n_modifications > 0:
        graph = make_graph(
            new_nodes,
            onx.name,
            [i for i in onx.input if i.name not in removed] + additional_inputs,
            onx.output,
            initializer=new_inits,
            sparse_initializer=new_sparse_inits,
        )
        if use_range:
            return _replace_constant_of_shape_with_range(graph)
        if value_constant_of_shape != 1:
            return _replace_constant_of_shape_value(graph, value_constant_of_shape)
        return graph
    if use_range:
        return _replace_constant_of_shape_with_range(onx)
    if value_constant_of_shape != 1:
        return _replace_constant_of_shape_value(onx, value_constant_of_shape)
    return onx
