# This module contains utility classes and functions for interacting with Arcee API.
# For more information and updates, refer to the Arcee utils page:
# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]

from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union

import requests
from langchain_core.retrievers import Document
from pydantic import BaseModel, SecretStr, model_validator


class ArceeRoute(str, Enum):
    """Routes available for the Arcee API as enumerator."""

    generate = "models/generate"
    retrieve = "models/retrieve"
    model_training_status = "models/status/{id_or_name}"


class DALMFilterType(str, Enum):
    """Filter types available for a DALM retrieval as enumerator."""

    fuzzy_search = "fuzzy_search"
    strict_search = "strict_search"


class DALMFilter(BaseModel):
    """Filters available for a DALM retrieval and generation.

    Arguments:
        field_name: The field to filter on. Can be 'document' or 'name' to filter
            on your document's raw text or title. Any other field will be presumed
            to be a metadata field you included when uploading your context data
        filter_type: Currently 'fuzzy_search' and 'strict_search' are supported.
            'fuzzy_search' means a fuzzy search on the provided field is performed.
            The exact strict doesn't need to exist in the document
            for this to find a match.
            Very useful for scanning a document for some keyword terms.
            'strict_search' means that the exact string must appear
            in the provided field.
            This is NOT an exact eq filter. ie a document with content
            "the happy dog crossed the street" will match on a strict_search of
            "dog" but won't match on "the dog".
            Python equivalent of `return search_string in full_string`.
        value: The actual value to search for in the context data/metadata
    """

    field_name: str
    filter_type: DALMFilterType
    value: str
    _is_metadata: bool = False

    @model_validator(mode="before")
    @classmethod
    def set_meta(cls, values: Dict) -> Any:
        """document and name are reserved arcee keys. Anything else is metadata"""
        values["_is_meta"] = values.get("field_name") not in ["document", "name"]
        return values


class ArceeDocumentSource(BaseModel):
    """Source of an Arcee document."""

    document: str
    name: str
    id: str


class ArceeDocument(BaseModel):
    """Arcee document."""

    index: str
    id: str
    score: float
    source: ArceeDocumentSource


class ArceeDocumentAdapter:
    """Adapter for Arcee documents"""

    @classmethod
    def adapt(cls, arcee_document: ArceeDocument) -> Document:
        """Adapts an `ArceeDocument` to a langchain's `Document` object."""
        return Document(
            page_content=arcee_document.source.document,
            metadata={
                # arcee document; source metadata
                "name": arcee_document.source.name,
                "source_id": arcee_document.source.id,
                # arcee document metadata
                "index": arcee_document.index,
                "id": arcee_document.id,
                "score": arcee_document.score,
            },
        )


class ArceeWrapper:
    """Wrapper for Arcee API.

    For more details, see: https://www.arcee.ai/
    """

    def __init__(
        self,
        arcee_api_key: Union[str, SecretStr],
        arcee_api_url: str,
        arcee_api_version: str,
        model_kwargs: Optional[Dict[str, Any]],
        model_name: str,
    ):
        """Initialize ArceeWrapper.

        Arguments:
            arcee_api_key: API key for Arcee API.
            arcee_api_url: URL for Arcee API.
            arcee_api_version: Version of Arcee API.
            model_kwargs: Keyword arguments for Arcee API.
            model_name: Name of an Arcee model.
        """
        if isinstance(arcee_api_key, str):
            arcee_api_key_ = SecretStr(arcee_api_key)
        else:
            arcee_api_key_ = arcee_api_key
        self.arcee_api_key: SecretStr = arcee_api_key_
        self.model_kwargs = model_kwargs
        self.arcee_api_url = arcee_api_url
        self.arcee_api_version = arcee_api_version

        try:
            route = ArceeRoute.model_training_status.value.format(id_or_name=model_name)
            response = self._make_request("get", route)
            self.model_id = response.get("model_id")
            self.model_training_status = response.get("status")
        except Exception as e:
            raise ValueError(
                f"Error while validating model training status for '{model_name}': {e}"
            ) from e

    def validate_model_training_status(self) -> None:
        if self.model_training_status != "training_complete":
            raise Exception(
                f"Model {self.model_id} is not ready. "
                "Please wait for training to complete."
            )

    def _make_request(
        self,
        method: Literal["post", "get"],
        route: Union[ArceeRoute, str],
        body: Optional[Mapping[str, Any]] = None,
        params: Optional[dict] = None,
        headers: Optional[dict] = None,
    ) -> dict:
        """Make a request to the Arcee API
        Args:
            method: The HTTP method to use
            route: The route to call
            body: The body of the request
            params: The query params of the request
            headers: The headers of the request
        """
        headers = self._make_request_headers(headers=headers)
        url = self._make_request_url(route=route)

        req_type = getattr(requests, method)

        response = req_type(url, json=body, params=params, headers=headers)
        if response.status_code not in (200, 201):
            raise Exception(f"Failed to make request. Response: {response.text}")
        return response.json()

    def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
        headers = headers or {}
        if not isinstance(self.arcee_api_key, SecretStr):
            raise TypeError(
                f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}"
            )
        api_key = self.arcee_api_key.get_secret_value()
        internal_headers = {
            "X-Token": api_key,
            "Content-Type": "application/json",
        }
        headers.update(internal_headers)
        return headers

    def _make_request_url(self, route: Union[ArceeRoute, str]) -> str:
        return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}"

    def _make_request_body_for_models(
        self, prompt: str, **kwargs: Mapping[str, Any]
    ) -> Mapping[str, Any]:
        """Make the request body for generate/retrieve models endpoint"""
        _model_kwargs = self.model_kwargs or {}
        _params = {**_model_kwargs, **kwargs}

        filters = [DALMFilter(**f) for f in _params.get("filters", [])]
        return dict(
            model_id=self.model_id,
            query=prompt,
            size=_params.get("size", 3),
            filters=filters,
            id=self.model_id,
        )

    def generate(
        self,
        prompt: str,
        **kwargs: Any,
    ) -> str:
        """Generate text from Arcee DALM.

        Args:
            prompt: Prompt to generate text from.
            size: The max number of context results to retrieve. Defaults to 3.
              (Can be less if filters are provided).
            filters: Filters to apply to the context dataset.
        """

        response = self._make_request(
            method="post",
            route=ArceeRoute.generate.value,
            body=self._make_request_body_for_models(
                prompt=prompt,
                **kwargs,
            ),
        )
        return response["text"]

    def retrieve(
        self,
        query: str,
        **kwargs: Any,
    ) -> List[Document]:
        """Retrieve {size} contexts with your retriever for a given query

        Args:
            query: Query to submit to the model
            size: The max number of context results to retrieve. Defaults to 3.
              (Can be less if filters are provided).
            filters: Filters to apply to the context dataset.
        """

        response = self._make_request(
            method="post",
            route=ArceeRoute.retrieve.value,
            body=self._make_request_body_for_models(
                prompt=query,
                **kwargs,
            ),
        )
        return [
            ArceeDocumentAdapter.adapt(ArceeDocument(**doc))
            for doc in response["results"]
        ]
