import json
import urllib.request
import warnings
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import BaseModel, ConfigDict, SecretStr, model_validator, validator

DEFAULT_TIMEOUT = 50


class AzureMLEndpointClient(object):
    """AzureML Managed Endpoint client."""

    def __init__(
        self,
        endpoint_url: str,
        endpoint_api_key: str,
        deployment_name: str = "",
        timeout: int = DEFAULT_TIMEOUT,
    ) -> None:
        """Initialize the class."""
        if not endpoint_api_key or not endpoint_url:
            raise ValueError(
                """A key/token and REST endpoint should 
                be provided to invoke the endpoint"""
            )
        self.endpoint_url = endpoint_url
        self.endpoint_api_key = endpoint_api_key
        self.deployment_name = deployment_name
        self.timeout = timeout

    def call(
        self,
        body: bytes,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> bytes:
        """call."""

        # The azureml-model-deployment header will force the request to go to a
        # specific deployment. Remove this header to have the request observe the
        # endpoint traffic rules.
        headers = {
            "Content-Type": "application/json",
            "Authorization": ("Bearer " + self.endpoint_api_key),
        }
        if self.deployment_name != "":
            headers["azureml-model-deployment"] = self.deployment_name

        req = urllib.request.Request(self.endpoint_url, body, headers)
        response = urllib.request.urlopen(
            req, timeout=kwargs.get("timeout", self.timeout)
        )
        result = response.read()
        return result


class AzureMLEndpointApiType(str, Enum):
    """Azure ML endpoints API types. Use `dedicated` for models deployed in hosted
    infrastructure (also known as Online Endpoints in Azure Machine Learning),
    or `serverless` for models deployed as a service with a
    pay-as-you-go billing or PTU.
    """

    dedicated = "dedicated"
    realtime = "realtime"  #: Deprecated
    serverless = "serverless"


class ContentFormatterBase:
    """Transform request and response of AzureML endpoint to match with
    required schema.
    """

    """
    Example:
        .. code-block:: python
        
            class ContentFormatter(ContentFormatterBase):
                content_type = "application/json"
                accepts = "application/json"
                
                def format_request_payload(
                    self, 
                    prompt: str, 
                    model_kwargs: Dict,
                    api_type: AzureMLEndpointApiType,
                ) -> bytes:
                    input_str = json.dumps(
                        {
                            "inputs": {"input_string": [prompt]}, 
                            "parameters": model_kwargs,
                        }
                    )
                    return str.encode(input_str)
                    
                def format_response_payload(
                        self, output: str, api_type: AzureMLEndpointApiType
                    ) -> str:
                    response_json = json.loads(output)
                    return response_json[0]["0"]
    """
    content_type: Optional[str] = "application/json"
    """The MIME type of the input data passed to the endpoint"""

    accepts: Optional[str] = "application/json"
    """The MIME type of the response data returned from the endpoint"""

    format_error_msg: str = (
        "Error while formatting response payload for chat model of type "
        " `{api_type}`. Are you using the right formatter for the deployed "
        " model and endpoint type?"
    )

    @staticmethod
    def escape_special_characters(prompt: str) -> str:
        """Escapes any special characters in `prompt`"""
        escape_map = {
            "\\": "\\\\",
            '"': '\\"',
            "\b": "\\b",
            "\f": "\\f",
            "\n": "\\n",
            "\r": "\\r",
            "\t": "\\t",
        }

        # Replace each occurrence of the specified characters with escaped versions
        for escape_sequence, escaped_sequence in escape_map.items():
            prompt = prompt.replace(escape_sequence, escaped_sequence)

        return prompt

    @property
    def supported_api_types(self) -> List[AzureMLEndpointApiType]:
        """Supported APIs for the given formatter. Azure ML supports
        deploying models using different hosting methods. Each method may have
        a different API structure."""

        return [AzureMLEndpointApiType.dedicated]

    def format_request_payload(
        self,
        prompt: str,
        model_kwargs: Dict,
        api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
    ) -> Any:
        """Formats the request body according to the input schema of
        the model. Returns bytes or seekable file like object in the
        format specified in the content_type request header.
        """
        raise NotImplementedError()

    @abstractmethod
    def format_response_payload(
        self,
        output: bytes,
        api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
    ) -> Generation:
        """Formats the response body according to the output
        schema of the model. Returns the data type that is
        received from the response.
        """


class GPT2ContentFormatter(ContentFormatterBase):
    """Content handler for GPT2"""

    @property
    def supported_api_types(self) -> List[AzureMLEndpointApiType]:
        return [AzureMLEndpointApiType.dedicated]

    def format_request_payload(  # type: ignore[override]
        self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
    ) -> bytes:
        prompt = ContentFormatterBase.escape_special_characters(prompt)
        request_payload = json.dumps(
            {"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
        )
        return str.encode(request_payload)

    def format_response_payload(  # type: ignore[override]
        self, output: bytes, api_type: AzureMLEndpointApiType
    ) -> Generation:
        try:
            choice = json.loads(output)[0]["0"]
        except (KeyError, IndexError, TypeError) as e:
            raise ValueError(self.format_error_msg.format(api_type=api_type)) from e  # type: ignore[union-attr]
        return Generation(text=choice)


class OSSContentFormatter(GPT2ContentFormatter):
    """Deprecated: Kept for backwards compatibility

    Content handler for LLMs from the OSS catalog."""

    content_formatter: Any = None

    def __init__(self) -> None:
        super().__init__()
        warnings.warn(
            """`OSSContentFormatter` will be deprecated in the future. 
                      Please use `GPT2ContentFormatter` instead.  
                      """
        )


class HFContentFormatter(ContentFormatterBase):
    """Content handler for LLMs from the HuggingFace catalog."""

    @property
    def supported_api_types(self) -> List[AzureMLEndpointApiType]:
        return [AzureMLEndpointApiType.dedicated]

    def format_request_payload(  # type: ignore[override]
        self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
    ) -> bytes:
        ContentFormatterBase.escape_special_characters(prompt)
        request_payload = json.dumps(
            {"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
        )
        return str.encode(request_payload)

    def format_response_payload(  # type: ignore[override]
        self, output: bytes, api_type: AzureMLEndpointApiType
    ) -> Generation:
        try:
            choice = json.loads(output)[0]["0"]["generated_text"]
        except (KeyError, IndexError, TypeError) as e:
            raise ValueError(self.format_error_msg.format(api_type=api_type)) from e  # type: ignore[union-attr]
        return Generation(text=choice)


class DollyContentFormatter(ContentFormatterBase):
    """Content handler for the Dolly-v2-12b model"""

    @property
    def supported_api_types(self) -> List[AzureMLEndpointApiType]:
        return [AzureMLEndpointApiType.dedicated]

    def format_request_payload(  # type: ignore[override]
        self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
    ) -> bytes:
        prompt = ContentFormatterBase.escape_special_characters(prompt)
        request_payload = json.dumps(
            {
                "input_data": {"input_string": [f'"{prompt}"']},
                "parameters": model_kwargs,
            }
        )
        return str.encode(request_payload)

    def format_response_payload(  # type: ignore[override]
        self, output: bytes, api_type: AzureMLEndpointApiType
    ) -> Generation:
        try:
            choice = json.loads(output)[0]
        except (KeyError, IndexError, TypeError) as e:
            raise ValueError(self.format_error_msg.format(api_type=api_type)) from e  # type: ignore[union-attr]
        return Generation(text=choice)


class CustomOpenAIContentFormatter(ContentFormatterBase):
    """Content formatter for models that use the OpenAI like API scheme."""

    @property
    def supported_api_types(self) -> List[AzureMLEndpointApiType]:
        return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]

    def format_request_payload(  # type: ignore[override]
        self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
    ) -> bytes:
        """Formats the request according to the chosen api"""
        prompt = ContentFormatterBase.escape_special_characters(prompt)
        if api_type in [
            AzureMLEndpointApiType.dedicated,
            AzureMLEndpointApiType.realtime,
        ]:
            request_payload = json.dumps(
                {
                    "input_data": {
                        "input_string": [f'"{prompt}"'],
                        "parameters": model_kwargs,
                    }
                }
            )
        elif api_type == AzureMLEndpointApiType.serverless:
            request_payload = json.dumps({"prompt": prompt, **model_kwargs})
        else:
            raise ValueError(
                f"`api_type` {api_type} is not supported by this formatter"
            )
        return str.encode(request_payload)

    def format_response_payload(  # type: ignore[override]
        self, output: bytes, api_type: AzureMLEndpointApiType
    ) -> Generation:
        """Formats response"""
        if api_type in [
            AzureMLEndpointApiType.dedicated,
            AzureMLEndpointApiType.realtime,
        ]:
            try:
                choice = json.loads(output)[0]["0"]
            except (KeyError, IndexError, TypeError) as e:
                raise ValueError(self.format_error_msg.format(api_type=api_type)) from e  # type: ignore[union-attr]
            return Generation(text=choice)
        if api_type == AzureMLEndpointApiType.serverless:
            try:
                choice = json.loads(output)["choices"][0]
                if not isinstance(choice, dict):
                    raise TypeError(
                        "Endpoint response is not well formed for a chat "
                        "model. Expected `dict` but `{type(choice)}` was "
                        "received."
                    )
            except (KeyError, IndexError, TypeError) as e:
                raise ValueError(self.format_error_msg.format(api_type=api_type)) from e  # type: ignore[union-attr]
            return Generation(
                text=choice["text"].strip(),
                generation_info=dict(
                    finish_reason=choice.get("finish_reason"),
                    logprobs=choice.get("logprobs"),
                ),
            )
        raise ValueError(f"`api_type` {api_type} is not supported by this formatter")


class LlamaContentFormatter(CustomOpenAIContentFormatter):
    """Deprecated: Kept for backwards compatibility

    Content formatter for Llama."""

    content_formatter: Any = None

    def __init__(self) -> None:
        super().__init__()
        warnings.warn(
            """`LlamaContentFormatter` will be deprecated in the future. 
                Please use `CustomOpenAIContentFormatter` instead.  
            """
        )


class AzureMLBaseEndpoint(BaseModel):
    """Azure ML Online Endpoint models."""

    endpoint_url: str = ""
    """URL of pre-existing Endpoint. Should be passed to constructor or specified as 
        env var `AZUREML_ENDPOINT_URL`."""

    endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated
    """Type of the endpoint being consumed. Possible values are `serverless` for 
        pay-as-you-go and `dedicated` for dedicated endpoints. """

    endpoint_api_key: SecretStr = convert_to_secret_str("")
    """Authentication Key for Endpoint. Should be passed to constructor or specified as
        env var `AZUREML_ENDPOINT_API_KEY`."""

    deployment_name: str = ""
    """Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed 
        to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""

    timeout: int = DEFAULT_TIMEOUT
    """Request timeout for calls to the endpoint"""

    http_client: Any = None  #: :meta private:

    max_retries: int = 1

    content_formatter: Any = None
    """The content formatter that provides an input and output
    transform function to handle formats between the LLM and
    the endpoint"""

    model_kwargs: Optional[dict] = None
    """Keyword arguments to pass to the model."""

    model_config = ConfigDict(protected_namespaces=())

    @model_validator(mode="before")
    @classmethod
    def validate_environ(cls, values: Dict) -> Any:
        values["endpoint_api_key"] = convert_to_secret_str(
            get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
        )
        values["endpoint_url"] = get_from_dict_or_env(
            values, "endpoint_url", "AZUREML_ENDPOINT_URL"
        )
        values["deployment_name"] = get_from_dict_or_env(
            values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
        )
        values["endpoint_api_type"] = get_from_dict_or_env(
            values,
            "endpoint_api_type",
            "AZUREML_ENDPOINT_API_TYPE",
            AzureMLEndpointApiType.dedicated,
        )
        values["timeout"] = get_from_dict_or_env(
            values,
            "timeout",
            "AZUREML_TIMEOUT",
            str(DEFAULT_TIMEOUT),
        )

        return values

    @validator("content_formatter")
    def validate_content_formatter(
        cls, field_value: Any, values: Dict
    ) -> ContentFormatterBase:
        """Validate that content formatter is supported by endpoint type."""
        endpoint_api_type = values.get("endpoint_api_type")
        if endpoint_api_type not in field_value.supported_api_types:
            raise ValueError(
                f"Content formatter f{type(field_value)} is not supported by this "
                f"endpoint. Supported types are {field_value.supported_api_types} "
                f"but endpoint is {endpoint_api_type}."
            )
        return field_value

    @validator("endpoint_url")
    def validate_endpoint_url(cls, field_value: Any) -> str:
        """Validate that endpoint url is complete."""
        if field_value.endswith("/"):
            field_value = field_value[:-1]
        if field_value.endswith("inference.ml.azure.com"):
            raise ValueError(
                "`endpoint_url` should contain the full invocation URL including "
                "`/score` for `endpoint_api_type='dedicated'` or `/completions` "
                "or `/chat/completions` for `endpoint_api_type='serverless'`"
            )
        return field_value

    @validator("endpoint_api_type")
    def validate_endpoint_api_type(
        cls, field_value: Any, values: Dict
    ) -> AzureMLEndpointApiType:
        """Validate that endpoint api type is compatible with the URL format."""
        endpoint_url = values.get("endpoint_url")
        if (
            (
                field_value == AzureMLEndpointApiType.dedicated
                or field_value == AzureMLEndpointApiType.realtime
            )
            and not endpoint_url.endswith("/score")  # type: ignore[union-attr]
        ):
            raise ValueError(
                "Endpoints of type `dedicated` should follow the format "
                "`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
                " If your endpoint URL ends with `/completions` or"
                "`/chat/completions`, use `endpoint_api_type='serverless'` instead."
            )
        if field_value == AzureMLEndpointApiType.serverless and not (
            endpoint_url.endswith("/completions")  # type: ignore[union-attr]
            or endpoint_url.endswith("/chat/completions")  # type: ignore[union-attr]
        ):
            raise ValueError(
                "Endpoints of type `serverless` should follow the format "
                "`https://<your-endpoint>.<your_region>.inference.ml.azure.com/chat/completions`"
                " or `https://<your-endpoint>.<your_region>.inference.ml.azure.com/chat/completions`"
            )

        return field_value

    @validator("http_client", always=True)
    def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
        """Validate that api key and python package exists in environment."""
        endpoint_url = values.get("endpoint_url")
        endpoint_key = values.get("endpoint_api_key")
        deployment_name = values.get("deployment_name")
        timeout = values.get("timeout", DEFAULT_TIMEOUT)

        http_client = AzureMLEndpointClient(
            endpoint_url,  # type: ignore
            endpoint_key.get_secret_value(),  # type: ignore
            deployment_name,  # type: ignore
            timeout,  # type: ignore
        )

        return http_client


class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
    """Azure ML Online Endpoint models.

    Example:
        .. code-block:: python
            azure_llm = AzureMLOnlineEndpoint(
                endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
                endpoint_api_type=AzureMLApiType.dedicated,
                endpoint_api_key="my-api-key",
                timeout=120,
                content_formatter=content_formatter,
            )
    """

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        _model_kwargs = self.model_kwargs or {}
        return {
            **{"deployment_name": self.deployment_name},
            **{"model_kwargs": _model_kwargs},
        }

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "azureml_endpoint"

    def _generate(
        self,
        prompts: List[str],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Run the LLM on the given prompts.

        Args:
            prompts: The prompt to pass into the model.
            stop: Optional list of stop words to use when generating.
        Returns:
            The string generated by the model.
        Example:
            .. code-block:: python
                response = azureml_model.invoke("Tell me a joke.")
        """
        _model_kwargs = self.model_kwargs or {}
        _model_kwargs.update(kwargs)
        if stop:
            _model_kwargs["stop"] = stop
        generations = []

        for prompt in prompts:
            request_payload = self.content_formatter.format_request_payload(
                prompt, _model_kwargs, self.endpoint_api_type
            )
            response_payload = self.http_client.call(
                body=request_payload, run_manager=run_manager
            )
            generated_text = self.content_formatter.format_response_payload(
                response_payload, self.endpoint_api_type
            )
            generations.append([generated_text])

        return LLMResult(generations=generations)
