# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0

import io
import os
import pathlib
import tempfile
import unittest

import google.protobuf.message
import google.protobuf.text_format
import parameterized

import onnx
from onnx import serialization


def _simple_model() -> onnx.ModelProto:
    model = onnx.ModelProto()
    model.ir_version = onnx.IR_VERSION
    model.producer_name = "onnx-test"
    model.graph.name = "test"
    return model


def _simple_tensor() -> onnx.TensorProto:
    tensor = onnx.helper.make_tensor(
        name="test-tensor",
        data_type=onnx.TensorProto.FLOAT,
        dims=(2, 3, 4),
        vals=[x + 0.5 for x in range(24)],
    )
    return tensor


@parameterized.parameterized_class(
    [
        {"format": "protobuf"},
        {"format": "textproto"},
        {"format": "json"},
        {"format": "onnxtxt"},
    ]
)
class TestIO(unittest.TestCase):
    format: str

    def test_load_model_when_input_is_bytes(self) -> None:
        proto = _simple_model()
        proto_string = serialization.registry.get(self.format).serialize_proto(proto)
        loaded_proto = onnx.load_model_from_string(proto_string, format=self.format)
        self.assertEqual(proto, loaded_proto)

    def test_save_and_load_model_when_input_has_read_function(self) -> None:
        proto = _simple_model()
        # When the proto is a bytes representation provided to `save_model`,
        # it should always be a serialized binary protobuf representation. Aka. format="protobuf"
        # The saved file format is specified by the `format` argument.
        proto_string = serialization.registry.get("protobuf").serialize_proto(proto)
        f = io.BytesIO()
        onnx.save_model(proto_string, f, format=self.format)
        loaded_proto = onnx.load_model(io.BytesIO(f.getvalue()), format=self.format)
        self.assertEqual(proto, loaded_proto)

    def test_save_and_load_model_when_input_is_file_name(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.onnx")
            onnx.save_model(proto, model_path, format=self.format)
            loaded_proto = onnx.load_model(model_path, format=self.format)
            self.assertEqual(proto, loaded_proto)

    def test_save_and_load_model_when_input_is_pathlike(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = pathlib.Path(temp_dir, "model.onnx")
            onnx.save_model(proto, model_path, format=self.format)
            loaded_proto = onnx.load_model(model_path, format=self.format)
            self.assertEqual(proto, loaded_proto)


@parameterized.parameterized_class(
    [
        {"format": "protobuf"},
        {"format": "textproto"},
        {"format": "json"},
        # The onnxtxt format does not support saving/loading tensors yet
    ]
)
class TestIOTensor(unittest.TestCase):
    """Test loading and saving of TensorProto."""

    format: str

    def test_load_tensor_when_input_is_bytes(self) -> None:
        proto = _simple_tensor()
        proto_string = serialization.registry.get(self.format).serialize_proto(proto)
        loaded_proto = onnx.load_tensor_from_string(proto_string, format=self.format)
        self.assertEqual(proto, loaded_proto)

    def test_save_and_load_tensor_when_input_has_read_function(self) -> None:
        # Test if input has a read function
        proto = _simple_tensor()
        f = io.BytesIO()
        onnx.save_tensor(proto, f, format=self.format)
        loaded_proto = onnx.load_tensor(io.BytesIO(f.getvalue()), format=self.format)
        self.assertEqual(proto, loaded_proto)

    def test_save_and_load_tensor_when_input_is_file_name(self) -> None:
        # Test if input is a file name
        proto = _simple_tensor()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.onnx")
            onnx.save_tensor(proto, model_path, format=self.format)
            loaded_proto = onnx.load_tensor(model_path, format=self.format)
            self.assertEqual(proto, loaded_proto)

    def test_save_and_load_tensor_when_input_is_pathlike(self) -> None:
        # Test if input is a file name
        proto = _simple_tensor()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = pathlib.Path(temp_dir, "model.onnx")
            onnx.save_tensor(proto, model_path, format=self.format)
            loaded_proto = onnx.load_tensor(model_path, format=self.format)
            self.assertEqual(proto, loaded_proto)


class TestSaveAndLoadFileExtensions(unittest.TestCase):
    def test_save_model_picks_correct_format_from_extension(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.textproto")
            # No format is specified, so the extension should be used to determine the format
            onnx.save_model(proto, model_path)
            loaded_proto = onnx.load_model(model_path, format="textproto")
            self.assertEqual(proto, loaded_proto)

    def test_load_model_picks_correct_format_from_extension(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.textproto")
            onnx.save_model(proto, model_path, format="textproto")
            # No format is specified, so the extension should be used to determine the format
            loaded_proto = onnx.load_model(model_path)
            self.assertEqual(proto, loaded_proto)

    def test_save_model_uses_format_when_it_is_specified(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.textproto")
            # `format` is specified. It should take precedence over the extension
            onnx.save_model(proto, model_path, format="protobuf")
            loaded_proto = onnx.load_model(model_path, format="protobuf")
            self.assertEqual(proto, loaded_proto)
            with self.assertRaises(google.protobuf.text_format.ParseError):
                # Loading it as textproto (by file extension) should fail
                onnx.load_model(model_path)

    def test_load_model_uses_format_when_it_is_specified(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.protobuf")
            onnx.save_model(proto, model_path)
            with self.assertRaises(google.protobuf.text_format.ParseError):
                # `format` is specified. It should take precedence over the extension
                # Loading it as textproto should fail
                onnx.load_model(model_path, format="textproto")

            loaded_proto = onnx.load_model(model_path, format="protobuf")
            self.assertEqual(proto, loaded_proto)

    def test_load_and_save_model_to_path_without_specifying_extension_succeeds(
        self,
    ) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            # No extension is specified
            model_path = os.path.join(temp_dir, "model")
            onnx.save_model(proto, model_path, format="textproto")
            with self.assertRaises(google.protobuf.message.DecodeError):
                # `format` is not specified. load_model should assume protobuf
                # and fail to load it
                onnx.load_model(model_path)

            loaded_proto = onnx.load_model(model_path, format="textproto")
            self.assertEqual(proto, loaded_proto)

    def test_load_and_save_model_without_specifying_extension_or_format_defaults_to_protobuf(
        self,
    ) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            # No extension is specified
            model_path = os.path.join(temp_dir, "model")
            onnx.save_model(proto, model_path)
            with self.assertRaises(google.protobuf.text_format.ParseError):
                # The model is saved as protobuf, so loading it as textproto should fail
                onnx.load_model(model_path, format="textproto")

            loaded_proto = onnx.load_model(model_path)
            self.assertEqual(proto, loaded_proto)
            loaded_proto_as_explicitly_protobuf = onnx.load_model(
                model_path, format="protobuf"
            )
            self.assertEqual(proto, loaded_proto_as_explicitly_protobuf)


class TestBasicFunctions(unittest.TestCase):
    def test_protos_exist(self) -> None:
        # The proto classes should exist
        _ = onnx.AttributeProto
        _ = onnx.NodeProto
        _ = onnx.GraphProto
        _ = onnx.ModelProto

    def test_version_exists(self) -> None:
        model = onnx.ModelProto()
        # When we create it, graph should not have a version string.
        self.assertFalse(model.HasField("ir_version"))
        # We should touch the version so it is annotated with the current
        # ir version of the running ONNX
        model.ir_version = onnx.IR_VERSION
        model_string = model.SerializeToString()
        model.ParseFromString(model_string)
        self.assertTrue(model.HasField("ir_version"))
        # Check if the version is correct.
        self.assertEqual(model.ir_version, onnx.IR_VERSION)


if __name__ == "__main__":
    unittest.main()
