# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

import os
import tempfile
import unittest

import numpy as np

import onnx
import onnx.external_data_helper as ext_data
import onnx.helper
import onnx.model_container
import onnx.numpy_helper


def _linear_regression():
    X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])
    Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None])
    graph = onnx.helper.make_graph(
        [
            onnx.helper.make_node("MatMul", ["X", "A"], ["XA"]),
            onnx.helper.make_node("MatMul", ["XA", "B"], ["XB"]),
            onnx.helper.make_node("MatMul", ["XB", "C"], ["Y"]),
        ],
        "mm",
        [X],
        [Y],
        [
            onnx.numpy_helper.from_array(
                np.arange(9).astype(np.float32).reshape((-1, 3)), name="A"
            ),
            onnx.numpy_helper.from_array(
                (np.arange(9) * 10).astype(np.float32).reshape((-1, 3)),
                name="B",
            ),
            onnx.numpy_helper.from_array(
                (np.arange(9) * 10).astype(np.float32).reshape((-1, 3)),
                name="C",
            ),
        ],
    )
    onnx_model = onnx.helper.make_model(graph)
    onnx.checker.check_model(onnx_model)
    return onnx_model


def _large_linear_regression():
    X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])
    Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None])
    graph = onnx.helper.make_graph(
        [
            onnx.helper.make_node("MatMul", ["X", "A"], ["XA"]),
            onnx.helper.make_node("MatMul", ["XA", "B"], ["XB"]),
            onnx.helper.make_node("MatMul", ["XB", "C"], ["Y"]),
        ],
        "mm",
        [X],
        [Y],
        [
            onnx.model_container.make_large_tensor_proto(
                "#loc0", "A", onnx.TensorProto.FLOAT, (3, 3)
            ),
            onnx.numpy_helper.from_array(
                np.arange(9).astype(np.float32).reshape((-1, 3)), name="B"
            ),
            onnx.model_container.make_large_tensor_proto(
                "#loc1", "C", onnx.TensorProto.FLOAT, (3, 3)
            ),
        ],
    )
    onnx_model = onnx.helper.make_model(graph)
    large_model = onnx.model_container.make_large_model(
        onnx_model.graph,
        {
            "#loc0": (np.arange(9) * 100).astype(np.float32).reshape((-1, 3)),
            "#loc1": (np.arange(9) + 10).astype(np.float32).reshape((-1, 3)),
        },
    )
    large_model.check_model()
    return large_model


class TestLargeOnnx(unittest.TestCase):
    def test_large_onnx_no_large_initializer(self):
        model_proto = _linear_regression()
        assert isinstance(model_proto, onnx.ModelProto)
        large_model = onnx.model_container.make_large_model(model_proto.graph)
        assert isinstance(large_model, onnx.model_container.ModelContainer)
        with tempfile.TemporaryDirectory() as temp:
            filename = os.path.join(temp, "model.onnx")
            large_model.save(filename)
            copy = onnx.model_container.ModelContainer()
            with self.assertRaises(RuntimeError):
                assert copy.model_proto
            copy.load(filename)
            assert copy.model_proto is not None
            onnx.checker.check_model(copy.model_proto)

    def test_large_one_weight_file(self):
        large_model = _large_linear_regression()
        assert isinstance(large_model, onnx.model_container.ModelContainer)
        with tempfile.TemporaryDirectory() as temp:
            filename = os.path.join(temp, "model.onnx")
            saved_proto = large_model.save(filename, True)
            assert isinstance(saved_proto, onnx.ModelProto)
            copy = onnx.model_container.ModelContainer()
            copy.load(filename)
            copy.check_model()
            loaded_model = onnx.load_model(filename, load_external_data=True)
            onnx.checker.check_model(loaded_model)

    def test_large_multi_files(self):
        large_model = _large_linear_regression()
        assert isinstance(large_model, onnx.model_container.ModelContainer)
        with tempfile.TemporaryDirectory() as temp:
            filename = os.path.join(temp, "model.onnx")
            saved_proto = large_model.save(filename, False)
            assert isinstance(saved_proto, onnx.ModelProto)
            copy = onnx.load_model(filename)
            onnx.checker.check_model(copy)
            for tensor in ext_data._get_all_tensors(copy):
                if ext_data.uses_external_data(tensor):
                    tested = 0
                    for ext in tensor.external_data:
                        if ext.key == "location":  # type: ignore[attr-defined]
                            assert os.path.exists(ext.value)
                            tested += 1
                    self.assertEqual(tested, 1)
            loaded_model = onnx.load_model(filename, load_external_data=True)
            onnx.checker.check_model(loaded_model)


if __name__ == "__main__":
    unittest.main(verbosity=2)
