# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import warnings

__all__ = [
    "registry",
]

import typing
from typing import Any, Collection, Optional, Protocol, TypeVar

import google.protobuf.json_format
import google.protobuf.message
import google.protobuf.text_format

import onnx

_Proto = TypeVar("_Proto", bound=google.protobuf.message.Message)
# Encoding used for serializing and deserializing text files
_ENCODING = "utf-8"


class ProtoSerializer(Protocol):
    """A serializer-deserializer to and from in-memory Protocol Buffers representations."""

    # Format supported by the serializer. E.g. "protobuf"
    supported_format: str
    # File extensions supported by the serializer. E.g. frozenset({".onnx", ".pb"})
    # Be careful to include the dot in the file extension.
    file_extensions: Collection[str]

    # NOTE: The methods defined are serialize_proto and deserialize_proto and not the
    # more generic serialize and deserialize to leave space for future protocols
    # that are defined to serialize/deserialize the ONNX in memory IR.
    # This way a class can implement both protocols.

    def serialize_proto(self, proto: _Proto) -> Any:
        """Serialize a in-memory proto to a serialized data type."""

    def deserialize_proto(self, serialized: Any, proto: _Proto) -> _Proto:
        """Parse a serialized data type into a in-memory proto."""


class _Registry:
    def __init__(self) -> None:
        self._serializers: dict[str, ProtoSerializer] = {}
        # A mapping from file extension to format
        self._extension_to_format: dict[str, str] = {}

    def register(self, serializer: ProtoSerializer) -> None:
        self._serializers[serializer.supported_format] = serializer
        self._extension_to_format.update(
            {ext: serializer.supported_format for ext in serializer.file_extensions}
        )

    def get(self, fmt: str) -> ProtoSerializer:
        """Get a serializer for a format.

        Args:
            fmt: The format to get a serializer for.

        Returns:
            ProtoSerializer: The serializer for the format.

        Raises:
            ValueError: If the format is not supported.
        """
        try:
            return self._serializers[fmt]
        except KeyError:
            raise ValueError(
                f"Unsupported format: '{fmt}'. Supported formats are: {self._serializers.keys()}"
            ) from None

    def get_format_from_file_extension(self, file_extension: str) -> str | None:
        """Get the corresponding format from a file extension.

        Args:
            file_extension: The file extension to get a format for.

        Returns:
            The format for the file extension, or None if not found.
        """
        return self._extension_to_format.get(file_extension)


class _ProtobufSerializer(ProtoSerializer):
    """Serialize and deserialize protobuf message."""

    supported_format = "protobuf"
    file_extensions = frozenset({".onnx", ".pb"})

    def serialize_proto(self, proto: _Proto) -> bytes:
        if hasattr(proto, "SerializeToString") and callable(proto.SerializeToString):
            try:
                result = proto.SerializeToString()
            except ValueError as e:
                if proto.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
                    raise ValueError(
                        "The proto size is larger than the 2 GB limit. "
                        "Please use save_as_external_data to save tensors separately from the model file."
                    ) from e
                raise
            return result  # type: ignore
        raise TypeError(
            f"No SerializeToString method is detected.\ntype is {type(proto)}"
        )

    def deserialize_proto(self, serialized: bytes, proto: _Proto) -> _Proto:
        if not isinstance(serialized, bytes):
            raise TypeError(
                f"Parameter 'serialized' must be bytes, but got type: {type(serialized)}"
            )
        decoded = typing.cast(Optional[int], proto.ParseFromString(serialized))
        if decoded is not None and decoded != len(serialized):
            raise google.protobuf.message.DecodeError(
                f"Protobuf decoding consumed too few bytes: {decoded} out of {len(serialized)}"
            )
        return proto


class _TextProtoSerializer(ProtoSerializer):
    """Serialize and deserialize text proto."""

    supported_format = "textproto"
    file_extensions = frozenset({".textproto", ".prototxt", ".pbtxt"})

    def serialize_proto(self, proto: _Proto) -> bytes:
        textproto = google.protobuf.text_format.MessageToString(proto)
        return textproto.encode(_ENCODING)

    def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
        if not isinstance(serialized, (bytes, str)):
            raise TypeError(
                f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
            )
        if isinstance(serialized, bytes):
            serialized = serialized.decode(_ENCODING)
        assert isinstance(serialized, str)
        return google.protobuf.text_format.Parse(serialized, proto)


class _JsonSerializer(ProtoSerializer):
    """Serialize and deserialize JSON."""

    supported_format = "json"
    file_extensions = frozenset({".json", ".onnxjson"})

    def serialize_proto(self, proto: _Proto) -> bytes:
        json_message = google.protobuf.json_format.MessageToJson(
            proto, preserving_proto_field_name=True
        )
        return json_message.encode(_ENCODING)

    def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
        if not isinstance(serialized, (bytes, str)):
            raise TypeError(
                f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
            )
        if isinstance(serialized, bytes):
            serialized = serialized.decode(_ENCODING)
        assert isinstance(serialized, str)
        return google.protobuf.json_format.Parse(serialized, proto)


class _TextualSerializer(ProtoSerializer):
    """Serialize and deserialize the ONNX textual representation."""

    supported_format = "onnxtxt"
    file_extensions = frozenset({".onnxtxt"})

    def serialize_proto(self, proto: _Proto) -> bytes:
        text = onnx.printer.to_text(proto)  # type: ignore[arg-type]
        return text.encode(_ENCODING)

    def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
        warnings.warn(
            "The onnxtxt format is experimental. Please report any errors to the ONNX GitHub repository.",
            stacklevel=2,
        )
        if not isinstance(serialized, (bytes, str)):
            raise TypeError(
                f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
            )
        if isinstance(serialized, bytes):
            text = serialized.decode(_ENCODING)
        else:
            text = serialized
        if isinstance(proto, onnx.ModelProto):
            return onnx.parser.parse_model(text)  # type: ignore[return-value]
        if isinstance(proto, onnx.GraphProto):
            return onnx.parser.parse_graph(text)  # type: ignore[return-value]
        if isinstance(proto, onnx.FunctionProto):
            return onnx.parser.parse_function(text)  # type: ignore[return-value]
        if isinstance(proto, onnx.NodeProto):
            return onnx.parser.parse_node(text)  # type: ignore[return-value]
        raise ValueError(f"Unsupported proto type: {type(proto)}")


# Register default serializers
registry = _Registry()
registry.register(_ProtobufSerializer())
registry.register(_TextProtoSerializer())
registry.register(_JsonSerializer())
registry.register(_TextualSerializer())
