from typing import Any, Dict, List

import requests
from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict

DEFAULT_MODEL_NAME = "@cf/baai/bge-base-en-v1.5"


class CloudflareWorkersAIEmbeddings(BaseModel, Embeddings):
    """Cloudflare Workers AI embedding model.

    To use, you need to provide an API token and
    account ID to access Cloudflare Workers AI.

    Example:
        .. code-block:: python

            from langchain_community.embeddings import CloudflareWorkersAIEmbeddings

            account_id = "my_account_id"
            api_token = "my_secret_api_token"
            model_name =  "@cf/baai/bge-small-en-v1.5"

            cf = CloudflareWorkersAIEmbeddings(
                account_id=account_id,
                api_token=api_token,
                model_name=model_name
            )
    """

    api_base_url: str = "https://api.cloudflare.com/client/v4/accounts"
    account_id: str
    api_token: str
    model_name: str = DEFAULT_MODEL_NAME
    batch_size: int = 50
    strip_new_lines: bool = True
    headers: Dict[str, str] = {"Authorization": "Bearer "}

    def __init__(self, **kwargs: Any):
        """Initialize the Cloudflare Workers AI client."""
        super().__init__(**kwargs)

        self.headers = {"Authorization": f"Bearer {self.api_token}"}

    model_config = ConfigDict(extra="forbid", protected_namespaces=())

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Compute doc embeddings using Cloudflare Workers AI.

        Args:
            texts: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """
        if self.strip_new_lines:
            texts = [text.replace("\n", " ") for text in texts]

        batches = [
            texts[i : i + self.batch_size]
            for i in range(0, len(texts), self.batch_size)
        ]
        embeddings = []

        for batch in batches:
            response = requests.post(
                f"{self.api_base_url}/{self.account_id}/ai/run/{self.model_name}",
                headers=self.headers,
                json={"text": batch},
            )
            embeddings.extend(response.json()["result"]["data"])

        return embeddings

    def embed_query(self, text: str) -> List[float]:
        """Compute query embeddings using Cloudflare Workers AI.

        Args:
            text: The text to embed.

        Returns:
            Embeddings for the text.
        """
        text = text.replace("\n", " ") if self.strip_new_lines else text
        response = requests.post(
            f"{self.api_base_url}/{self.account_id}/ai/run/{self.model_name}",
            headers=self.headers,
            json={"text": [text]},
        )
        return response.json()["result"]["data"][0]
