from __future__ import annotations

import gzip
import logging
import random

from torch.utils.data import Dataset

from sentence_transformers import SentenceTransformer
from sentence_transformers.readers import InputExample

logger = logging.getLogger(__name__)


class ParallelSentencesDataset(Dataset):
    """
    This dataset reader can be used to read-in parallel sentences, i.e., it reads in a file with tab-seperated sentences with the same
    sentence in different languages. For example, the file can look like this (EN\tDE\tES):
    hello world     hallo welt  hola mundo
    second sentence zweiter satz    segunda oración

    The sentence in the first column will be mapped to a sentence embedding using the given the embedder. For example,
    embedder is a mono-lingual sentence embedding method for English. The sentences in the other languages will also be
    mapped to this English sentence embedding.

    When getting a sample from the dataset, we get one sentence with the according sentence embedding for this sentence.

    teacher_model can be any class that implement an encode function. The encode function gets a list of sentences and
    returns a list of sentence embeddings
    """

    def __init__(
        self,
        student_model: SentenceTransformer,
        teacher_model: SentenceTransformer,
        batch_size: int = 8,
        use_embedding_cache: bool = True,
    ):
        """
        Parallel sentences dataset reader to train student model given a teacher model

        Args:
            student_model (SentenceTransformer): The student sentence embedding model that should be trained.
            teacher_model (SentenceTransformer): The teacher model that provides the sentence embeddings for the first column in the dataset file.
            batch_size (int, optional): The batch size for training. Defaults to 8.
            use_embedding_cache (bool, optional): Whether to use an embedding cache. Defaults to True.
        """
        self.student_model = student_model
        self.teacher_model = teacher_model
        self.datasets = []
        self.datasets_iterator = []
        self.datasets_tokenized = []
        self.dataset_indices = []
        self.copy_dataset_indices = []
        self.cache = []
        self.batch_size = batch_size
        self.use_embedding_cache = use_embedding_cache
        self.embedding_cache = {}
        self.num_sentences = 0

    def load_data(
        self, filepath: str, weight: int = 100, max_sentences: int = None, max_sentence_length: int = 128
    ) -> None:
        """
        Reads in a tab-seperated .txt/.csv/.tsv or .gz file. The different columns contain the different translations of the sentence in the first column

        Args:
            filepath (str): Filepath to the file.
            weight (int, optional): If more than one dataset is loaded with load_data, specifies the frequency at which data should be sampled from this dataset. Defaults to 100.
            max_sentences (int, optional): Maximum number of lines to be read from the filepath. Defaults to None.
            max_sentence_length (int, optional): Skip the example if one of the sentences has more characters than max_sentence_length. Defaults to 128.

        Returns:
            None
        """

        logger.info("Load " + filepath)
        parallel_sentences = []

        with (
            gzip.open(filepath, "rt", encoding="utf8")
            if filepath.endswith(".gz")
            else open(filepath, encoding="utf8") as fIn
        ):
            count = 0
            for line in fIn:
                sentences = line.strip().split("\t")
                if (
                    max_sentence_length is not None
                    and max_sentence_length > 0
                    and max([len(sent) for sent in sentences]) > max_sentence_length
                ):
                    continue

                parallel_sentences.append(sentences)
                count += 1
                if max_sentences is not None and max_sentences > 0 and count >= max_sentences:
                    break
        self.add_dataset(
            parallel_sentences, weight=weight, max_sentences=max_sentences, max_sentence_length=max_sentence_length
        )

    def add_dataset(
        self,
        parallel_sentences: list[list[str]],
        weight: int = 100,
        max_sentences: int = None,
        max_sentence_length: int = 128,
    ):
        sentences_map = {}
        for sentences in parallel_sentences:
            if (
                max_sentence_length is not None
                and max_sentence_length > 0
                and max([len(sent) for sent in sentences]) > max_sentence_length
            ):
                continue

            source_sentence = sentences[0]
            if source_sentence not in sentences_map:
                sentences_map[source_sentence] = set()

            for sent in sentences:
                sentences_map[source_sentence].add(sent)

            if max_sentences is not None and max_sentences > 0 and len(sentences_map) >= max_sentences:
                break

        if len(sentences_map) == 0:
            return

        self.num_sentences += sum([len(sentences_map[sent]) for sent in sentences_map])

        dataset_id = len(self.datasets)
        self.datasets.append(list(sentences_map.items()))
        self.datasets_iterator.append(0)
        self.dataset_indices.extend([dataset_id] * weight)

    def generate_data(self):
        source_sentences_list = []
        target_sentences_list = []
        for data_idx in self.dataset_indices:
            src_sentence, trg_sentences = self.next_entry(data_idx)
            source_sentences_list.append(src_sentence)
            target_sentences_list.append(trg_sentences)

        # Generate embeddings
        src_embeddings = self.get_embeddings(source_sentences_list)

        for src_embedding, trg_sentences in zip(src_embeddings, target_sentences_list):
            for trg_sentence in trg_sentences:
                self.cache.append(InputExample(texts=[trg_sentence], label=src_embedding))

        random.shuffle(self.cache)

    def next_entry(self, data_idx):
        source, target_sentences = self.datasets[data_idx][self.datasets_iterator[data_idx]]

        self.datasets_iterator[data_idx] += 1
        if self.datasets_iterator[data_idx] >= len(self.datasets[data_idx]):  # Restart iterator
            self.datasets_iterator[data_idx] = 0
            random.shuffle(self.datasets[data_idx])

        return source, target_sentences

    def get_embeddings(self, sentences):
        if not self.use_embedding_cache:
            return self.teacher_model.encode(
                sentences, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True
            )

        # Use caching
        new_sentences = []
        for sent in sentences:
            if sent not in self.embedding_cache:
                new_sentences.append(sent)

        if len(new_sentences) > 0:
            new_embeddings = self.teacher_model.encode(
                new_sentences, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True
            )
            for sent, embedding in zip(new_sentences, new_embeddings):
                self.embedding_cache[sent] = embedding

        return [self.embedding_cache[sent] for sent in sentences]

    def __len__(self):
        return self.num_sentences

    def __getitem__(self, idx):
        if len(self.cache) == 0:
            self.generate_data()

        return self.cache.pop()
