# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

__all__ = [
    # Constants
    "ONNX_ML",
    "IR_VERSION",
    "IR_VERSION_2017_10_10",
    "IR_VERSION_2017_10_30",
    "IR_VERSION_2017_11_3",
    "IR_VERSION_2019_1_22",
    "IR_VERSION_2019_3_18",
    "IR_VERSION_2019_9_19",
    "IR_VERSION_2020_5_8",
    "IR_VERSION_2021_7_30",
    "IR_VERSION_2023_5_5",
    "EXPERIMENTAL",
    "STABLE",
    # Modules
    "checker",
    "compose",
    "defs",
    "gen_proto",
    "helper",
    "hub",
    "mapping",
    "numpy_helper",
    "parser",
    "printer",
    "shape_inference",
    "utils",
    "version_converter",
    # Proto classes
    "AttributeProto",
    "FunctionProto",
    "GraphProto",
    "MapProto",
    "ModelProto",
    "NodeProto",
    "OperatorProto",
    "OperatorSetIdProto",
    "OperatorSetProto",
    "OperatorStatus",
    "OptionalProto",
    "SequenceProto",
    "SparseTensorProto",
    "StringStringEntryProto",
    "TensorAnnotation",
    "TensorProto",
    "TensorShapeProto",
    "TrainingInfoProto",
    "TypeProto",
    "ValueInfoProto",
    "Version",
    # Utility functions
    "convert_model_to_external_data",
    "load_external_data_for_model",
    "load_model_from_string",
    "load_model",
    "load_tensor_from_string",
    "load_tensor",
    "save_model",
    "save_tensor",
    "write_external_data_tensors",
]
# isort:skip_file

import os
import typing
from typing import IO, Literal, Union


from onnx import serialization
from onnx.onnx_cpp2py_export import ONNX_ML
from onnx.external_data_helper import (
    load_external_data_for_model,
    write_external_data_tensors,
    convert_model_to_external_data,
)
from onnx.onnx_pb import (
    AttributeProto,
    EXPERIMENTAL,
    FunctionProto,
    GraphProto,
    IR_VERSION,
    IR_VERSION_2017_10_10,
    IR_VERSION_2017_10_30,
    IR_VERSION_2017_11_3,
    IR_VERSION_2019_1_22,
    IR_VERSION_2019_3_18,
    IR_VERSION_2019_9_19,
    IR_VERSION_2020_5_8,
    IR_VERSION_2021_7_30,
    IR_VERSION_2023_5_5,
    ModelProto,
    NodeProto,
    OperatorSetIdProto,
    OperatorStatus,
    STABLE,
    SparseTensorProto,
    StringStringEntryProto,
    TensorAnnotation,
    TensorProto,
    TensorShapeProto,
    TrainingInfoProto,
    TypeProto,
    ValueInfoProto,
    Version,
)
from onnx.onnx_operators_pb import OperatorProto, OperatorSetProto
from onnx.onnx_data_pb import MapProto, OptionalProto, SequenceProto
from onnx.version import version as __version__

# Import common subpackages so they're available when you 'import onnx'
from onnx import (
    checker,
    compose,
    defs,
    gen_proto,
    helper,
    hub,
    mapping,
    numpy_helper,
    parser,
    printer,
    shape_inference,
    utils,
    version_converter,
)

# Supported model formats that can be loaded from and saved to
# The literals are formats with built-in support. But we also allow users to
# register their own formats. So we allow str as well.
_SupportedFormat = Union[Literal["protobuf", "textproto"], str]
# Default serialization format
_DEFAULT_FORMAT = "protobuf"


def _load_bytes(f: IO[bytes] | str | os.PathLike) -> bytes:
    if hasattr(f, "read") and callable(typing.cast(IO[bytes], f).read):
        content = typing.cast(IO[bytes], f).read()
    else:
        f = typing.cast(Union[str, os.PathLike], f)
        with open(f, "rb") as readable:
            content = readable.read()
    return content


def _save_bytes(content: bytes, f: IO[bytes] | str | os.PathLike) -> None:
    if hasattr(f, "write") and callable(typing.cast(IO[bytes], f).write):
        typing.cast(IO[bytes], f).write(content)
    else:
        f = typing.cast(Union[str, os.PathLike], f)
        with open(f, "wb") as writable:
            writable.write(content)


def _get_file_path(f: IO[bytes] | str | os.PathLike | None) -> str | None:
    if isinstance(f, (str, os.PathLike)):
        return os.path.abspath(f)
    if hasattr(f, "name"):
        assert f is not None
        return os.path.abspath(f.name)
    return None


def _get_serializer(
    fmt: _SupportedFormat | None, f: str | os.PathLike | IO[bytes] | None = None
) -> serialization.ProtoSerializer:
    """Get the serializer for the given path and format from the serialization registry."""
    # Use fmt if it is specified
    if fmt is not None:
        return serialization.registry.get(fmt)

    if (file_path := _get_file_path(f)) is not None:
        _, ext = os.path.splitext(file_path)
        fmt = serialization.registry.get_format_from_file_extension(ext)

    # Failed to resolve format if fmt is None. Use protobuf as default
    fmt = fmt or _DEFAULT_FORMAT
    assert fmt is not None

    return serialization.registry.get(fmt)


