import asyncio
import logging
import threading
from typing import Dict, List, Optional

import requests
from langchain_core._api.deprecation import deprecated
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.utils import get_from_dict_or_env, pre_init
from pydantic import BaseModel, ConfigDict

logger = logging.getLogger(__name__)


@deprecated(
    since="0.0.13",
    alternative="langchain_community.embeddings.QianfanEmbeddingsEndpoint",
)
class ErnieEmbeddings(BaseModel, Embeddings):
    """`Ernie Embeddings V1` embedding models."""

    ernie_api_base: Optional[str] = None
    ernie_client_id: Optional[str] = None
    ernie_client_secret: Optional[str] = None
    access_token: Optional[str] = None

    chunk_size: int = 16

    model_name: str = "ErnieBot-Embedding-V1"

    _lock = threading.Lock()

    model_config = ConfigDict(protected_namespaces=())

    @pre_init
    def validate_environment(cls, values: Dict) -> Dict:
        values["ernie_api_base"] = get_from_dict_or_env(
            values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
        )
        values["ernie_client_id"] = get_from_dict_or_env(
            values,
            "ernie_client_id",
            "ERNIE_CLIENT_ID",
        )
        values["ernie_client_secret"] = get_from_dict_or_env(
            values,
            "ernie_client_secret",
            "ERNIE_CLIENT_SECRET",
        )
        return values

    def _embedding(self, json: object) -> dict:
        base_url = (
            f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
        )
        resp = requests.post(
            f"{base_url}/embedding-v1",
            headers={
                "Content-Type": "application/json",
            },
            params={"access_token": self.access_token},
            json=json,
        )
        return resp.json()

    def _refresh_access_token_with_lock(self) -> None:
        with self._lock:
            logger.debug("Refreshing access token")
            base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
            resp = requests.post(
                base_url,
                headers={
                    "Content-Type": "application/json",
                    "Accept": "application/json",
                },
                params={
                    "grant_type": "client_credentials",
                    "client_id": self.ernie_client_id,
                    "client_secret": self.ernie_client_secret,
                },
            )
            self.access_token = str(resp.json().get("access_token"))

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed search docs.

        Args:
            texts: The list of texts to embed

        Returns:
            List[List[float]]: List of embeddings, one for each text.
        """

        if not self.access_token:
            self._refresh_access_token_with_lock()
        text_in_chunks = [
            texts[i : i + self.chunk_size]
            for i in range(0, len(texts), self.chunk_size)
        ]
        lst = []
        for chunk in text_in_chunks:
            resp = self._embedding({"input": [text for text in chunk]})
            if resp.get("error_code"):
                if resp.get("error_code") == 111:
                    self._refresh_access_token_with_lock()
                    resp = self._embedding({"input": [text for text in chunk]})
                else:
                    raise ValueError(f"Error from Ernie: {resp}")
            lst.extend([i["embedding"] for i in resp["data"]])
        return lst

    def embed_query(self, text: str) -> List[float]:
        """Embed query text.

        Args:
            text: The text to embed.

        Returns:
            List[float]: Embeddings for the text.
        """

        if not self.access_token:
            self._refresh_access_token_with_lock()
        resp = self._embedding({"input": [text]})
        if resp.get("error_code"):
            if resp.get("error_code") == 111:
                self._refresh_access_token_with_lock()
                resp = self._embedding({"input": [text]})
            else:
                raise ValueError(f"Error from Ernie: {resp}")
        return resp["data"][0]["embedding"]

    async def aembed_query(self, text: str) -> List[float]:
        """Asynchronous Embed query text.

        Args:
            text: The text to embed.

        Returns:
            List[float]: Embeddings for the text.
        """

        return await run_in_executor(None, self.embed_query, text)

    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
        """Asynchronous Embed search docs.

        Args:
            texts: The list of texts to embed

        Returns:
            List[List[float]]: List of embeddings, one for each text.
        """

        result = await asyncio.gather(*[self.aembed_query(text) for text in texts])

        return list(result)
