from __future__ import annotations

import logging
import uuid
from typing import Any, List, Optional, Tuple

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

logger = logging.getLogger(__name__)


class VikingDBConfig(object):
    """vikingdb connection config

    See the following documentation for details:
    https://www.volcengine.com/docs/6459/1167770

    Attribute:
        host(str):The access address of the vector database server
            that the client needs to connect to.
        region(str):"cn-shanghai" or "cn-beijing"
        ak(str):Access Key ID, security credentials for accessing
            Volcano Engine services.
        sk(str):Secret Access Key, security credentials for accessing
            Volcano Engine services.
        scheme(str):http or https, defaulting to http.
    """

    def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"):  # type: ignore[no-untyped-def]
        self.host = host
        self.region = region
        self.ak = ak
        self.sk = sk
        self.scheme = scheme


class VikingDB(VectorStore):
    """vikingdb as a vector store

    In order to use this you need to have a database instance.
    See the following documentation for details:
    https://www.volcengine.com/docs/6459/1167774
    """

    def __init__(
        self,
        embedding_function: Embeddings,
        collection_name: str = "LangChainCollection",
        connection_args: Optional[VikingDBConfig] = None,
        index_params: Optional[dict] = None,
        drop_old: Optional[bool] = False,
        **kwargs: Any,
    ):
        try:
            from volcengine.viking_db import Collection, VikingDBService
        except ImportError:
            raise ImportError(
                "Could not import volcengine python package. "
                "Please install it with `pip install --upgrade volcengine`."
            )
        self.embedding_func = embedding_function
        self.collection_name = collection_name
        self.index_name = "LangChainIndex"
        self.connection_args = connection_args
        self.index_params = index_params
        self.drop_old = drop_old
        self.service = VikingDBService(
            connection_args.host,  # type: ignore[union-attr]
            connection_args.region,  # type: ignore[union-attr]
            connection_args.ak,  # type: ignore[union-attr]
            connection_args.sk,  # type: ignore[union-attr]
            connection_args.scheme,  # type: ignore[union-attr]
        )

        try:
            col = self.service.get_collection(collection_name)
        except Exception:
            col = None
        self.collection = col
        self.index = None
        if self.collection is not None:
            self.index = self.service.get_index(self.collection_name, self.index_name)

        if drop_old and isinstance(self.collection, Collection):
            indexes = self.service.list_indexes(collection_name)
            for index in indexes:
                self.service.drop_index(collection_name, index.index_name)
            self.service.drop_collection(collection_name)
            self.collection = None
            self.index = None

    @property
    def embeddings(self) -> Embeddings:
        return self.embedding_func

    def _create_collection(
        self, embeddings: List, metadatas: Optional[List[dict]] = None
    ) -> None:
        try:
            from volcengine.viking_db import Field, FieldType
        except ImportError:
            raise ImportError(
                "Could not import volcengine python package. "
                "Please install it with `pip install --upgrade volcengine`."
            )
        dim = len(embeddings[0])
        fields = []
        if metadatas:
            for key, value in metadatas[0].items():
                # print(key, value)
                if isinstance(value, str):
                    fields.append(Field(key, FieldType.String))
                elif isinstance(value, int):
                    fields.append(Field(key, FieldType.Int64))
                elif isinstance(value, bool):
                    fields.append(Field(key, FieldType.Bool))
                elif isinstance(value, list) and all(
                    isinstance(item, str) for item in value
                ):
                    fields.append(Field(key, FieldType.List_String))
                elif isinstance(value, list) and all(
                    isinstance(item, int) for item in value
                ):
                    fields.append(Field(key, FieldType.List_Int64))
                elif isinstance(value, bytes):
                    fields.append(Field(key, FieldType.Text))
                else:
                    raise ValueError(
                        "metadatas value is invalid"
                        "please change the type of metadatas."
                    )
        # fields.append(Field("text", FieldType.String))
        fields.append(Field("text", FieldType.Text))

        fields.append(Field("primary_key", FieldType.String, is_primary_key=True))

        fields.append(Field("vector", FieldType.Vector, dim=dim))

        self.collection = self.service.create_collection(self.collection_name, fields)

    def _create_index(self) -> None:
        try:
            from volcengine.viking_db import VectorIndexParams
        except ImportError:
            raise ImportError(
                "Could not import volcengine python package. "
                "Please install it with `pip install --upgrade volcengine`."
            )
        cpu_quota = 2
        vector_index = VectorIndexParams()
        partition_by = ""
        scalar_index = None
        if self.index_params is not None:
            if self.index_params.get("cpu_quota") is not None:
                cpu_quota = self.index_params["cpu_quota"]
            if self.index_params.get("vector_index") is not None:
                vector_index = self.index_params["vector_index"]
            if self.index_params.get("partition_by") is not None:
                partition_by = self.index_params["partition_by"]
            if self.index_params.get("scalar_index") is not None:
                scalar_index = self.index_params["scalar_index"]

        self.index = self.service.create_index(
            self.collection_name,
            self.index_name,
            vector_index=vector_index,
            cpu_quota=cpu_quota,
            partition_by=partition_by,
            scalar_index=scalar_index,
        )

    def add_texts(  # type: ignore[override]
        self,
        texts: List[str],
        metadatas: Optional[List[dict]] = None,
        batch_size: int = 1000,
        **kwargs: Any,
    ) -> List[str]:
        """Insert text data into VikingDB."""
        try:
            from volcengine.viking_db import Data
        except ImportError:
            raise ImportError(
                "Could not import volcengine python package. "
                "Please install it with `pip install --upgrade volcengine`."
            )
        texts = list(texts)
        try:
            embeddings = self.embedding_func.embed_documents(texts)
        except NotImplementedError:
            embeddings = [self.embedding_func.embed_query(x) for x in texts]
        if len(embeddings) == 0:
            logger.debug("Nothing to insert, skipping.")
            return []
        if self.collection is None:
            self._create_collection(embeddings, metadatas)
            self._create_index()

        # insert data
        data = []
        pks: List[str] = []
        for index in range(len(embeddings)):
            primary_key = str(uuid.uuid4())
            pks.append(primary_key)
            field = {
                "text": texts[index],
                "primary_key": primary_key,
                "vector": embeddings[index],
            }
            if metadatas is not None and index < len(metadatas):
                names = list(metadatas[index].keys())
                for name in names:
                    field[name] = metadatas[index].get(name)  # type: ignore[assignment]
            data.append(Data(field))

        total_count = len(data)
        for i in range(0, total_count, batch_size):
            end = min(i + batch_size, total_count)
            insert_data = data[i:end]
            # print(insert_data)
            self.collection.upsert_data(insert_data)  # type: ignore[union-attr]
        return pks

    def similarity_search(  # type: ignore[override]
        self,
        query: str,
        params: Optional[dict] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a similarity search against the query string."""
        res = self.similarity_search_with_score(query=query, params=params, **kwargs)
        return [doc for doc, _ in res]

    def similarity_search_with_score(
        self,
        query: str,
        params: Optional[dict] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Perform a search on a query string and return results with score."""
        embedding = self.embedding_func.embed_query(query)

        res = self.similarity_search_with_score_by_vector(
            embedding=embedding, params=params, **kwargs
        )
        return res

    def similarity_search_by_vector(  # type: ignore[override]
        self,
        embedding: List[float],
        params: Optional[dict] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a similarity search against the query string."""
        res = self.similarity_search_with_score_by_vector(
            embedding=embedding, params=params, **kwargs
        )
        return [doc for doc, _ in res]

    def similarity_search_with_score_by_vector(
        self,
        embedding: List[float],
        params: Optional[dict] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Perform a search on a query string and return results with score."""
        if self.collection is None:
            logger.debug("No existing collection to search.")
            return []

        filter = None
        limit = 10
        output_fields = None
        partition = "default"
        if params is not None:
            if params.get("filter") is not None:
                filter = params["filter"]
            if params.get("limit") is not None:
                limit = params["limit"]
            if params.get("output_fields") is not None:
                output_fields = params["output_fields"]
            if params.get("partition") is not None:
                partition = params["partition"]

        res = self.index.search_by_vector(  # type: ignore[union-attr]
            embedding,
            filter=filter,
            limit=limit,
            output_fields=output_fields,
            partition=partition,
        )

        ret = []
        for item in res:
            if "primary_key" in item.fields:
                item.fields.pop("primary_key")
            if "vector" in item.fields:
                item.fields.pop("vector")
            page_content = ""
            if "text" in item.fields:
                page_content = item.fields.pop("text")
            doc = Document(page_content=page_content, metadata=item.fields)
            pair = (doc, item.score)
            ret.append(pair)
        return ret

    def max_marginal_relevance_search(  # type: ignore[override]
        self,
        query: str,
        k: int = 4,
        lambda_mult: float = 0.5,
        params: Optional[dict] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a search and return results that are reordered by MMR."""
        embedding = self.embedding_func.embed_query(query)
        return self.max_marginal_relevance_search_by_vector(
            embedding=embedding,
            k=k,
            lambda_mult=lambda_mult,
            params=params,
            **kwargs,
        )

    def max_marginal_relevance_search_by_vector(  # type: ignore[override]
        self,
        embedding: List[float],
        k: int = 4,
        lambda_mult: float = 0.5,
        params: Optional[dict] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a search and return results that are reordered by MMR."""
        if self.collection is None:
            logger.debug("No existing collection to search.")
            return []
        filter = None
        limit = 10
        output_fields = None
        partition = "default"
        if params is not None:
            if params.get("filter") is not None:
                filter = params["filter"]
            if params.get("limit") is not None:
                limit = params["limit"]
            if params.get("output_fields") is not None:
                output_fields = params["output_fields"]
            if params.get("partition") is not None:
                partition = params["partition"]

        res = self.index.search_by_vector(  # type: ignore[union-attr]
            embedding,
            filter=filter,
            limit=limit,
            output_fields=output_fields,
            partition=partition,
        )
        documents = []
        ordered_result_embeddings = []
        for item in res:
            if (
                "vector" not in item.fields
                or "primary_key" not in item.fields
                or "text" not in item.fields
            ):
                continue
            ordered_result_embeddings.append(item.fields.pop("vector"))
            item.fields.pop("primary_key")
            page_content = item.fields.pop("text")
            doc = Document(page_content=page_content, metadata=item.fields)
            documents.append(doc)

        new_ordering = maximal_marginal_relevance(
            np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
        )
        # Reorder the values and return.
        ret = []
        for x in new_ordering:
            # Function can return -1 index
            if x == -1:
                break
            else:
                ret.append(documents[x])
        return ret

    def delete(
        self,
        ids: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> None:
        if self.collection is None:
            logger.debug("No existing collection to search.")
        self.collection.delete_data(ids)  # type: ignore[union-attr]

    @classmethod
    def from_texts(  # type: ignore[no-untyped-def, override]
        cls,
        texts: List[str],
        embedding: Embeddings,
        connection_args: Optional[VikingDBConfig] = None,
        metadatas: Optional[List[dict]] = None,
        collection_name: str = "LangChainCollection",
        index_params: Optional[dict] = None,
        drop_old: bool = False,
        **kwargs: Any,
    ):
        """Create a collection, indexes it and insert data."""
        if connection_args is None:
            raise Exception("VikingDBConfig does not exists")
        vector_db = cls(
            embedding_function=embedding,
            collection_name=collection_name,
            connection_args=connection_args,
            index_params=index_params,
            drop_old=drop_old,
            **kwargs,
        )
        vector_db.add_texts(texts=texts, metadatas=metadatas)
        return vector_db
