from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

if TYPE_CHECKING:
    from transwarp_hippo_api.hippo_client import HippoClient

# Default connection
DEFAULT_HIPPO_CONNECTION = {
    "host": "localhost",
    "port": "7788",
    "username": "admin",
    "password": "admin",
}

logger = logging.getLogger(__name__)


class Hippo(VectorStore):
    """`Hippo` vector store.

    You need to install `hippo-api` and run Hippo.

    Please visit our official website for how to run a Hippo instance:
    https://www.transwarp.cn/starwarp

    Args:
        embedding_function (Embeddings): Function used to embed the text.
        table_name (str): Which Hippo table to use. Defaults to
            "test".
        database_name (str): Which Hippo database to use. Defaults to
            "default".
        number_of_shards (int): The number of shards for the Hippo table.Defaults to
            1.
        number_of_replicas (int): The number of replicas for the Hippo table.Defaults to
            1.
        connection_args (Optional[dict[str, any]]): The connection args used for
            this class comes in the form of a dict.
        index_params (Optional[dict]): Which index params to use. Defaults to
            IVF_FLAT.
        drop_old (Optional[bool]): Whether to drop the current collection. Defaults
            to False.
        primary_field (str): Name of the primary key field. Defaults to "pk".
        text_field (str): Name of the text field. Defaults to "text".
        vector_field (str): Name of the vector field. Defaults to "vector".

    The connection args used for this class comes in the form of a dict,
    here are a few of the options:
        host (str): The host of Hippo instance. Default at "localhost".
        port (str/int): The port of Hippo instance. Default at 7788.
        user (str): Use which user to connect to Hippo instance. If user and
            password are provided, we will add related header in every RPC call.
        password (str): Required when user is provided. The password
            corresponding to the user.

    Example:
        .. code-block:: python

        from langchain_community.vectorstores import Hippo
        from langchain_community.embeddings import OpenAIEmbeddings

        embedding = OpenAIEmbeddings()
        # Connect to a hippo instance on localhost
        vector_store = Hippo.from_documents(
            docs,
            embedding=embeddings,
            table_name="langchain_test",
            connection_args=HIPPO_CONNECTION
        )

    Raises:
        ValueError: If the hippo-api python package is not installed.
    """

    def __init__(
        self,
        embedding_function: Embeddings,
        table_name: str = "test",
        database_name: str = "default",
        number_of_shards: int = 1,
        number_of_replicas: int = 1,
        connection_args: Optional[Dict[str, Any]] = None,
        index_params: Optional[dict] = None,
        drop_old: Optional[bool] = False,
    ):
        self.number_of_shards = number_of_shards
        self.number_of_replicas = number_of_replicas
        self.embedding_func = embedding_function
        self.table_name = table_name
        self.database_name = database_name
        self.index_params = index_params

        # In order for a collection to be compatible,
        # 'pk' should be an auto-increment primary key and string
        self._primary_field = "pk"
        # In order for compatibility, the text field will need to be called "text"
        self._text_field = "text"
        # In order for compatibility, the vector field needs to be called "vector"
        self._vector_field = "vector"
        self.fields: List[str] = []
        # Create the connection to the server
        if connection_args is None:
            connection_args = DEFAULT_HIPPO_CONNECTION
        self.hc = self._create_connection_alias(connection_args)
        self.col: Any = None

        # If the collection exists, delete it
        try:
            if (
                self.hc.check_table_exists(self.table_name, self.database_name)
                and drop_old
            ):
                self.hc.delete_table(self.table_name, self.database_name)
        except Exception as e:
            logging.error(
                f"An error occurred while deleting the table " f"{self.table_name}: {e}"
            )
            raise

        try:
            if self.hc.check_table_exists(self.table_name, self.database_name):
                self.col = self.hc.get_table(self.table_name, self.database_name)
        except Exception as e:
            logging.error(
                f"An error occurred while getting the table " f"{self.table_name}: {e}"
            )
            raise

        # Initialize the vector database
        self._get_env()

    def _create_connection_alias(self, connection_args: dict) -> HippoClient:
        """Create the connection to the Hippo server."""
        # Grab the connection arguments that are used for checking existing connection
        try:
            from transwarp_hippo_api.hippo_client import HippoClient
        except ImportError as e:
            raise ImportError(
                "Unable to import transwarp_hipp_api, please install with "
                "`pip install hippo-api`."
            ) from e

        host: str = connection_args.get("host", None)
        port: int = connection_args.get("port", None)
        username: str = connection_args.get("username", "shiva")
        password: str = connection_args.get("password", "shiva")

        # Order of use is host/port, uri, address
        if host is not None and port is not None:
            if "," in host:
                hosts = host.split(",")
                given_address = ",".join([f"{h}:{port}" for h in hosts])
            else:
                given_address = str(host) + ":" + str(port)
        else:
            raise ValueError("Missing standard address type for reuse attempt")

        try:
            logger.info(f"create HippoClient[{given_address}]")
            return HippoClient([given_address], username=username, pwd=password)
        except Exception as e:
            logger.error("Failed to create new connection")
            raise e

    def _get_env(
        self, embeddings: Optional[list] = None, metadatas: Optional[List[dict]] = None
    ) -> None:
        logger.info("init ...")
        if embeddings is not None:
            logger.info("create collection")
            self._create_collection(embeddings, metadatas)
        self._extract_fields()
        self._create_index()

    def _create_collection(
        self, embeddings: list, metadatas: Optional[List[dict]] = None
    ) -> None:
        from transwarp_hippo_api.hippo_client import HippoField
        from transwarp_hippo_api.hippo_type import HippoType

        # Determine embedding dim
        dim = len(embeddings[0])
        logger.debug(f"[_create_collection] dim: {dim}")
        fields = []

        # Create the primary key field
        fields.append(HippoField(self._primary_field, True, HippoType.STRING))

        # Create the text field

        fields.append(HippoField(self._text_field, False, HippoType.STRING))

        # Create the vector field, supports binary or float vectors
        # to The binary vector type is to be developed.
        fields.append(
            HippoField(
                self._vector_field,
                False,
                HippoType.FLOAT_VECTOR,
                type_params={"dimension": dim},
            )
        )
        # to In Hippo,there is no method similar to the infer_type_data
        # types, so currently all non-vector data is converted to string type.

        if metadatas:
            #     # Create FieldSchema for each entry in metadata.
            for key, value in metadatas[0].items():
                #         # Infer the corresponding datatype of the metadata
                if isinstance(value, list):
                    value_dim = len(value)
                    fields.append(
                        HippoField(
                            key,
                            False,
                            HippoType.FLOAT_VECTOR,
                            type_params={"dimension": value_dim},
                        )
                    )
                else:
                    fields.append(HippoField(key, False, HippoType.STRING))

        logger.debug(f"[_create_collection] fields: {fields}")

        # Create the collection
        self.hc.create_table(
            name=self.table_name,
            auto_id=True,
            fields=fields,
            database_name=self.database_name,
            number_of_shards=self.number_of_shards,
            number_of_replicas=self.number_of_replicas,
        )
        self.col = self.hc.get_table(self.table_name, self.database_name)
        logger.info(
            f"[_create_collection] : "
            f"create table {self.table_name} in {self.database_name} successfully"
        )

    def _extract_fields(self) -> None:
        """Grab the existing fields from the Collection"""
        from transwarp_hippo_api.hippo_client import HippoTable

        if isinstance(self.col, HippoTable):
            schema = self.col.schema
            logger.debug(f"[_extract_fields] schema:{schema}")
            for x in schema:
                self.fields.append(x.name)
            logger.debug(f"04 [_extract_fields] fields:{self.fields}")

    # TO CAN: Translated into English, your statement would be: "Currently,
    # only the field named 'vector' (the automatically created vector field)
    # is checked for indexing. Indexes need to be created manually for other
    # vector type columns.
    def _get_index(self) -> Optional[Dict[str, Any]]:
        """Return the vector index information if it exists"""
        from transwarp_hippo_api.hippo_client import HippoTable

        if isinstance(self.col, HippoTable):
            table_info = self.hc.get_table_info(
                self.table_name, self.database_name
            ).get(self.table_name, {})
            embedding_indexes = table_info.get("embedding_indexes", None)
            if embedding_indexes is None:
                return None
            else:
                for x in self.hc.get_table_info(self.table_name, self.database_name)[
                    self.table_name
                ]["embedding_indexes"]:
                    logger.debug(f"[_get_index] embedding_indexes {embedding_indexes}")
                    if x["column"] == self._vector_field:
                        return x
        return None

    # TO Indexes can only be created for the self._vector_field field.
    def _create_index(self) -> None:
        """Create a index on the collection"""
        from transwarp_hippo_api.hippo_client import HippoTable
        from transwarp_hippo_api.hippo_type import IndexType, MetricType

        if isinstance(self.col, HippoTable) and self._get_index() is None:
            if self._get_index() is None:
                if self.index_params is None:
                    self.index_params = {
                        "index_name": "langchain_auto_create",
                        "metric_type": MetricType.L2,
                        "index_type": IndexType.IVF_FLAT,
                        "nlist": 10,
                    }

                    self.col.create_index(
                        self._vector_field,
                        self.index_params["index_name"],
                        self.index_params["index_type"],
                        self.index_params["metric_type"],
                        nlist=self.index_params["nlist"],
                    )
                    logger.debug(
                        self.col.activate_index(self.index_params["index_name"])
                    )
                    logger.info("create index successfully")
                else:
                    index_dict = {
                        "IVF_FLAT": IndexType.IVF_FLAT,
                        "FLAT": IndexType.FLAT,
                        "IVF_SQ": IndexType.IVF_SQ,
                        "IVF_PQ": IndexType.IVF_PQ,
                        "HNSW": IndexType.HNSW,
                    }

                    metric_dict = {
                        "ip": MetricType.IP,
                        "IP": MetricType.IP,
                        "l2": MetricType.L2,
                        "L2": MetricType.L2,
                    }
                    self.index_params["metric_type"] = metric_dict[
                        self.index_params["metric_type"]
                    ]

                    if self.index_params["index_type"] == "FLAT":
                        self.index_params["index_type"] = index_dict[
                            self.index_params["index_type"]
                        ]
                        self.col.create_index(
                            self._vector_field,
                            self.index_params["index_name"],
                            self.index_params["index_type"],
                            self.index_params["metric_type"],
                        )
                        logger.debug(
                            self.col.activate_index(self.index_params["index_name"])
                        )
                    elif (
                        self.index_params["index_type"] == "IVF_FLAT"
                        or self.index_params["index_type"] == "IVF_SQ"
                    ):
                        self.index_params["index_type"] = index_dict[
                            self.index_params["index_type"]
                        ]
                        self.col.create_index(
                            self._vector_field,
                            self.index_params["index_name"],
                            self.index_params["index_type"],
                            self.index_params["metric_type"],
                            nlist=self.index_params.get("nlist", 10),
                            nprobe=self.index_params.get("nprobe", 10),
                        )
                        logger.debug(
                            self.col.activate_index(self.index_params["index_name"])
                        )
                    elif self.index_params["index_type"] == "IVF_PQ":
                        self.index_params["index_type"] = index_dict[
                            self.index_params["index_type"]
                        ]
                        self.col.create_index(
                            self._vector_field,
                            self.index_params["index_name"],
                            self.index_params["index_type"],
                            self.index_params["metric_type"],
                            nlist=self.index_params.get("nlist", 10),
                            nprobe=self.index_params.get("nprobe", 10),
                            nbits=self.index_params.get("nbits", 8),
                            m=self.index_params.get("m"),
                        )
                        logger.debug(
                            self.col.activate_index(self.index_params["index_name"])
                        )
                    elif self.index_params["index_type"] == "HNSW":
                        self.index_params["index_type"] = index_dict[
                            self.index_params["index_type"]
                        ]
                        self.col.create_index(
                            self._vector_field,
                            self.index_params["index_name"],
                            self.index_params["index_type"],
                            self.index_params["metric_type"],
                            M=self.index_params.get("M"),
                            ef_construction=self.index_params.get("ef_construction"),
                            ef_search=self.index_params.get("ef_search"),
                        )
                        logger.debug(
                            self.col.activate_index(self.index_params["index_name"])
                        )
                    else:
                        raise ValueError(
                            "Index name does not match, "
                            "please enter the correct index name. "
                            "(FLAT, IVF_FLAT, IVF_PQ,IVF_SQ, HNSW)"
                        )

    def add_texts(
        self,
        texts: Iterable[str],
        metadatas: Optional[List[dict]] = None,
        timeout: Optional[int] = None,
        batch_size: int = 1000,
        **kwargs: Any,
    ) -> List[str]:
        """
        Add text to the collection.

        Args:
            texts: An iterable that contains the text to be added.
            metadatas: An optional list of dictionaries,
            each dictionary contains the metadata associated with a text.
            timeout: Optional timeout, in seconds.
            batch_size: The number of texts inserted in each batch, defaults to 1000.
            **kwargs: Other optional parameters.

        Returns:
            A list of strings, containing the unique identifiers of the inserted texts.

        Note:
            If the collection has not yet been created,
            this method will create a new collection.
        """
        from transwarp_hippo_api.hippo_client import HippoTable

        if not texts or all(t == "" for t in texts):
            logger.debug("Nothing to insert, skipping.")
            return []
        texts = list(texts)

        logger.debug(f"[add_texts] texts: {texts}")

        try:
            embeddings = self.embedding_func.embed_documents(texts)
        except NotImplementedError:
            embeddings = [self.embedding_func.embed_query(x) for x in texts]

        if len(embeddings) == 0:
            logger.debug("Nothing to insert, skipping.")
            return []

        logger.debug(f"[add_texts] len_embeddings:{len(embeddings)}")

        # 如果还没有创建collection则创建collection
        if not isinstance(self.col, HippoTable):
            self._get_env(embeddings, metadatas)

        # Dict to hold all insert columns
        insert_dict: Dict[str, list] = {
            self._text_field: texts,
            self._vector_field: embeddings,
        }
        logger.debug(f"[add_texts] metadatas:{metadatas}")
        logger.debug(f"[add_texts] fields:{self.fields}")
        if metadatas is not None:
            for d in metadatas:
                for key, value in d.items():
                    if key in self.fields:
                        insert_dict.setdefault(key, []).append(value)

        logger.debug(insert_dict[self._text_field])

        # Total insert count
        vectors: list = insert_dict[self._vector_field]
        total_count = len(vectors)

        if "pk" in self.fields:
            self.fields.remove("pk")

        logger.debug(f"[add_texts] total_count:{total_count}")
        for i in range(0, total_count, batch_size):
            # Grab end index
            end = min(i + batch_size, total_count)
            # Convert dict to list of lists batch for insertion
            insert_list = [insert_dict[x][i:end] for x in self.fields]
            try:
                res = self.col.insert_rows(insert_list)
                logger.info(f"05 [add_texts] insert {res}")
            except Exception as e:
                logger.error(
                    "Failed to insert batch starting at entity: %s/%s", i, total_count
                )
                raise e
        return [""]

    def similarity_search(
        self,
        query: str,
        k: int = 4,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """
        Perform a similarity search on the query string.

        Args:
            query (str): The text to search for.
            k (int, optional): The number of results to return. Default is 4.
            param (dict, optional): Specifies the search parameters for the index.
            Defaults to None.
            expr (str, optional): Filtering expression. Defaults to None.
            timeout (int, optional): Time to wait before a timeout error.
            Defaults to None.
            kwargs: Keyword arguments for Collection.search().

        Returns:
            List[Document]: The document results of the search.
        """

        if self.col is None:
            logger.debug("No existing collection to search.")
            return []
        res = self.similarity_search_with_score(
            query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
        )
        return [doc for doc, _ in res]

    def similarity_search_with_score(
        self,
        query: str,
        k: int = 4,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """
        Performs a search on the query string and returns results with scores.

        Args:
            query (str): The text being searched.
            k (int, optional): The number of results to return.
            Default is 4.
            param (dict): Specifies the search parameters for the index.
            Default is None.
            expr (str, optional): Filtering expression. Default is None.
            timeout (int, optional): The waiting time before a timeout error.
            Default is None.
            kwargs: Keyword arguments for Collection.search().

        Returns:
            List[float], List[Tuple[Document, any, any]]:
        """

        if self.col is None:
            logger.debug("No existing collection to search.")
            return []

        # Embed the query text.
        embedding = self.embedding_func.embed_query(query)

        ret = self.similarity_search_with_score_by_vector(
            embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
        )
        return ret

    def similarity_search_with_score_by_vector(
        self,
        embedding: List[float],
        k: int = 4,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """
        Performs a search on the query string and returns results with scores.

        Args:
            embedding (List[float]): The embedding vector being searched.
            k (int, optional): The number of results to return.
            Default is 4.
            param (dict): Specifies the search parameters for the index.
            Default is None.
            expr (str, optional): Filtering expression. Default is None.
            timeout (int, optional): The waiting time before a timeout error.
            Default is None.
            kwargs: Keyword arguments for Collection.search().

        Returns:
            List[Tuple[Document, float]]: Resulting documents and scores.
        """
        if self.col is None:
            logger.debug("No existing collection to search.")
            return []

        # if param is None:
        #     param = self.search_params

        # Determine result metadata fields.
        output_fields = self.fields[:]
        output_fields.remove(self._vector_field)

        # Perform the search.
        logger.debug(f"search_field:{self._vector_field}")
        logger.debug(f"vectors:{[embedding]}")
        logger.debug(f"output_fields:{output_fields}")
        logger.debug(f"topk:{k}")
        logger.debug(f"dsl:{expr}")

        res = self.col.query(
            search_field=self._vector_field,
            vectors=[embedding],
            output_fields=output_fields,
            topk=k,
            dsl=expr,
        )
        # Organize results.
        logger.debug(f"[similarity_search_with_score_by_vector] res:{res}")
        score_col = self._text_field + "%scores"
        ret = []
        count = 0
        for items in zip(*[res[0][field] for field in output_fields]):
            meta = {field: value for field, value in zip(output_fields, items)}
            doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
            logger.debug(
                f"[similarity_search_with_score_by_vector] "
                f"res[0][score_col]:{res[0][score_col]}"
            )
            score = res[0][score_col][count]
            count += 1
            ret.append((doc, score))

        return ret

    @classmethod
    def from_texts(
        cls,
        texts: List[str],
        embedding: Embeddings,
        metadatas: Optional[List[dict]] = None,
        table_name: str = "test",
        database_name: str = "default",
        connection_args: Dict[str, Any] = DEFAULT_HIPPO_CONNECTION,
        index_params: Optional[Dict[Any, Any]] = None,
        search_params: Optional[Dict[str, Any]] = None,
        drop_old: bool = False,
        **kwargs: Any,
    ) -> "Hippo":
        """
        Creates an instance of the VST class from the given texts.

        Args:
            texts (List[str]): List of texts to be added.
            embedding (Embeddings): Embedding model for the texts.
            metadatas (List[dict], optional):
            List of metadata dictionaries for each text.Defaults to None.
            table_name (str): Name of the table. Defaults to "test".
            database_name (str): Name of the database. Defaults to "default".
            connection_args (dict[str, Any]): Connection parameters.
            Defaults to DEFAULT_HIPPO_CONNECTION.
            index_params (dict): Indexing parameters. Defaults to None.
            search_params (dict): Search parameters. Defaults to an empty dictionary.
            drop_old (bool): Whether to drop the old collection. Defaults to False.
            kwargs: Other arguments.

        Returns:
            Hippo: An instance of the VST class.
        """

        if search_params is None:
            search_params = {}
        logger.info("00 [from_texts] init the class of Hippo")
        vector_db = cls(
            embedding_function=embedding,
            table_name=table_name,
            database_name=database_name,
            connection_args=connection_args,
            index_params=index_params,
            drop_old=drop_old,
            **kwargs,
        )
        logger.debug(f"[from_texts] texts:{texts}")
        logger.debug(f"[from_texts] metadatas:{metadatas}")
        vector_db.add_texts(texts=texts, metadatas=metadatas)
        return vector_db
