"""Base classes for chain routing."""

from __future__ import annotations

from abc import ABC
from typing import Any, Dict, List, Mapping, NamedTuple, Optional

from langchain_core.callbacks import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
    Callbacks,
)
from pydantic import ConfigDict

from langchain.chains.base import Chain


class Route(NamedTuple):
    destination: Optional[str]
    next_inputs: Dict[str, Any]


class RouterChain(Chain, ABC):
    """Chain that outputs the name of a destination chain and the inputs to it."""

    @property
    def output_keys(self) -> List[str]:
        return ["destination", "next_inputs"]

    def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
        """
        Route inputs to a destination chain.

        Args:
            inputs: inputs to the chain
            callbacks: callbacks to use for the chain

        Returns:
            a Route object
        """
        result = self(inputs, callbacks=callbacks)
        return Route(result["destination"], result["next_inputs"])

    async def aroute(
        self, inputs: Dict[str, Any], callbacks: Callbacks = None
    ) -> Route:
        result = await self.acall(inputs, callbacks=callbacks)
        return Route(result["destination"], result["next_inputs"])


class MultiRouteChain(Chain):
    """Use a single chain to route an input to one of multiple candidate chains."""

    router_chain: RouterChain
    """Chain that routes inputs to destination chains."""
    destination_chains: Mapping[str, Chain]
    """Chains that return final answer to inputs."""
    default_chain: Chain
    """Default chain to use when none of the destination chains are suitable."""
    silent_errors: bool = False
    """If True, use default_chain when an invalid destination name is provided. 
    Defaults to False."""

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        extra="forbid",
    )

    @property
    def input_keys(self) -> List[str]:
        """Will be whatever keys the router chain prompt expects.

        :meta private:
        """
        return self.router_chain.input_keys

    @property
    def output_keys(self) -> List[str]:
        """Will always return text key.

        :meta private:
        """
        return []

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        callbacks = _run_manager.get_child()
        route = self.router_chain.route(inputs, callbacks=callbacks)

        _run_manager.on_text(
            str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose
        )
        if not route.destination:
            return self.default_chain(route.next_inputs, callbacks=callbacks)
        elif route.destination in self.destination_chains:
            return self.destination_chains[route.destination](
                route.next_inputs, callbacks=callbacks
            )
        elif self.silent_errors:
            return self.default_chain(route.next_inputs, callbacks=callbacks)
        else:
            raise ValueError(
                f"Received invalid destination chain name '{route.destination}'"
            )

    async def _acall(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
        callbacks = _run_manager.get_child()
        route = await self.router_chain.aroute(inputs, callbacks=callbacks)

        await _run_manager.on_text(
            str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose
        )
        if not route.destination:
            return await self.default_chain.acall(
                route.next_inputs, callbacks=callbacks
            )
        elif route.destination in self.destination_chains:
            return await self.destination_chains[route.destination].acall(
                route.next_inputs, callbacks=callbacks
            )
        elif self.silent_errors:
            return await self.default_chain.acall(
                route.next_inputs, callbacks=callbacks
            )
        else:
            raise ValueError(
                f"Received invalid destination chain name '{route.destination}'"
            )
