"""Wrapper around the Baidu vector database."""

from __future__ import annotations

import json
import logging
import time
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

logger = logging.getLogger(__name__)


class ConnectionParams:
    """Baidu VectorDB Connection params.

    See the following documentation for details:
    https://cloud.baidu.com/doc/VDB/s/6lrsob0wy

    Attribute:
        endpoint (str) : The access address of the vector database server
            that the client needs to connect to.
        api_key (str): API key for client to access the vector database server,
            which is used for authentication.
        account (str) : Account for client to access the vector database server.
        connection_timeout_in_mills (int) : Request Timeout.
    """

    def __init__(
        self,
        endpoint: str,
        api_key: str,
        account: str = "root",
        connection_timeout_in_mills: int = 50 * 1000,
    ):
        self.endpoint = endpoint
        self.api_key = api_key
        self.account = account
        self.connection_timeout_in_mills = connection_timeout_in_mills


class TableParams:
    """Baidu VectorDB table params.

    See the following documentation for details:
    https://cloud.baidu.com/doc/VDB/s/mlrsob0p6
    """

    def __init__(
        self,
        dimension: int,
        replication: int = 3,
        partition: int = 1,
        index_type: str = "HNSW",
        metric_type: str = "L2",
        params: Optional[Dict] = None,
    ):
        self.dimension = dimension
        self.replication = replication
        self.partition = partition
        self.index_type = index_type
        self.metric_type = metric_type
        self.params = params


