from __future__ import annotations

import csv
import logging
import os
from collections import defaultdict
from contextlib import nullcontext
from typing import TYPE_CHECKING

from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.util import paraphrase_mining

if TYPE_CHECKING:
    from sentence_transformers.SentenceTransformer import SentenceTransformer

logger = logging.getLogger(__name__)


class ParaphraseMiningEvaluator(SentenceEvaluator):
    """
    Given a large set of sentences, this evaluator performs paraphrase (duplicate) mining and
    identifies the pairs with the highest similarity. It compare the extracted paraphrase pairs
    with a set of gold labels and computes the F1 score.

    Example:
        ::

            from datasets import load_dataset
            from sentence_transformers.SentenceTransformer import SentenceTransformer
            from sentence_transformers.evaluation import ParaphraseMiningEvaluator

            # Load a model
            model = SentenceTransformer('all-mpnet-base-v2')

            # Load the Quora Duplicates Mining dataset
            questions_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "questions", split="dev")
            duplicates_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "duplicates", split="dev")

            # Create a mapping from qid to question & a list of duplicates (qid1, qid2)
            qid_to_questions = dict(zip(questions_dataset["qid"], questions_dataset["question"]))
            duplicates = list(zip(duplicates_dataset["qid1"], duplicates_dataset["qid2"]))

            # Initialize the paraphrase mining evaluator
            paraphrase_mining_evaluator = ParaphraseMiningEvaluator(
                sentences_map=qid_to_questions,
                duplicates_list=duplicates,
                name="quora-duplicates-dev",
            )
            results = paraphrase_mining_evaluator(model)
            '''
            Paraphrase Mining Evaluation of the model on the quora-duplicates-dev dataset:
            Number of candidate pairs: 250564
            Average Precision: 56.51
            Optimal threshold: 0.8325
            Precision: 52.76
            Recall: 59.19
            F1: 55.79
            '''
            print(paraphrase_mining_evaluator.primary_metric)
            # => "quora-duplicates-dev_average_precision"
            print(results[paraphrase_mining_evaluator.primary_metric])
            # => 0.5650940787776353
    """

    def __init__(
        self,
        sentences_map: dict[str, str],
        duplicates_list: list[tuple[str, str]] = None,
        duplicates_dict: dict[str, dict[str, bool]] = None,
        add_transitive_closure: bool = False,
        query_chunk_size: int = 5000,
        corpus_chunk_size: int = 100000,
        max_pairs: int = 500000,
        top_k: int = 100,
        show_progress_bar: bool = False,
        batch_size: int = 16,
        name: str = "",
        write_csv: bool = True,
        truncate_dim: int | None = None,
    ):
        """
        Initializes the ParaphraseMiningEvaluator.

        Args:
            sentences_map (Dict[str, str]): A dictionary that maps sentence-ids to sentences.
                For example, sentences_map[id] => sentence.
            duplicates_list (List[Tuple[str, str]], optional): A list with id pairs [(id1, id2), (id1, id5)]
                that identifies the duplicates / paraphrases in the sentences_map. Defaults to None.
            duplicates_dict (Dict[str, Dict[str, bool]], optional): A default dictionary mapping [id1][id2]
                to true if id1 and id2 are duplicates. Must be symmetric, i.e., if [id1][id2] => True,
                then [id2][id1] => True. Defaults to None.
            add_transitive_closure (bool, optional): If true, it adds a transitive closure,
                i.e. if dup[a][b] and dup[b][c], then dup[a][c]. Defaults to False.
            query_chunk_size (int, optional): To identify the paraphrases, the cosine-similarity between
                all sentence-pairs will be computed. As this might require a lot of memory, we perform
                a batched computation. query_chunk_size sentences will be compared against up to
                corpus_chunk_size sentences. In the default setting, 5000 sentences will be grouped
                together and compared up-to against 100k other sentences. Defaults to 5000.
            corpus_chunk_size (int, optional): The corpus will be batched, to reduce the memory requirement.
                Defaults to 100000.
            max_pairs (int, optional): We will only extract up to max_pairs potential paraphrase candidates.
                Defaults to 500000.
            top_k (int, optional): For each query, we extract the top_k most similar pairs and add it to a sorted list.
                I.e., for one sentence we cannot find more than top_k paraphrases. Defaults to 100.
            show_progress_bar (bool, optional): Output a progress bar. Defaults to False.
            batch_size (int, optional): Batch size for computing sentence embeddings. Defaults to 16.
            name (str, optional): Name of the experiment. Defaults to "".
            write_csv (bool, optional): Write results to CSV file. Defaults to True.
            truncate_dim (Optional[int], optional): The dimension to truncate sentence embeddings to.
                `None` uses the model's current truncation dimension. Defaults to None.
        """
        super().__init__()
        self.sentences = []
        self.ids = []

        for id, sentence in sentences_map.items():
            self.sentences.append(sentence)
            self.ids.append(id)

        self.name = name
        self.show_progress_bar = show_progress_bar
        self.batch_size = batch_size
        self.query_chunk_size = query_chunk_size
        self.corpus_chunk_size = corpus_chunk_size
        self.max_pairs = max_pairs
        self.top_k = top_k
        self.truncate_dim = truncate_dim

        self.duplicates = duplicates_dict if duplicates_dict is not None else defaultdict(lambda: defaultdict(bool))
        if duplicates_list is not None:
            for id1, id2 in duplicates_list:
                if id1 in sentences_map and id2 in sentences_map:
                    self.duplicates[id1][id2] = True
                    self.duplicates[id2][id1] = True

        # Add transitive closure
        if add_transitive_closure:
            self.duplicates = self.add_transitive_closure(self.duplicates)

        positive_key_pairs = set()
        for key1 in self.duplicates:
            for key2 in self.duplicates[key1]:
                if (
                    key1 in sentences_map
                    and key2 in sentences_map
                    and (self.duplicates[key1][key2] or self.duplicates[key2][key1])
                ):
                    positive_key_pairs.add(tuple(sorted([key1, key2])))

        self.total_num_duplicates = len(positive_key_pairs)

        if name:
            name = "_" + name

        self.csv_file: str = "paraphrase_mining_evaluation" + name + "_results.csv"
        self.csv_headers = ["epoch", "steps", "precision", "recall", "f1", "threshold", "average_precision"]
        self.write_csv = write_csv
        self.primary_metric = "average_precision"

    def __call__(
        self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1
    ) -> dict[str, float]:
        if epoch != -1:
            if steps == -1:
                out_txt = f" after epoch {epoch}"
            else:
                out_txt = f" in epoch {epoch} after {steps} steps"
        else:
            out_txt = ""
        if self.truncate_dim is not None:
            out_txt += f" (truncated to {self.truncate_dim})"

        logger.info(f"Paraphrase Mining Evaluation of the model on the {self.name} dataset{out_txt}:")

        # Compute embedding for the sentences
        with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
            pairs_list = paraphrase_mining(
                model,
                self.sentences,
                self.show_progress_bar,
                self.batch_size,
                self.query_chunk_size,
                self.corpus_chunk_size,
                self.max_pairs,
                self.top_k,
            )

        logger.info("Number of candidate pairs: " + str(len(pairs_list)))

        # Compute F1 score and Average Precision
        n_extract = n_correct = 0
        threshold = 0
        best_f1 = best_recall = best_precision = 0

        average_precision = 0

        for idx in range(len(pairs_list)):
            score, i, j = pairs_list[idx]
            id1 = self.ids[i]
            id2 = self.ids[j]

            # Compute optimal threshold and F1-score
            n_extract += 1
            if self.duplicates[id1][id2] or self.duplicates[id2][id1]:
                n_correct += 1
                precision = n_correct / n_extract
                recall = n_correct / self.total_num_duplicates
                f1 = 2 * precision * recall / (precision + recall)
                average_precision += precision
                if f1 > best_f1:
                    best_f1 = f1
                    best_precision = precision
                    best_recall = recall
                    threshold = (pairs_list[idx][0] + pairs_list[min(idx + 1, len(pairs_list) - 1)][0]) / 2

        average_precision = average_precision / self.total_num_duplicates

        logger.info(f"Average Precision: {average_precision * 100:.2f}")
        logger.info(f"Optimal threshold: {threshold:.4f}")
        logger.info(f"Precision: {best_precision * 100:.2f}")
        logger.info(f"Recall: {best_recall * 100:.2f}")
        logger.info(f"F1: {best_f1 * 100:.2f}\n")

        if output_path is not None and self.write_csv:
            csv_path = os.path.join(output_path, self.csv_file)
            if not os.path.isfile(csv_path):
                with open(csv_path, newline="", mode="w", encoding="utf-8") as f:
                    writer = csv.writer(f)
                    writer.writerow(self.csv_headers)
                    writer.writerow([epoch, steps, best_precision, best_recall, best_f1, threshold, average_precision])
            else:
                with open(csv_path, newline="", mode="a", encoding="utf-8") as f:
                    writer = csv.writer(f)
                    writer.writerow([epoch, steps, best_precision, best_recall, best_f1, threshold, average_precision])

        metrics = {
            "average_precision": average_precision,
            "f1": best_f1,
            "precision": best_precision,
            "recall": best_recall,
            "threshold": threshold,
        }
        metrics = self.prefix_name_to_metrics(metrics, self.name)
        self.store_metrics_in_model_card_data(model, metrics)
        return metrics

    @staticmethod
    def add_transitive_closure(graph):
        nodes_visited = set()
        for a in list(graph.keys()):
            if a not in nodes_visited:
                connected_subgraph_nodes = set()
                connected_subgraph_nodes.add(a)

                # Add all nodes in the connected graph
                neighbor_nodes_queue = list(graph[a])
                while len(neighbor_nodes_queue) > 0:
                    node = neighbor_nodes_queue.pop(0)
                    if node not in connected_subgraph_nodes:
                        connected_subgraph_nodes.add(node)
                        neighbor_nodes_queue.extend(graph[node])

                # Ensure transitivity between all nodes in the graph
                connected_subgraph_nodes = list(connected_subgraph_nodes)
                for i in range(len(connected_subgraph_nodes) - 1):
                    for j in range(i + 1, len(connected_subgraph_nodes)):
                        graph[connected_subgraph_nodes[i]][connected_subgraph_nodes[j]] = True
                        graph[connected_subgraph_nodes[j]][connected_subgraph_nodes[i]] = True

                        nodes_visited.add(connected_subgraph_nodes[i])
                        nodes_visited.add(connected_subgraph_nodes[j])
        return graph
