from __future__ import annotations

import re
from typing import Any, Dict, List, Optional, Sequence, Tuple

import numpy as np
from langchain_core.callbacks import (
    CallbackManagerForChainRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from pydantic import Field

from langchain.chains.base import Chain
from langchain.chains.flare.prompts import (
    PROMPT,
    QUESTION_GENERATOR_PROMPT,
    FinishedOutputParser,
)
from langchain.chains.llm import LLMChain


def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
    """Extract tokens and log probabilities from chat model response."""
    tokens = []
    log_probs = []
    for token in response.response_metadata["logprobs"]["content"]:
        tokens.append(token["token"])
        log_probs.append(token["logprob"])
    return tokens, log_probs


class QuestionGeneratorChain(LLMChain):
    """Chain that generates questions from uncertain spans."""

    prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT
    """Prompt template for the chain."""

    @classmethod
    def is_lc_serializable(cls) -> bool:
        return False

    @property
    def input_keys(self) -> List[str]:
        """Input keys for the chain."""
        return ["user_input", "context", "response"]


def _low_confidence_spans(
    tokens: Sequence[str],
    log_probs: Sequence[float],
    min_prob: float,
    min_token_gap: int,
    num_pad_tokens: int,
) -> List[str]:
    _low_idx = np.where(np.exp(log_probs) < min_prob)[0]
    low_idx = [i for i in _low_idx if re.search(r"\w", tokens[i])]
    if len(low_idx) == 0:
        return []
    spans = [[low_idx[0], low_idx[0] + num_pad_tokens + 1]]
    for i, idx in enumerate(low_idx[1:]):
        end = idx + num_pad_tokens + 1
        if idx - low_idx[i] < min_token_gap:
            spans[-1][1] = end
        else:
            spans.append([idx, end])
    return ["".join(tokens[start:end]) for start, end in spans]


class FlareChain(Chain):
    """Chain that combines a retriever, a question generator,
    and a response generator.

    See [Active Retrieval Augmented Generation](https://arxiv.org/abs/2305.06983) paper.
    """

    question_generator_chain: Runnable
    """Chain that generates questions from uncertain spans."""
    response_chain: Runnable
    """Chain that generates responses from user input and context."""
    output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
    """Parser that determines whether the chain is finished."""
    retriever: BaseRetriever
    """Retriever that retrieves relevant documents from a user input."""
    min_prob: float = 0.2
    """Minimum probability for a token to be considered low confidence."""
    min_token_gap: int = 5
    """Minimum number of tokens between two low confidence spans."""
    num_pad_tokens: int = 2
    """Number of tokens to pad around a low confidence span."""
    max_iter: int = 10
    """Maximum number of iterations."""
    start_with_retrieval: bool = True
    """Whether to start with retrieval."""

    @property
    def input_keys(self) -> List[str]:
        """Input keys for the chain."""
        return ["user_input"]

    @property
    def output_keys(self) -> List[str]:
        """Output keys for the chain."""
        return ["response"]

    def _do_generation(
        self,
        questions: List[str],
        user_input: str,
        response: str,
        _run_manager: CallbackManagerForChainRun,
    ) -> Tuple[str, bool]:
        callbacks = _run_manager.get_child()
        docs = []
        for question in questions:
            docs.extend(self.retriever.invoke(question))
        context = "\n\n".join(d.page_content for d in docs)
        result = self.response_chain.invoke(
            {
                "user_input": user_input,
                "context": context,
                "response": response,
            },
            {"callbacks": callbacks},
        )
        if isinstance(result, AIMessage):
            result = result.content
        marginal, finished = self.output_parser.parse(result)
        return marginal, finished

    def _do_retrieval(
        self,
        low_confidence_spans: List[str],
        _run_manager: CallbackManagerForChainRun,
        user_input: str,
        response: str,
        initial_response: str,
    ) -> Tuple[str, bool]:
        question_gen_inputs = [
            {
                "user_input": user_input,
                "current_response": initial_response,
                "uncertain_span": span,
            }
            for span in low_confidence_spans
        ]
        callbacks = _run_manager.get_child()
        if isinstance(self.question_generator_chain, LLMChain):
            question_gen_outputs = self.question_generator_chain.apply(
                question_gen_inputs, callbacks=callbacks
            )
            questions = [
                output[self.question_generator_chain.output_keys[0]]
                for output in question_gen_outputs
            ]
        else:
            questions = self.question_generator_chain.batch(
                question_gen_inputs, config={"callbacks": callbacks}
            )
        _run_manager.on_text(
            f"Generated Questions: {questions}", color="yellow", end="\n"
        )
        return self._do_generation(questions, user_input, response, _run_manager)

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()

        user_input = inputs[self.input_keys[0]]

        response = ""

        for i in range(self.max_iter):
            _run_manager.on_text(
                f"Current Response: {response}", color="blue", end="\n"
            )
            _input = {"user_input": user_input, "context": "", "response": response}
            tokens, log_probs = _extract_tokens_and_log_probs(
                self.response_chain.invoke(
                    _input, {"callbacks": _run_manager.get_child()}
                )
            )
            low_confidence_spans = _low_confidence_spans(
                tokens,
                log_probs,
                self.min_prob,
                self.min_token_gap,
                self.num_pad_tokens,
            )
            initial_response = response.strip() + " " + "".join(tokens)
            if not low_confidence_spans:
                response = initial_response
                final_response, finished = self.output_parser.parse(response)
                if finished:
                    return {self.output_keys[0]: final_response}
                continue

            marginal, finished = self._do_retrieval(
                low_confidence_spans,
                _run_manager,
                user_input,
                response,
                initial_response,
            )
            response = response.strip() + " " + marginal
            if finished:
                break
        return {self.output_keys[0]: response}

    @classmethod
    def from_llm(
        cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any
    ) -> FlareChain:
        """Creates a FlareChain from a language model.

        Args:
            llm: Language model to use.
            max_generation_len: Maximum length of the generated response.
            kwargs: Additional arguments to pass to the constructor.

        Returns:
            FlareChain class with the given language model.
        """
        try:
            from langchain_openai import ChatOpenAI
        except ImportError:
            raise ImportError(
                "OpenAI is required for FlareChain. "
                "Please install langchain-openai."
                "pip install langchain-openai"
            )
        llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
        response_chain = PROMPT | llm
        question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
        return cls(
            question_generator_chain=question_gen_chain,
            response_chain=response_chain,
            **kwargs,
        )
