from __future__ import annotations

import time
from itertools import repeat
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore


class XataVectorStore(VectorStore):
    """`Xata` vector store.

    It assumes you have a Xata database
    created with the right schema. See the guide at:
    https://integrations.langchain.com/vectorstores?integration_name=XataVectorStore

    """

    def __init__(
        self,
        api_key: str,
        db_url: str,
        embedding: Embeddings,
        table_name: str,
    ) -> None:
        """Initialize with Xata client."""
        try:
            from xata.client import XataClient
        except ImportError:
            raise ImportError(
                "Could not import xata python package. "
                "Please install it with `pip install xata`."
            )
        self._client = XataClient(api_key=api_key, db_url=db_url)
        self._embedding: Embeddings = embedding
        self._table_name = table_name or "vectors"

    @property
    def embeddings(self) -> Embeddings:
        return self._embedding

    def add_vectors(
        self,
        vectors: List[List[float]],
        documents: List[Document],
        ids: Optional[List[str]] = None,
    ) -> List[str]:
        return self._add_vectors(vectors, documents, ids)

    def add_texts(
        self,
        texts: Iterable[str],
        metadatas: Optional[List[Dict[Any, Any]]] = None,
        ids: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> List[str]:
        ids = ids
        docs = self._texts_to_documents(texts, metadatas)

        vectors = self._embedding.embed_documents(list(texts))
        return self.add_vectors(vectors, docs, ids)

    def _add_vectors(
        self,
        vectors: List[List[float]],
        documents: List[Document],
        ids: Optional[List[str]] = None,
    ) -> List[str]:
        """Add vectors to the Xata database."""

        rows: List[Dict[str, Any]] = []
        for idx, embedding in enumerate(vectors):
            row = {
                "content": documents[idx].page_content,
                "embedding": embedding,
            }
            if ids:
                row["id"] = ids[idx]
            for key, val in documents[idx].metadata.items():
                if key not in ["id", "content", "embedding"]:
                    row[key] = val
            rows.append(row)

        # XXX: I would have liked to use the BulkProcessor here, but it
        # doesn't return the IDs, which we need here. Manual chunking it is.
        chunk_size = 1000
        id_list: List[str] = []
        for i in range(0, len(rows), chunk_size):
            chunk = rows[i : i + chunk_size]

            r = self._client.records().bulk_insert(self._table_name, {"records": chunk})
            if r.status_code != 200:
                raise Exception(f"Error adding vectors to Xata: {r.status_code} {r}")
            id_list.extend(r["recordIDs"])
        return id_list

    @staticmethod
    def _texts_to_documents(
        texts: Iterable[str],
        metadatas: Optional[Iterable[Dict[Any, Any]]] = None,
    ) -> List[Document]:
        """Return list of Documents from list of texts and metadatas."""
        if metadatas is None:
            metadatas = repeat({})

        docs = [
            Document(page_content=text, metadata=metadata)
            for text, metadata in zip(texts, metadatas)
        ]

        return docs

    @classmethod
    def from_texts(
        cls: Type["XataVectorStore"],
        texts: List[str],
        embedding: Embeddings,
        metadatas: Optional[List[dict]] = None,
        api_key: Optional[str] = None,
        db_url: Optional[str] = None,
        table_name: str = "vectors",
        ids: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> "XataVectorStore":
        """Return VectorStore initialized from texts and embeddings."""

        if not api_key or not db_url:
            raise ValueError("Xata api_key and db_url must be set.")

        embeddings = embedding.embed_documents(texts)
        ids = None  # Xata will generate them for us
        docs = cls._texts_to_documents(texts, metadatas)

        vector_db = cls(
            api_key=api_key,
            db_url=db_url,
            embedding=embedding,
            table_name=table_name,
        )

        vector_db._add_vectors(embeddings, docs, ids)
        return vector_db

    def similarity_search(
        self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
    ) -> List[Document]:
        """Return docs most similar to query.

        Args:
            query: Text to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.

        Returns:
            List of Documents most similar to the query.
        """
        docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
        documents = [d[0] for d in docs_and_scores]
        return documents

    def similarity_search_with_score(
        self, query: str, k: int = 4, filter: Optional[dict] = None, **kwargs: Any
    ) -> List[Tuple[Document, float]]:
        """Run similarity search with Chroma with distance.

        Args:
            query (str): Query text to search for.
            k (int): Number of results to return. Defaults to 4.
            filter (Optional[dict]): Filter by metadata. Defaults to None.

        Returns:
            List[Tuple[Document, float]]: List of documents most similar to the query
                text with distance in float.
        """
        embedding = self._embedding.embed_query(query)
        payload = {
            "queryVector": embedding,
            "column": "embedding",
            "size": k,
        }
        if filter:
            payload["filter"] = filter
        r = self._client.data().vector_search(self._table_name, payload=payload)
        if r.status_code != 200:
            raise Exception(f"Error running similarity search: {r.status_code} {r}")
        hits = r["records"]
        docs_and_scores = [
            (
                Document(
                    page_content=hit["content"],
                    metadata=self._extractMetadata(hit),
                ),
                hit["xata"]["score"],
            )
            for hit in hits
        ]
        return docs_and_scores

    def _extractMetadata(self, record: dict) -> dict:
        """Extract metadata from a record. Filters out known columns."""
        metadata = {}
        for key, val in record.items():
            if key not in ["id", "content", "embedding", "xata"]:
                metadata[key] = val
        return metadata

    def delete(
        self,
        ids: Optional[List[str]] = None,
        delete_all: Optional[bool] = None,
        **kwargs: Any,
    ) -> None:
        """Delete by vector IDs.

        Args:
            ids: List of ids to delete.
            delete_all: Delete all records in the table.
        """
        if delete_all:
            self._delete_all()
            self.wait_for_indexing(ndocs=0)
        elif ids is not None:
            chunk_size = 500
            for i in range(0, len(ids), chunk_size):
                chunk = ids[i : i + chunk_size]
                operations = [
                    {"delete": {"table": self._table_name, "id": id}} for id in chunk
                ]
                self._client.records().transaction(payload={"operations": operations})
        else:
            raise ValueError("Either ids or delete_all must be set.")

    def _delete_all(self) -> None:
        """Delete all records in the table."""
        while True:
            r = self._client.data().query(self._table_name, payload={"columns": ["id"]})
            if r.status_code != 200:
                raise Exception(f"Error running query: {r.status_code} {r}")
            ids = [rec["id"] for rec in r["records"]]
            if len(ids) == 0:
                break
            operations = [
                {"delete": {"table": self._table_name, "id": id}} for id in ids
            ]
            self._client.records().transaction(payload={"operations": operations})

    def wait_for_indexing(self, timeout: float = 5, ndocs: int = 1) -> None:
        """Wait for the search index to contain a certain number of
        documents. Useful in tests.
        """
        start = time.time()
        while True:
            r = self._client.data().search_table(
                self._table_name, payload={"query": "", "page": {"size": 0}}
            )
            if r.status_code != 200:
                raise Exception(f"Error running search: {r.status_code} {r}")
            if r["totalCount"] == ndocs:
                break
            if time.time() - start > timeout:
                raise Exception("Timed out waiting for indexing to complete.")
            time.sleep(0.5)
