"""Chain that makes API calls and summarizes the responses to answer a question."""

from __future__ import annotations

import json
from typing import Any, Dict, List, NamedTuple, Optional, cast

from langchain.chains.api.openapi.requests_chain import APIRequesterChain
from langchain.chains.api.openapi.response_chain import APIResponderChain
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
from langchain_core.language_models import BaseLanguageModel
from pydantic import BaseModel, Field
from requests import Response

from langchain_community.tools.openapi.utils.api_models import APIOperation
from langchain_community.utilities.requests import Requests


class _ParamMapping(NamedTuple):
    """Mapping from parameter name to parameter value."""

    query_params: List[str]
    body_params: List[str]
    path_params: List[str]


class OpenAPIEndpointChain(Chain, BaseModel):
    """Chain interacts with an OpenAPI endpoint using natural language."""

    api_request_chain: LLMChain
    api_response_chain: Optional[LLMChain] = None
    api_operation: APIOperation
    requests: Requests = Field(exclude=True, default_factory=Requests)
    param_mapping: _ParamMapping = Field(alias="param_mapping")
    return_intermediate_steps: bool = False
    instructions_key: str = "instructions"  #: :meta private:
    output_key: str = "output"  #: :meta private:
    max_text_length: Optional[int] = Field(ge=0)  #: :meta private:

    @property
    def input_keys(self) -> List[str]:
        """Expect input key.

        :meta private:
        """
        return [self.instructions_key]

    @property
    def output_keys(self) -> List[str]:
        """Expect output key.

        :meta private:
        """
        if not self.return_intermediate_steps:
            return [self.output_key]
        else:
            return [self.output_key, "intermediate_steps"]

    def _construct_path(self, args: Dict[str, str]) -> str:
        """Construct the path from the deserialized input."""
        path = self.api_operation.base_url + self.api_operation.path
        for param in self.param_mapping.path_params:
            path = path.replace(f"{{{param}}}", str(args.pop(param, "")))
        return path

    def _extract_query_params(self, args: Dict[str, str]) -> Dict[str, str]:
        """Extract the query params from the deserialized input."""
        query_params = {}
        for param in self.param_mapping.query_params:
            if param in args:
                query_params[param] = args.pop(param)
        return query_params

    def _extract_body_params(self, args: Dict[str, str]) -> Optional[Dict[str, str]]:
        """Extract the request body params from the deserialized input."""
        body_params = None
        if self.param_mapping.body_params:
            body_params = {}
            for param in self.param_mapping.body_params:
                if param in args:
                    body_params[param] = args.pop(param)
        return body_params

    def deserialize_json_input(self, serialized_args: str) -> dict:
        """Use the serialized typescript dictionary.

        Resolve the path, query params dict, and optional requestBody dict.
        """
        args: dict = json.loads(serialized_args)
        path = self._construct_path(args)
        body_params = self._extract_body_params(args)
        query_params = self._extract_query_params(args)
        return {
            "url": path,
            "data": body_params,
            "params": query_params,
        }

    def _get_output(self, output: str, intermediate_steps: dict) -> dict:
        """Return the output from the API call."""
        if self.return_intermediate_steps:
            return {
                self.output_key: output,
                "intermediate_steps": intermediate_steps,
            }
        else:
            return {self.output_key: output}

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        intermediate_steps = {}
        instructions = inputs[self.instructions_key]
        instructions = instructions[: self.max_text_length]
        _api_arguments = self.api_request_chain.predict_and_parse(
            instructions=instructions, callbacks=_run_manager.get_child()
        )
        api_arguments = cast(str, _api_arguments)
        intermediate_steps["request_args"] = api_arguments
        _run_manager.on_text(
            api_arguments, color="green", end="\n", verbose=self.verbose
        )
        if api_arguments.startswith("ERROR"):
            return self._get_output(api_arguments, intermediate_steps)
        elif api_arguments.startswith("MESSAGE:"):
            return self._get_output(
                api_arguments[len("MESSAGE:") :], intermediate_steps
            )
        try:
            request_args = self.deserialize_json_input(api_arguments)
            method = getattr(self.requests, self.api_operation.method.value)
            api_response: Response = method(**request_args)
            if api_response.status_code != 200:
                method_str = str(self.api_operation.method.value)
                response_text = (
                    f"{api_response.status_code}: {api_response.reason}"
                    + f"\nFor {method_str.upper()}  {request_args['url']}\n"
                    + f"Called with args: {request_args['params']}"
                )
            else:
                response_text = api_response.text
        except Exception as e:
            response_text = f"Error with message {str(e)}"
        response_text = response_text[: self.max_text_length]
        intermediate_steps["response_text"] = response_text
        _run_manager.on_text(
            response_text, color="blue", end="\n", verbose=self.verbose
        )
        if self.api_response_chain is not None:
            _answer = self.api_response_chain.predict_and_parse(
                response=response_text,
                instructions=instructions,
                callbacks=_run_manager.get_child(),
            )
            answer = cast(str, _answer)
            _run_manager.on_text(answer, color="yellow", end="\n", verbose=self.verbose)
            return self._get_output(answer, intermediate_steps)
        else:
            return self._get_output(response_text, intermediate_steps)

    @classmethod
    def from_url_and_method(
        cls,
        spec_url: str,
        path: str,
        method: str,
        llm: BaseLanguageModel,
        requests: Optional[Requests] = None,
        return_intermediate_steps: bool = False,
        **kwargs: Any,
        # TODO: Handle async
    ) -> "OpenAPIEndpointChain":
        """Create an OpenAPIEndpoint from a spec at the specified url."""
        operation = APIOperation.from_openapi_url(spec_url, path, method)
        return cls.from_api_operation(
            operation,
            requests=requests,
            llm=llm,
            return_intermediate_steps=return_intermediate_steps,
            **kwargs,
        )

    @classmethod
    def from_api_operation(
        cls,
        operation: APIOperation,
        llm: BaseLanguageModel,
        requests: Optional[Requests] = None,
        verbose: bool = False,
        return_intermediate_steps: bool = False,
        raw_response: bool = False,
        callbacks: Callbacks = None,
        **kwargs: Any,
        # TODO: Handle async
    ) -> "OpenAPIEndpointChain":
        """Create an OpenAPIEndpointChain from an operation and a spec."""
        param_mapping = _ParamMapping(
            query_params=operation.query_params,
            body_params=operation.body_params,
            path_params=operation.path_params,
        )
        requests_chain = APIRequesterChain.from_llm_and_typescript(
            llm,
            typescript_definition=operation.to_typescript(),
            verbose=verbose,
            callbacks=callbacks,
        )
        if raw_response:
            response_chain = None
        else:
            response_chain = APIResponderChain.from_llm(
                llm, verbose=verbose, callbacks=callbacks
            )
        _requests = requests or Requests()
        return cls(
            api_request_chain=requests_chain,
            api_response_chain=response_chain,
            api_operation=operation,
            requests=_requests,
            param_mapping=param_mapping,
            verbose=verbose,
            return_intermediate_steps=return_intermediate_steps,
            callbacks=callbacks,
            **kwargs,
        )
