import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict

import aiohttp
import requests
from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.utils import get_from_dict_or_env
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self

from langchain_community.llms.utils import enforce_stop_tokens


class TrainResult(TypedDict):
    """Train result."""

    loss: float


class GradientLLM(BaseLLM):
    """Gradient.ai LLM Endpoints.

    GradientLLM is a class to interact with LLMs on gradient.ai

    To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
    API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
    or alternatively provide them as keywords to the constructor of this class.

    Example:
        .. code-block:: python

            from langchain_community.llms import GradientLLM
            GradientLLM(
                model="99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model",
                model_kwargs={
                    "max_generated_token_count": 128,
                    "temperature": 0.75,
                    "top_p": 0.95,
                    "top_k": 20,
                    "stop": [],
                },
                gradient_workspace_id="12345614fc0_workspace",
                gradient_access_token="gradientai-access_token",
            )

    """

    model_id: str = Field(alias="model", min_length=2)
    "Underlying gradient.ai model id (base or fine-tuned)."

    gradient_workspace_id: Optional[str] = None
    "Underlying gradient.ai workspace_id."

    gradient_access_token: Optional[str] = None
    """gradient.ai API Token, which can be generated by going to
        https://auth.gradient.ai/select-workspace
        and selecting "Access tokens" under the profile drop-down.
    """

    model_kwargs: Optional[dict] = None
    """Keyword arguments to pass to the model."""

    gradient_api_url: str = "https://api.gradient.ai/api"
    """Endpoint URL to use."""

    aiosession: Optional[aiohttp.ClientSession] = None  #: :meta private:
    """ClientSession, private, subject to change in upcoming releases."""

    # LLM call kwargs
    model_config = ConfigDict(
        populate_by_name=True,
        extra="forbid",
    )

    @model_validator(mode="before")
    @classmethod
    def validate_environment(cls, values: Dict) -> Any:
        """Validate that api key and python package exists in environment."""

        values["gradient_access_token"] = get_from_dict_or_env(
            values, "gradient_access_token", "GRADIENT_ACCESS_TOKEN"
        )
        values["gradient_workspace_id"] = get_from_dict_or_env(
            values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID"
        )

        values["gradient_api_url"] = get_from_dict_or_env(
            values, "gradient_api_url", "GRADIENT_API_URL"
        )
        return values

    @model_validator(mode="after")
    def post_init(self) -> Self:
        """Post init validation."""
        # Can be most to post_init_validation
        try:
            import gradientai  # noqa
        except ImportError:
            logging.warning(
                "DeprecationWarning: `GradientLLM` will use "
                "`pip install gradientai` in future releases of langchain."
            )
        except Exception:
            pass

        # Can be most to post_init_validation
        if self.gradient_access_token is None or len(self.gradient_access_token) < 10:
            raise ValueError("env variable `GRADIENT_ACCESS_TOKEN` must be set")

        if self.gradient_workspace_id is None or len(self.gradient_access_token) < 3:
            raise ValueError("env variable `GRADIENT_WORKSPACE_ID` must be set")

        if self.model_kwargs:
            kw = self.model_kwargs
            if not 0 <= kw.get("temperature", 0.5) <= 1:
                raise ValueError("`temperature` must be in the range [0.0, 1.0]")

            if not 0 <= kw.get("top_p", 0.5) <= 1:
                raise ValueError("`top_p` must be in the range [0.0, 1.0]")

            if 0 >= kw.get("top_k", 0.5):
                raise ValueError("`top_k` must be positive")

            if 0 >= kw.get("max_generated_token_count", 1):
                raise ValueError("`max_generated_token_count` must be positive")

        return self

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        _model_kwargs = self.model_kwargs or {}
        return {
            **{"gradient_api_url": self.gradient_api_url},
            **{"model_kwargs": _model_kwargs},
        }

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "gradient"

    def _kwargs_post_fine_tune_request(
        self, inputs: Sequence[str], kwargs: Mapping[str, Any]
    ) -> Mapping[str, Any]:
        """Build the kwargs for the Post request, used by sync

        Args:
            prompt (str): prompt used in query
            kwargs (dict): model kwargs in payload

        Returns:
            Dict[str, Union[str,dict]]: _description_
        """
        _model_kwargs = self.model_kwargs or {}
        _params = {**_model_kwargs, **kwargs}

        multipliers = _params.get("multipliers", None)

        return dict(
            url=f"{self.gradient_api_url}/models/{self.model_id}/fine-tune",
            headers={
                "authorization": f"Bearer {self.gradient_access_token}",
                "x-gradient-workspace-id": f"{self.gradient_workspace_id}",
                "accept": "application/json",
                "content-type": "application/json",
            },
            json=dict(
                samples=(
                    tuple(
                        {
                            "inputs": input,
                        }
                        for input in inputs
                    )
                    if multipliers is None
                    else tuple(
                        {
                            "inputs": input,
                            "fineTuningParameters": {
                                "multiplier": multiplier,
                            },
                        }
                        for input, multiplier in zip(inputs, multipliers)
                    )
                ),
            ),
        )

    def _kwargs_post_request(
        self, prompt: str, kwargs: Mapping[str, Any]
    ) -> Mapping[str, Any]:
        """Build the kwargs for the Post request, used by sync

        Args:
            prompt (str): prompt used in query
            kwargs (dict): model kwargs in payload

        Returns:
            Dict[str, Union[str,dict]]: _description_
        """
        _model_kwargs = self.model_kwargs or {}
        _params = {**_model_kwargs, **kwargs}

        return dict(
            url=f"{self.gradient_api_url}/models/{self.model_id}/complete",
            headers={
                "authorization": f"Bearer {self.gradient_access_token}",
                "x-gradient-workspace-id": f"{self.gradient_workspace_id}",
                "accept": "application/json",
                "content-type": "application/json",
            },
            json=dict(
                query=prompt,
                maxGeneratedTokenCount=_params.get("max_generated_token_count", None),
                temperature=_params.get("temperature", None),
                topK=_params.get("top_k", None),
                topP=_params.get("top_p", None),
            ),
        )

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Call to Gradients API `model/{id}/complete`.

        Args:
            prompt: The prompt to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            The string generated by the model.
        """
        try:
            response = requests.post(**self._kwargs_post_request(prompt, kwargs))
            if response.status_code != 200:
                raise Exception(
                    f"Gradient returned an unexpected response with status "
                    f"{response.status_code}: {response.text}"
                )
        except requests.exceptions.RequestException as e:
            raise Exception(f"RequestException while calling Gradient Endpoint: {e}")

        text = response.json()["generatedOutput"]

        if stop is not None:
            # Apply stop tokens when making calls to Gradient
            text = enforce_stop_tokens(text, stop)

        return text

    async def _acall(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Async Call to Gradients API `model/{id}/complete`.

        Args:
            prompt: The prompt to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            The string generated by the model.
        """
        if not self.aiosession:
            async with aiohttp.ClientSession() as session:
                async with session.post(
                    **self._kwargs_post_request(prompt=prompt, kwargs=kwargs)
                ) as response:
                    if response.status != 200:
                        raise Exception(
                            f"Gradient returned an unexpected response with status "
                            f"{response.status}: {response.text}"
                        )
                    text = (await response.json())["generatedOutput"]
        else:
            async with self.aiosession.post(
                **self._kwargs_post_request(prompt=prompt, kwargs=kwargs)
            ) as response:
                if response.status != 200:
                    raise Exception(
                        f"Gradient returned an unexpected response with status "
                        f"{response.status}: {response.text}"
                    )
                text = (await response.json())["generatedOutput"]

        if stop is not None:
            # Apply stop tokens when making calls to Gradient
            text = enforce_stop_tokens(text, stop)

        return text

    def _generate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Run the LLM on the given prompt and input."""

        # same thing with threading
        def _inner_generate(prompt: str) -> List[Generation]:
            return [
                Generation(
                    text=self._call(
                        prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
                    )
                )
            ]

        if len(prompts) <= 1:
            generations = list(map(_inner_generate, prompts))
        else:
            with ThreadPoolExecutor(min(8, len(prompts))) as p:
                generations = list(p.map(_inner_generate, prompts))

        return LLMResult(generations=generations)

    async def _agenerate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Run the LLM on the given prompt and input."""
        generations = []
        for generation in await asyncio.gather(
            *[
                self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
                for prompt in prompts
            ]
        ):
            generations.append([Generation(text=generation)])
        return LLMResult(generations=generations)

    def train_unsupervised(
        self,
        inputs: Sequence[str],
        **kwargs: Any,
    ) -> TrainResult:
        try:
            response = requests.post(
                **self._kwargs_post_fine_tune_request(inputs, kwargs)
            )
            if response.status_code != 200:
                raise Exception(
                    f"Gradient returned an unexpected response with status "
                    f"{response.status_code}: {response.text}"
                )
        except requests.exceptions.RequestException as e:
            raise Exception(f"RequestException while calling Gradient Endpoint: {e}")

        response_json = response.json()
        loss = response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
        return TrainResult(loss=loss)

    async def atrain_unsupervised(
        self,
        inputs: Sequence[str],
        **kwargs: Any,
    ) -> TrainResult:
        if not self.aiosession:
            async with aiohttp.ClientSession() as session:
                async with session.post(
                    **self._kwargs_post_fine_tune_request(inputs, kwargs)
                ) as response:
                    if response.status != 200:
                        raise Exception(
                            f"Gradient returned an unexpected response with status "
                            f"{response.status}: {response.text}"
                        )
                    response_json = await response.json()
                    loss = (
                        response_json["sumLoss"]
                        / response_json["numberOfTrainableTokens"]
                    )
        else:
            async with self.aiosession.post(
                **self._kwargs_post_fine_tune_request(inputs, kwargs)
            ) as response:
                if response.status != 200:
                    raise Exception(
                        f"Gradient returned an unexpected response with status "
                        f"{response.status}: {response.text}"
                    )
                response_json = await response.json()
                loss = (
                    response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
                )

        return TrainResult(loss=loss)