class BaiduVectorDB(VectorStore):
    """Baidu VectorDB as a vector store.

    In order to use this you need to have a database instance.
    See the following documentation for details:
    https://cloud.baidu.com/doc/VDB/index.html
    """

    field_id: str = "id"
    field_vector: str = "vector"
    field_text: str = "text"
    field_metadata: str = "metadata"

    index_vector: str = "vector_idx"

    def __init__(
        self,
        embedding: Embeddings,
        connection_params: ConnectionParams,
        table_params: TableParams = TableParams(128),
        database_name: str = "LangChainDatabase",
        table_name: str = "LangChainTable",
        drop_old: Optional[bool] = False,
    ):
        pymochow = guard_import("pymochow")
        configuration = guard_import("pymochow.configuration")
        auth = guard_import("pymochow.auth.bce_credentials")
        self.mochowtable = guard_import("pymochow.model.table")
        self.mochowenum = guard_import("pymochow.model.enum")
        self.embedding_func = embedding
        self.table_params = table_params
        config = configuration.Configuration(
            credentials=auth.BceCredentials(
                connection_params.account, connection_params.api_key
            ),
            endpoint=connection_params.endpoint,
            connection_timeout_in_mills=connection_params.connection_timeout_in_mills,
        )
        self.vdb_client = pymochow.MochowClient(config)
        db_list = self.vdb_client.list_databases()
        db_exist: bool = False
        for db in db_list:
            if database_name == db.database_name:
                db_exist = True
                break
        if db_exist:
            self.database = self.vdb_client.database(database_name)
        else:
            self.database = self.vdb_client.create_database(database_name)
        try:
            self.table = self.database.describe_table(table_name)
            if drop_old:
                self.database.drop_table(table_name)
                self._create_table(table_name)
        except pymochow.exception.ServerError:
            self._create_table(table_name)

    def _create_table(self, table_name: str) -> None:
        schema = guard_import("pymochow.model.schema")
        index_type = None
        for k, v in self.mochowenum.IndexType.__members__.items():
            if k == self.table_params.index_type:
                index_type = v
        if index_type is None:
            raise ValueError("unsupported index_type")
        metric_type = None
        for k, v in self.mochowenum.MetricType.__members__.items():
            if k == self.table_params.metric_type:
                metric_type = v
        if metric_type is None:
            raise ValueError("unsupported metric_type")
        if self.table_params.params is None:
            params = schema.HNSWParams(m=16, efconstruction=200)
        else:
            params = schema.HNSWParams(
                m=self.table_params.params.get("M", 16),
                efconstruction=self.table_params.params.get("efConstruction", 200),
            )
        fields = []
        fields.append(
            schema.Field(
                self.field_id,
                self.mochowenum.FieldType.STRING,
                primary_key=True,
                partition_key=True,
                auto_increment=False,
                not_null=True,
            )
        )
        fields.append(
            schema.Field(
                self.field_vector,
                self.mochowenum.FieldType.FLOAT_VECTOR,
                dimension=self.table_params.dimension,
                not_null=True,
            )
        )
        fields.append(schema.Field(self.field_text, self.mochowenum.FieldType.STRING))
        fields.append(
            schema.Field(self.field_metadata, self.mochowenum.FieldType.STRING)
        )
        indexes = []
        indexes.append(
            schema.VectorIndex(
                index_name=self.index_vector,
                index_type=index_type,
                field=self.field_vector,
                metric_type=metric_type,
                params=params,
            )
        )

        self.table = self.database.create_table(
            table_name=table_name,
            replication=self.table_params.replication,
            partition=self.mochowtable.Partition(
                partition_num=self.table_params.partition
            ),
            schema=schema.Schema(fields=fields, indexes=indexes),
        )

        while True:
            time.sleep(1)
            table = self.database.describe_table(table_name)
            if table.state == self.mochowenum.TableState.NORMAL:
                break

    @property
    def embeddings(self) -> Embeddings:
        return self.embedding_func

    @classmethod
    def from_texts(
        cls,
        texts: List[str],
        embedding: Embeddings,
        metadatas: Optional[List[dict]] = None,
        connection_params: Optional[ConnectionParams] = None,
        table_params: Optional[TableParams] = None,
        database_name: str = "LangChainDatabase",
        table_name: str = "LangChainTable",
        drop_old: Optional[bool] = False,
        **kwargs: Any,
    ) -> BaiduVectorDB:
        """Create a table, indexes it with HNSW, and insert data."""
        if len(texts) == 0:
            raise ValueError("texts is empty")
        if connection_params is None:
            raise ValueError("connection_params is empty")
        try:
            embeddings = embedding.embed_documents(texts[0:1])
        except NotImplementedError:
            embeddings = [embedding.embed_query(texts[0])]
        dimension = len(embeddings[0])
        if table_params is None:
            table_params = TableParams(dimension=dimension)
        else:
            table_params.dimension = dimension
        vector_db = cls(
            embedding=embedding,
            connection_params=connection_params,
            table_params=table_params,
            database_name=database_name,
            table_name=table_name,
            drop_old=drop_old,
        )
        vector_db.add_texts(texts=texts, metadatas=metadatas)
        return vector_db

    def add_texts(
        self,
        texts: Iterable[str],
        metadatas: Optional[List[dict]] = None,
        batch_size: int = 1000,
        **kwargs: Any,
    ) -> List[str]:
        """Insert text data into Baidu VectorDB."""
        texts = list(texts)
        try:
            embeddings = self.embedding_func.embed_documents(texts)
        except NotImplementedError:
            embeddings = [self.embedding_func.embed_query(x) for x in texts]
        if len(embeddings) == 0:
            logger.debug("Nothing to insert, skipping.")
            return []
        pks: list[str] = []
        total_count = len(embeddings)
        for start in range(0, total_count, batch_size):
            # Grab end index
            rows = []
            end = min(start + batch_size, total_count)
            for id in range(start, end, 1):
                metadata = "{}"
                if metadatas is not None:
                    metadata = json.dumps(metadatas[id])
                row = self.mochowtable.Row(
                    id="{}-{}-{}".format(time.time_ns(), hash(texts[id]), id),
                    vector=[float(num) for num in embeddings[id]],
                    text=texts[id],
                    metadata=metadata,
                )
                rows.append(row)
                pks.append(str(id))
            self.table.upsert(rows=rows)
        # need rebuild vindex after upsert
        self.table.rebuild_index(self.index_vector)
        while True:
            time.sleep(2)
            index = self.table.describe_index(self.index_vector)
            if index.state == self.mochowenum.IndexState.NORMAL:
                break
        return pks

    def similarity_search(
        self,
        query: str,
        k: int = 4,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a similarity search against the query string."""
        res = self.similarity_search_with_score(
            query=query, k=k, param=param, expr=expr, **kwargs
        )
        return [doc for doc, _ in res]

    def similarity_search_with_score(
        self,
        query: str,
        k: int = 4,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Perform a search on a query string and return results with score."""
        # Embed the query text.
        embedding = self.embedding_func.embed_query(query)
        res = self._similarity_search_with_score(
            embedding=embedding, k=k, param=param, expr=expr, **kwargs
        )
        return res

    def similarity_search_by_vector(
        self,
        embedding: List[float],
        k: int = 4,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a similarity search against the query string."""
        res = self._similarity_search_with_score(
            embedding=embedding, k=k, param=param, expr=expr, **kwargs
        )
        return [doc for doc, _ in res]

    def _similarity_search_with_score(
        self,
        embedding: List[float],
        k: int = 4,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Perform a search on a query string and return results with score."""
        ef = 10 if param is None else param.get("ef", 10)

        anns = self.mochowtable.AnnSearch(
            vector_field=self.field_vector,
            vector_floats=[float(num) for num in embedding],
            params=self.mochowtable.HNSWSearchParams(ef=ef, limit=k),
            filter=expr,
        )
        res = self.table.search(anns=anns)

        rows = [[item] for item in res.rows]
        # Organize results.
        ret: List[Tuple[Document, float]] = []
        if rows is None or len(rows) == 0:
            return ret
        for row in rows:
            for result in row:
                row_data = result.get("row", {})
                meta = row_data.get(self.field_metadata)
                if meta is not None:
                    meta = json.loads(meta)
                doc = Document(
                    page_content=row_data.get(self.field_text), metadata=meta
                )
                pair = (doc, result.get("score", 0.0))
                ret.append(pair)
        return ret

    def max_marginal_relevance_search(
        self,
        query: str,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a search and return results that are reordered by MMR."""
        embedding = self.embedding_func.embed_query(query)
        return self._max_marginal_relevance_search(
            embedding=embedding,
            k=k,
            fetch_k=fetch_k,
            lambda_mult=lambda_mult,
            param=param,
            expr=expr,
            **kwargs,
        )

    def _max_marginal_relevance_search(
        self,
        embedding: list[float],
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a search and return results that are reordered by MMR."""
        ef = 10 if param is None else param.get("ef", 10)
        anns = self.mochowtable.AnnSearch(
            vector_field=self.field_vector,
            vector_floats=[float(num) for num in embedding],
            params=self.mochowtable.HNSWSearchParams(ef=ef, limit=k),
            filter=expr,
        )
        res = self.table.search(anns=anns, retrieve_vector=True)

        # Organize results.
        documents: List[Document] = []
        ordered_result_embeddings = []
        rows = [[item] for item in res.rows]
        if rows is None or len(rows) == 0:
            return documents
        for row in rows:
            for result in row:
                row_data = result.get("row", {})
                meta = row_data.get(self.field_metadata)
                if meta is not None:
                    meta = json.loads(meta)
                doc = Document(
                    page_content=row_data.get(self.field_text), metadata=meta
                )
                documents.append(doc)
                ordered_result_embeddings.append(row_data.get(self.field_vector))
        # Get the new order of results.
        new_ordering = maximal_marginal_relevance(
            np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
        )
        # Reorder the values and return.
        ret = []
        for x in new_ordering:
            # Function can return -1 index
            if x == -1:
                break
            else:
                ret.append(documents[x])
        return ret
