import json
import logging
from typing import Any, List, Optional, Union

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    FunctionMessage,
    HumanMessage,
    SystemMessage,
)
from pydantic import Field

from langchain_community.llms.utils import enforce_stop_tokens

logger = logging.getLogger(__name__)
HEADERS = {"Content-Type": "application/json"}
DEFAULT_TIMEOUT = 30


def _convert_message_to_dict(message: BaseMessage) -> dict:
    if isinstance(message, HumanMessage):
        message_dict = {"role": "user", "content": message.content}
    elif isinstance(message, AIMessage):
        message_dict = {"role": "assistant", "content": message.content}
    elif isinstance(message, SystemMessage):
        message_dict = {"role": "system", "content": message.content}
    elif isinstance(message, FunctionMessage):
        message_dict = {"role": "function", "content": message.content}
    else:
        raise ValueError(f"Got unknown type {message}")
    return message_dict


class ChatGLM3(LLM):
    """ChatGLM3 LLM service."""

    model_name: str = Field(default="chatglm3-6b", alias="model")
    endpoint_url: str = "http://127.0.0.1:8000/v1/chat/completions"
    """Endpoint URL to use."""
    model_kwargs: Optional[dict] = None
    """Keyword arguments to pass to the model."""
    max_tokens: int = 20000
    """Max token allowed to pass to the model."""
    temperature: float = 0.1
    """LLM model temperature from 0 to 10."""
    top_p: float = 0.7
    """Top P for nucleus sampling from 0 to 1"""
    prefix_messages: List[BaseMessage] = Field(default_factory=list)
    """Series of messages for Chat input."""
    streaming: bool = False
    """Whether to stream the results or not."""
    http_client: Union[Any, None] = None
    timeout: int = DEFAULT_TIMEOUT

    @property
    def _llm_type(self) -> str:
        return "chat_glm_3"

    @property
    def _invocation_params(self) -> dict:
        """Get the parameters used to invoke the model."""
        params = {
            "model": self.model_name,
            "temperature": self.temperature,
            "max_tokens": self.max_tokens,
            "top_p": self.top_p,
            "stream": self.streaming,
        }
        return {**params, **(self.model_kwargs or {})}

    @property
    def client(self) -> Any:
        import httpx

        return self.http_client or httpx.Client(timeout=self.timeout)

    def _get_payload(self, prompt: str) -> dict:
        params = self._invocation_params
        messages = self.prefix_messages + [HumanMessage(content=prompt)]
        params.update(
            {
                "messages": [_convert_message_to_dict(m) for m in messages],
            }
        )
        return params

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Call out to a ChatGLM3 LLM inference endpoint.

        Args:
            prompt: The prompt to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            The string generated by the model.

        Example:
            .. code-block:: python

                response = chatglm_llm.invoke("Who are you?")
        """
        import httpx

        payload = self._get_payload(prompt)
        logger.debug(f"ChatGLM3 payload: {payload}")

        try:
            response = self.client.post(
                self.endpoint_url, headers=HEADERS, json=payload
            )
        except httpx.NetworkError as e:
            raise ValueError(f"Error raised by inference endpoint: {e}")

        logger.debug(f"ChatGLM3 response: {response}")

        if response.status_code != 200:
            raise ValueError(f"Failed with response: {response}")

        try:
            parsed_response = response.json()

            if isinstance(parsed_response, dict):
                content_keys = "choices"
                if content_keys in parsed_response:
                    choices = parsed_response[content_keys]
                    if len(choices):
                        text = choices[0]["message"]["content"]
                else:
                    raise ValueError(f"No content in response : {parsed_response}")
            else:
                raise ValueError(f"Unexpected response type: {parsed_response}")

        except json.JSONDecodeError as e:
            raise ValueError(
                f"Error raised during decoding response from inference endpoint: {e}."
                f"\nResponse: {response.text}"
            )

        if stop is not None:
            text = enforce_stop_tokens(text, stop)

        return text