def load_model(
    f: IO[bytes] | str | os.PathLike,
    format: _SupportedFormat | None = None,  # noqa: A002
    load_external_data: bool = True,
) -> ModelProto:
    """Loads a serialized ModelProto into memory.

    Args:
        f: can be a file-like object (has "read" function) or a string/PathLike containing a file name
        format: The serialization format. When it is not specified, it is inferred
            from the file extension when ``f`` is a path. If not specified _and_
            ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
            be "utf-8" when the format is a text format.
        load_external_data: Whether to load the external data.
            Set to True if the data is under the same directory of the model.
            If not, users need to call :func:`load_external_data_for_model`
            with directory to load external data from.

    Returns:
        Loaded in-memory ModelProto.
    """
    model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto())

    if load_external_data:
        model_filepath = _get_file_path(f)
        if model_filepath:
            base_dir = os.path.dirname(model_filepath)
            load_external_data_for_model(model, base_dir)

    return model


def load_tensor(
    f: IO[bytes] | str | os.PathLike,
    format: _SupportedFormat | None = None,  # noqa: A002
) -> TensorProto:
    """Loads a serialized TensorProto into memory.

    Args:
        f: can be a file-like object (has "read" function) or a string/PathLike containing a file name
        format: The serialization format. When it is not specified, it is inferred
            from the file extension when ``f`` is a path. If not specified _and_
            ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
            be "utf-8" when the format is a text format.

    Returns:
        Loaded in-memory TensorProto.
    """
    return _get_serializer(format, f).deserialize_proto(_load_bytes(f), TensorProto())


def load_model_from_string(
    s: bytes | str,
    format: _SupportedFormat = _DEFAULT_FORMAT,  # noqa: A002
) -> ModelProto:
    """Loads a binary string (bytes) that contains serialized ModelProto.

    Args:
        s: a string, which contains serialized ModelProto
        format: The serialization format. When it is not specified, it is inferred
            from the file extension when ``f`` is a path. If not specified _and_
            ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
            be "utf-8" when the format is a text format.

    Returns:
        Loaded in-memory ModelProto.
    """
    return _get_serializer(format).deserialize_proto(s, ModelProto())


def load_tensor_from_string(
    s: bytes,
    format: _SupportedFormat = _DEFAULT_FORMAT,  # noqa: A002
) -> TensorProto:
    """Loads a binary string (bytes) that contains serialized TensorProto.

    Args:
        s: a string, which contains serialized TensorProto
        format: The serialization format. When it is not specified, it is inferred
            from the file extension when ``f`` is a path. If not specified _and_
            ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
            be "utf-8" when the format is a text format.

    Returns:
        Loaded in-memory TensorProto.
    """
    return _get_serializer(format).deserialize_proto(s, TensorProto())


def save_model(
    proto: ModelProto | bytes,
    f: IO[bytes] | str | os.PathLike,
    format: _SupportedFormat | None = None,  # noqa: A002
    *,
    save_as_external_data: bool = False,
    all_tensors_to_one_file: bool = True,
    location: str | None = None,
    size_threshold: int = 1024,
    convert_attribute: bool = False,
) -> None:
    """Saves the ModelProto to the specified path and optionally, serialize tensors with raw data as external data before saving.

    Args:
        proto: should be a in-memory ModelProto
        f: can be a file-like object (has "write" function) or a string containing
        a file name or a pathlike object
        format: The serialization format. When it is not specified, it is inferred
            from the file extension when ``f`` is a path. If not specified _and_
            ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
            be "utf-8" when the format is a text format.
        save_as_external_data: If true, save tensors to external file(s).
        all_tensors_to_one_file: Effective only if save_as_external_data is True.
            If true, save all tensors to one external file specified by location.
            If false, save each tensor to a file named with the tensor name.
        location: Effective only if save_as_external_data is true.
            Specify the external file that all tensors to save to.
            Path is relative to the model path.
            If not specified, will use the model name.
        size_threshold: Effective only if save_as_external_data is True.
            Threshold for size of data. Only when tensor's data is >= the size_threshold it will be converted
            to external data. To convert every tensor with raw data to external data set size_threshold=0.
        convert_attribute: Effective only if save_as_external_data is True.
            If true, convert all tensors to external data
            If false, convert only non-attribute tensors to external data
    """
    if isinstance(proto, bytes):
        proto = _get_serializer(_DEFAULT_FORMAT).deserialize_proto(proto, ModelProto())

    if save_as_external_data:
        convert_model_to_external_data(
            proto, all_tensors_to_one_file, location, size_threshold, convert_attribute
        )

    model_filepath = _get_file_path(f)
    if model_filepath is not None:
        basepath = os.path.dirname(model_filepath)
        proto = write_external_data_tensors(proto, basepath)

    serialized = _get_serializer(format, model_filepath).serialize_proto(proto)
    _save_bytes(serialized, f)


def save_tensor(
    proto: TensorProto,
    f: IO[bytes] | str | os.PathLike,
    format: _SupportedFormat | None = None,  # noqa: A002
) -> None:
    """Saves the TensorProto to the specified path.

    Args:
        proto: should be a in-memory TensorProto
        f: can be a file-like object (has "write" function) or a string
        containing a file name or a pathlike object.
        format: The serialization format. When it is not specified, it is inferred
            from the file extension when ``f`` is a path. If not specified _and_
            ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
            be "utf-8" when the format is a text format.
    """
    serialized = _get_serializer(format, f).serialize_proto(proto)
    _save_bytes(serialized, f)


# For backward compatibility
load = load_model
load_from_string = load_model_from_string
save = save_model
