from __future__ import annotations

import logging
from copy import deepcopy
from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

logger = logging.getLogger(__name__)


class Rockset(VectorStore):
    """`Rockset` vector store.

    To use, you should have the `rockset` python package installed. Note that to use
    this, the collection being used must already exist in your Rockset instance.
    You must also ensure you use a Rockset ingest transformation to apply
    `VECTOR_ENFORCE` on the column being used to store `embedding_key` in the
    collection.
    See: https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details

    Everything below assumes `commons` Rockset workspace.

    Example:
        .. code-block:: python

            from langchain_community.vectorstores import Rockset
            from langchain_community.embeddings.openai import OpenAIEmbeddings
            import rockset

            # Make sure you use the right host (region) for your Rockset instance
            # and APIKEY has both read-write access to your collection.

            rs = rockset.RocksetClient(host=rockset.Regions.use1a1, api_key="***")
            collection_name = "langchain_demo"
            embeddings = OpenAIEmbeddings()
            vectorstore = Rockset(rs, collection_name, embeddings,
                "description", "description_embedding")

    """

    def __init__(
        self,
        client: Any,
        embeddings: Embeddings,
        collection_name: str,
        text_key: str,
        embedding_key: str,
        workspace: str = "commons",
    ):
        """Initialize with Rockset client.
        Args:
            client: Rockset client object
            collection: Rockset collection to insert docs / query
            embeddings: Langchain Embeddings object to use to generate
                        embedding for given text.
            text_key: column in Rockset collection to use to store the text
            embedding_key: column in Rockset collection to use to store the embedding.
                           Note: We must apply `VECTOR_ENFORCE()` on this column via
                           Rockset ingest transformation.

        """
        try:
            from rockset import RocksetClient
        except ImportError:
            raise ImportError(
                "Could not import rockset client python package. "
                "Please install it with `pip install rockset`."
            )

        if not isinstance(client, RocksetClient):
            raise ValueError(
                f"client should be an instance of rockset.RocksetClient, "
                f"got {type(client)}"
            )
        # TODO: check that `collection_name` exists in rockset. Create if not.
        self._client = client
        self._collection_name = collection_name
        self._embeddings = embeddings
        self._text_key = text_key
        self._embedding_key = embedding_key
        self._workspace = workspace

        try:
            self._client.set_application("langchain")
        except AttributeError:
            # ignore
            pass

    @property
    def embeddings(self) -> Embeddings:
        return self._embeddings

    def add_texts(
        self,
        texts: Iterable[str],
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        batch_size: int = 32,
        **kwargs: Any,
    ) -> List[str]:
        """Run more texts through the embeddings and add to the vectorstore

                Args:
            texts: Iterable of strings to add to the vectorstore.
            metadatas: Optional list of metadatas associated with the texts.
            ids: Optional list of ids to associate with the texts.
            batch_size: Send documents in batches to rockset.

        Returns:
            List of ids from adding the texts into the vectorstore.

        """
        batch: list[dict] = []
        stored_ids = []

        for i, text in enumerate(texts):
            if len(batch) == batch_size:
                stored_ids += self._write_documents_to_rockset(batch)
                batch = []
            doc = {}
            if metadatas and len(metadatas) > i:
                doc = deepcopy(metadatas[i])
            if ids and len(ids) > i:
                doc["_id"] = ids[i]
            doc[self._text_key] = text
            doc[self._embedding_key] = self._embeddings.embed_query(text)
            batch.append(doc)
        if len(batch) > 0:
            stored_ids += self._write_documents_to_rockset(batch)
            batch = []
        return stored_ids

    @classmethod
    def from_texts(
        cls,
        texts: List[str],
        embedding: Embeddings,
        metadatas: Optional[List[dict]] = None,
        client: Any = None,
        collection_name: str = "",
        text_key: str = "",
        embedding_key: str = "",
        ids: Optional[List[str]] = None,
        batch_size: int = 32,
        **kwargs: Any,
    ) -> Rockset:
        """Create Rockset wrapper with existing texts.
        This is intended as a quicker way to get started.
        """

        # Sanitize inputs
        assert client is not None, "Rockset Client cannot be None"
        assert collection_name, "Collection name cannot be empty"
        assert text_key, "Text key name cannot be empty"
        assert embedding_key, "Embedding key cannot be empty"

        rockset = cls(client, embedding, collection_name, text_key, embedding_key)
        rockset.add_texts(texts, metadatas, ids, batch_size)
        return rockset

    # Rockset supports these vector distance functions.
    class DistanceFunction(Enum):
        COSINE_SIM = "COSINE_SIM"
        EUCLIDEAN_DIST = "EUCLIDEAN_DIST"
        DOT_PRODUCT = "DOT_PRODUCT"

        # how to sort results for "similarity"
        def order_by(self) -> str:
            if self.value == "EUCLIDEAN_DIST":
                return "ASC"
            return "DESC"

    def similarity_search_with_relevance_scores(
        self,
        query: str,
        k: int = 4,
        distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
        where_str: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Perform a similarity search with Rockset

        Args:
            query (str): Text to look up documents similar to.
            distance_func (DistanceFunction): how to compute distance between two
                vectors in Rockset.
            k (int, optional): Top K neighbors to retrieve. Defaults to 4.
            where_str (Optional[str], optional): Metadata filters supplied as a
                SQL `where` condition string. Defaults to None.
                eg. "price<=70.0 AND brand='Nintendo'"

            NOTE: Please do not let end-user to fill this and always be aware
                  of SQL injection.

        Returns:
            List[Tuple[Document, float]]: List of documents with their relevance score
        """
        return self.similarity_search_by_vector_with_relevance_scores(
            self._embeddings.embed_query(query),
            k,
            distance_func,
            where_str,
            **kwargs,
        )

    def similarity_search(
        self,
        query: str,
        k: int = 4,
        distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
        where_str: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Same as `similarity_search_with_relevance_scores` but
        doesn't return the scores.
        """
        return self.similarity_search_by_vector(
            self._embeddings.embed_query(query),
            k,
            distance_func,
            where_str,
            **kwargs,
        )

    def similarity_search_by_vector(
        self,
        embedding: List[float],
        k: int = 4,
        distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
        where_str: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Accepts a query_embedding (vector), and returns documents with
        similar embeddings."""

        docs_and_scores = self.similarity_search_by_vector_with_relevance_scores(
            embedding, k, distance_func, where_str, **kwargs
        )
        return [doc for doc, _ in docs_and_scores]

    def similarity_search_by_vector_with_relevance_scores(
        self,
        embedding: List[float],
        k: int = 4,
        distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
        where_str: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Accepts a query_embedding (vector), and returns documents with
        similar embeddings along with their relevance scores."""

        exclude_embeddings = True
        if "exclude_embeddings" in kwargs:
            exclude_embeddings = kwargs["exclude_embeddings"]
        q_str = self._build_query_sql(
            embedding, distance_func, k, where_str, exclude_embeddings
        )
        try:
            query_response = self._client.Queries.query(sql={"query": q_str})
        except Exception as e:
            logger.error("Exception when querying Rockset: %s\n", e)
            return []
        finalResult: list[Tuple[Document, float]] = []
        for document in query_response.results:
            metadata = {}
            assert isinstance(
                document, dict
            ), "document should be of type `dict[str,Any]`. But found: `{}`".format(
                type(document)
            )
            for k, v in document.items():
                if k == self._text_key:
                    assert isinstance(v, str), (
                        "page content stored in column `{}` must be of type `str`. "
                        "But found: `{}`"
                    ).format(self._text_key, type(v))
                    page_content = v
                elif k == "dist":
                    assert isinstance(v, float), (
                        "Computed distance between vectors must of type `float`. "
                        "But found {}"
                    ).format(type(v))
                    score = v
                elif k not in ["_id", "_event_time", "_meta"]:
                    # These columns are populated by Rockset when documents are
                    # inserted. No need to return them in metadata dict.
                    metadata[k] = v
            finalResult.append(
                (Document(page_content=page_content, metadata=metadata), score)
            )
        return finalResult

    def max_marginal_relevance_search(
        self,
        query: str,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        *,
        where_str: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            query: Text to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
            distance_func (DistanceFunction): how to compute distance between two
                vectors in Rockset.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
            where_str: where clause for the sql query
        Returns:
            List of Documents selected by maximal marginal relevance.
        """
        query_embedding = self._embeddings.embed_query(query)
        initial_docs = self.similarity_search_by_vector(
            query_embedding,
            k=fetch_k,
            where_str=where_str,
            exclude_embeddings=False,
            **kwargs,
        )

        embeddings = [doc.metadata[self._embedding_key] for doc in initial_docs]

        selected_indices = maximal_marginal_relevance(
            np.array(query_embedding),
            embeddings,
            lambda_mult=lambda_mult,
            k=k,
        )

        # remove embeddings key before returning for cleanup to be consistent with
        #   other search functions
        for i in selected_indices:
            del initial_docs[i].metadata[self._embedding_key]

        return [initial_docs[i] for i in selected_indices]

    # Helper functions

    def _build_query_sql(
        self,
        query_embedding: List[float],
        distance_func: DistanceFunction,
        k: int = 4,
        where_str: Optional[str] = None,
        exclude_embeddings: bool = True,
    ) -> str:
        """Builds Rockset SQL query to query similar vectors to query_vector"""

        q_embedding_str = ",".join(map(str, query_embedding))
        distance_str = f"""{distance_func.value}({self._embedding_key}, \
[{q_embedding_str}]) as dist"""
        where_str = f"WHERE {where_str}\n" if where_str else ""
        select_embedding = (
            f" EXCEPT({self._embedding_key})," if exclude_embeddings else ","
        )
        return f"""\
SELECT *{select_embedding} {distance_str}
FROM {self._workspace}.{self._collection_name}
{where_str}\
ORDER BY dist {distance_func.order_by()}
LIMIT {str(k)}
"""

    def _write_documents_to_rockset(self, batch: List[dict]) -> List[str]:
        add_doc_res = self._client.Documents.add_documents(
            collection=self._collection_name, data=batch, workspace=self._workspace
        )
        return [doc_status._id for doc_status in add_doc_res.data]

    def delete_texts(self, ids: List[str]) -> None:
        """Delete a list of docs from the Rockset collection"""
        try:
            from rockset.models import DeleteDocumentsRequestData
        except ImportError:
            raise ImportError(
                "Could not import rockset client python package. "
                "Please install it with `pip install rockset`."
            )

        self._client.Documents.delete_documents(
            collection=self._collection_name,
            data=[DeleteDocumentsRequestData(id=i) for i in ids],
            workspace=self._workspace,
        )

    def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
        try:
            if ids is None:
                ids = []
            self.delete_texts(ids)
        except Exception as e:
            logger.error("Exception when deleting docs from Rockset: %s\n", e)
            return False

        return True

    async def adelete(
        self, ids: Optional[List[str]] = None, **kwargs: Any
    ) -> Optional[bool]:
        return await run_in_executor(None, self.delete, ids, **kwargs)
