# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
import os
import tempfile
import unittest

import onnx

_TEST_MODEL = """\
<
    ir_version: 8,
    opset_import: ["" : 17, "local" : 1]
>
agraph (float[N] X) => (float[N] Y) {
    Y = local.foo (X)
}

<opset_import: ["" : 17, "local" : 1], domain: "local">
foo (x) => (y) {
    temp = Add(x, x)
    y = local.bar(temp)
}

<opset_import: ["" : 17], domain: "local">
bar (x) => (y) {
    y = Mul (x, x)
}"""


class _OnnxTestTextualSerializer(onnx.serialization.ProtoSerializer):
    """Serialize and deserialize the ONNX textual representation."""

    supported_format = "onnxtext"
    file_extensions = frozenset({".onnxtext"})

    def serialize_proto(self, proto) -> bytes:
        text = onnx.printer.to_text(proto)
        return text.encode("utf-8")

    def deserialize_proto(self, serialized: bytes, proto):
        text = serialized.decode("utf-8")
        if isinstance(proto, onnx.ModelProto):
            return onnx.parser.parse_model(text)
        if isinstance(proto, onnx.GraphProto):
            return onnx.parser.parse_graph(text)
        if isinstance(proto, onnx.FunctionProto):
            return onnx.parser.parse_function(text)
        if isinstance(proto, onnx.NodeProto):
            return onnx.parser.parse_node(text)
        raise ValueError(f"Unsupported proto type: {type(proto)}")


class TestRegistry(unittest.TestCase):
    def setUp(self) -> None:
        self.serializer = _OnnxTestTextualSerializer()
        onnx.serialization.registry.register(self.serializer)

    def test_get_returns_the_registered_instance(self) -> None:
        serializer = onnx.serialization.registry.get("onnxtext")
        self.assertIs(serializer, self.serializer)

    def test_get_raises_for_unsupported_format(self) -> None:
        with self.assertRaises(ValueError):
            onnx.serialization.registry.get("unsupported")

    def test_onnx_save_load_model_uses_the_custom_serializer(self) -> None:
        model = onnx.parser.parse_model(_TEST_MODEL)
        with tempfile.TemporaryDirectory() as tmpdir:
            model_path = os.path.join(tmpdir, "model.onnx")
            onnx.save_model(model, model_path, format="onnxtext")

            # Check the file content
            with open(model_path, encoding="utf-8") as f:
                content = f.read()
                self.assertEqual(content, onnx.printer.to_text(model))

            loaded_model = onnx.load_model(model_path, format="onnxtext")

            self.assertEqual(
                model.SerializeToString(deterministic=True),
                loaded_model.SerializeToString(deterministic=True),
            )


class TestCustomSerializer(unittest.TestCase):
    def test_serialize_deserialize_model(self) -> None:
        serializer = _OnnxTestTextualSerializer()
        model = onnx.parser.parse_model(_TEST_MODEL)
        serialized = serializer.serialize_proto(model)
        deserialized = serializer.deserialize_proto(serialized, onnx.ModelProto())
        self.assertEqual(
            model.SerializeToString(deterministic=True),
            deserialized.SerializeToString(deterministic=True),
        )
