from __future__ import annotations

import json
import logging
import uuid
from typing import Any, Iterable, List, Optional, Type

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore

logger = logging.getLogger(__name__)


def _uuid_key() -> str:
    return uuid.uuid4().hex


class Tair(VectorStore):
    """`Tair` vector store."""

    def __init__(
        self,
        embedding_function: Embeddings,
        url: str,
        index_name: str,
        content_key: str = "content",
        metadata_key: str = "metadata",
        search_params: Optional[dict] = None,
        **kwargs: Any,
    ):
        self.embedding_function = embedding_function
        self.index_name = index_name
        try:
            from tair import Tair as TairClient
        except ImportError:
            raise ImportError(
                "Could not import tair python package. "
                "Please install it with `pip install tair`."
            )
        try:
            # connect to tair from url
            client = TairClient.from_url(url, **kwargs)
        except ValueError as e:
            raise ValueError(f"Tair failed to connect: {e}")

        self.client = client
        self.content_key = content_key
        self.metadata_key = metadata_key
        self.search_params = search_params

    @property
    def embeddings(self) -> Embeddings:
        return self.embedding_function

    def create_index_if_not_exist(
        self,
        dim: int,
        distance_type: str,
        index_type: str,
        data_type: str,
        **kwargs: Any,
    ) -> bool:
        index = self.client.tvs_get_index(self.index_name)
        if index is not None:
            logger.info("Index already exists")
            return False
        self.client.tvs_create_index(
            self.index_name,
            dim,
            distance_type,
            index_type,
            data_type,
            **kwargs,
        )
        return True

    def add_texts(
        self,
        texts: Iterable[str],
        metadatas: Optional[List[dict]] = None,
        **kwargs: Any,
    ) -> List[str]:
        """Add texts data to an existing index."""
        ids = []
        keys = kwargs.get("keys", None)
        use_hybrid_search = False
        index = self.client.tvs_get_index(self.index_name)
        if index is not None and index.get("lexical_algorithm") == "bm25":
            use_hybrid_search = True
        # Write data to tair
        pipeline = self.client.pipeline(transaction=False)
        embeddings = self.embedding_function.embed_documents(list(texts))
        for i, text in enumerate(texts):
            # Use provided key otherwise use default key
            key = keys[i] if keys else _uuid_key()
            metadata = metadatas[i] if metadatas else {}
            if use_hybrid_search:
                # tair use TEXT attr hybrid search
                pipeline.tvs_hset(
                    self.index_name,
                    key,
                    embeddings[i],
                    False,
                    **{
                        "TEXT": text,
                        self.content_key: text,
                        self.metadata_key: json.dumps(metadata),
                    },
                )
            else:
                pipeline.tvs_hset(
                    self.index_name,
                    key,
                    embeddings[i],
                    False,
                    **{
                        self.content_key: text,
                        self.metadata_key: json.dumps(metadata),
                    },
                )
            ids.append(key)
        pipeline.execute()
        return ids

    def similarity_search(
        self, query: str, k: int = 4, **kwargs: Any
    ) -> List[Document]:
        """
        Returns the most similar indexed documents to the query text.

        Args:
            query (str): The query text for which to find similar documents.
            k (int): The number of documents to return. Default is 4.

        Returns:
            List[Document]: A list of documents that are most similar to the query text.
        """
        # Creates embedding vector from user query
        embedding = self.embedding_function.embed_query(query)

        keys_and_scores = self.client.tvs_knnsearch(
            self.index_name, k, embedding, False, None, **kwargs
        )

        pipeline = self.client.pipeline(transaction=False)
        for key, _ in keys_and_scores:
            pipeline.tvs_hmget(
                self.index_name, key, self.metadata_key, self.content_key
            )
        docs = pipeline.execute()

        return [
            Document(
                page_content=d[1],
                metadata=json.loads(d[0]),
            )
            for d in docs
        ]

    @classmethod
    def from_texts(
        cls: Type[Tair],
        texts: List[str],
        embedding: Embeddings,
        metadatas: Optional[List[dict]] = None,
        index_name: str = "langchain",
        content_key: str = "content",
        metadata_key: str = "metadata",
        **kwargs: Any,
    ) -> Tair:
        try:
            from tair import tairvector
        except ImportError:
            raise ImportError(
                "Could not import tair python package. "
                "Please install it with `pip install tair`."
            )
        url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")
        if "tair_url" in kwargs:
            kwargs.pop("tair_url")

        distance_type = tairvector.DistanceMetric.InnerProduct
        if "distance_type" in kwargs:
            distance_type = kwargs.pop("distance_type")
        index_type = tairvector.IndexType.HNSW
        if "index_type" in kwargs:
            index_type = kwargs.pop("index_type")
        data_type = tairvector.DataType.Float32
        if "data_type" in kwargs:
            data_type = kwargs.pop("data_type")
        index_params = {}
        if "index_params" in kwargs:
            index_params = kwargs.pop("index_params")
        search_params = {}
        if "search_params" in kwargs:
            search_params = kwargs.pop("search_params")

        keys = None
        if "keys" in kwargs:
            keys = kwargs.pop("keys")
        try:
            tair_vector_store = cls(
                embedding,
                url,
                index_name,
                content_key=content_key,
                metadata_key=metadata_key,
                search_params=search_params,
                **kwargs,
            )
        except ValueError as e:
            raise ValueError(f"tair failed to connect: {e}")

        # Create embeddings for documents
        embeddings = embedding.embed_documents(texts)

        tair_vector_store.create_index_if_not_exist(
            len(embeddings[0]),
            distance_type,
            index_type,
            data_type,
            **index_params,
        )

        tair_vector_store.add_texts(texts, metadatas, keys=keys)
        return tair_vector_store

    @classmethod
    def from_documents(
        cls,
        documents: List[Document],
        embedding: Embeddings,
        metadatas: Optional[List[dict]] = None,
        index_name: str = "langchain",
        content_key: str = "content",
        metadata_key: str = "metadata",
        **kwargs: Any,
    ) -> Tair:
        texts = [d.page_content for d in documents]
        metadatas = [d.metadata for d in documents]

        return cls.from_texts(
            texts, embedding, metadatas, index_name, content_key, metadata_key, **kwargs
        )

    @staticmethod
    def drop_index(
        index_name: str = "langchain",
        **kwargs: Any,
    ) -> bool:
        """
        Drop an existing index.

        Args:
            index_name (str): Name of the index to drop.

        Returns:
            bool: True if the index is dropped successfully.
        """
        try:
            from tair import Tair as TairClient
        except ImportError:
            raise ImportError(
                "Could not import tair python package. "
                "Please install it with `pip install tair`."
            )
        url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")

        try:
            if "tair_url" in kwargs:
                kwargs.pop("tair_url")
            client = TairClient.from_url(url=url, **kwargs)
        except ValueError as e:
            raise ValueError(f"Tair connection error: {e}")
        # delete index
        ret = client.tvs_del_index(index_name)
        if ret == 0:
            # index not exist
            logger.info("Index does not exist")
            return False
        return True

    @classmethod
    def from_existing_index(
        cls,
        embedding: Embeddings,
        index_name: str = "langchain",
        content_key: str = "content",
        metadata_key: str = "metadata",
        **kwargs: Any,
    ) -> Tair:
        """Connect to an existing Tair index."""
        url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")

        search_params = {}
        if "search_params" in kwargs:
            search_params = kwargs.pop("search_params")

        return cls(
            embedding,
            url,
            index_name,
            content_key=content_key,
            metadata_key=metadata_key,
            search_params=search_params,
            **kwargs,
        )
