from __future__ import annotations

from typing import TYPE_CHECKING, Any, Iterable, List, Optional

if TYPE_CHECKING:
    from pyspark.sql import DataFrame, Row, SparkSession


class SparkSQL:
    """SparkSQL is a utility class for interacting with Spark SQL."""

    def __init__(
        self,
        spark_session: Optional[SparkSession] = None,
        catalog: Optional[str] = None,
        schema: Optional[str] = None,
        ignore_tables: Optional[List[str]] = None,
        include_tables: Optional[List[str]] = None,
        sample_rows_in_table_info: int = 3,
    ):
        """Initialize a SparkSQL object.

        Args:
            spark_session: A SparkSession object.
              If not provided, one will be created.
            catalog: The catalog to use.
              If not provided, the default catalog will be used.
            schema: The schema to use.
              If not provided, the default schema will be used.
            ignore_tables: A list of tables to ignore.
              If not provided, all tables will be used.
            include_tables: A list of tables to include.
              If not provided, all tables will be used.
            sample_rows_in_table_info: The number of rows to include in the table info.
              Defaults to 3.
        """
        try:
            from pyspark.sql import SparkSession
        except ImportError:
            raise ImportError(
                "pyspark is not installed. Please install it with `pip install pyspark`"
            )

        self._spark = (
            spark_session if spark_session else SparkSession.builder.getOrCreate()
        )
        if catalog is not None:
            self._spark.catalog.setCurrentCatalog(catalog)
        if schema is not None:
            self._spark.catalog.setCurrentDatabase(schema)

        self._all_tables = set(self._get_all_table_names())
        self._include_tables = set(include_tables) if include_tables else set()
        if self._include_tables:
            missing_tables = self._include_tables - self._all_tables
            if missing_tables:
                raise ValueError(
                    f"include_tables {missing_tables} not found in database"
                )
        self._ignore_tables = set(ignore_tables) if ignore_tables else set()
        if self._ignore_tables:
            missing_tables = self._ignore_tables - self._all_tables
            if missing_tables:
                raise ValueError(
                    f"ignore_tables {missing_tables} not found in database"
                )
        usable_tables = self.get_usable_table_names()
        self._usable_tables = set(usable_tables) if usable_tables else self._all_tables

        if not isinstance(sample_rows_in_table_info, int):
            raise TypeError("sample_rows_in_table_info must be an integer")

        self._sample_rows_in_table_info = sample_rows_in_table_info

    @classmethod
    def from_uri(
        cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
    ) -> SparkSQL:
        """Creating a remote Spark Session via Spark connect.
        For example: SparkSQL.from_uri("sc://localhost:15002")
        """
        try:
            from pyspark.sql import SparkSession
        except ImportError:
            raise ImportError(
                "pyspark is not installed. Please install it with `pip install pyspark`"
            )

        spark = SparkSession.builder.remote(database_uri).getOrCreate()
        return cls(spark, **kwargs)

    def get_usable_table_names(self) -> Iterable[str]:
        """Get names of tables available."""
        if self._include_tables:
            return self._include_tables
        # sorting the result can help LLM understanding it.
        return sorted(self._all_tables - self._ignore_tables)

    def _get_all_table_names(self) -> Iterable[str]:
        rows = self._spark.sql("SHOW TABLES").select("tableName").collect()
        return list(map(lambda row: row.tableName, rows))

    def _get_create_table_stmt(self, table: str) -> str:
        statement = (
            self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt
        )
        # Ignore the data source provider and options to reduce the number of tokens.
        using_clause_index = statement.find("USING")
        return statement[:using_clause_index] + ";"

    def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
        all_table_names = self.get_usable_table_names()
        if table_names is not None:
            missing_tables = set(table_names).difference(all_table_names)
            if missing_tables:
                raise ValueError(f"table_names {missing_tables} not found in database")
            all_table_names = table_names
        tables = []
        for table_name in all_table_names:
            table_info = self._get_create_table_stmt(table_name)
            if self._sample_rows_in_table_info:
                table_info += "\n\n/*"
                table_info += f"\n{self._get_sample_spark_rows(table_name)}\n"
                table_info += "*/"
            tables.append(table_info)
        final_str = "\n\n".join(tables)
        return final_str

    def _get_sample_spark_rows(self, table: str) -> str:
        query = f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}"
        df = self._spark.sql(query)
        columns_str = "\t".join(list(map(lambda f: f.name, df.schema.fields)))
        try:
            sample_rows = self._get_dataframe_results(df)
            # save the sample rows in string format
            sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
        except Exception:
            sample_rows_str = ""

        return (
            f"{self._sample_rows_in_table_info} rows from {table} table:\n"
            f"{columns_str}\n"
            f"{sample_rows_str}"
        )

    def _convert_row_as_tuple(self, row: Row) -> tuple:
        return tuple(map(str, row.asDict().values()))

    def _get_dataframe_results(self, df: DataFrame) -> list:
        return list(map(self._convert_row_as_tuple, df.collect()))

    def run(self, command: str, fetch: str = "all") -> str:
        df = self._spark.sql(command)
        if fetch == "one":
            df = df.limit(1)
        return str(self._get_dataframe_results(df))

    def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
        """Get information about specified tables.

        Follows best practices as specified in: Rajkumar et al, 2022
        (https://arxiv.org/abs/2204.00498)

        If `sample_rows_in_table_info`, the specified number of sample rows will be
        appended to each table description. This can increase performance as
        demonstrated in the paper.
        """
        try:
            return self.get_table_info(table_names)
        except ValueError as e:
            """Format the error message"""
            return f"Error: {e}"

    def run_no_throw(self, command: str, fetch: str = "all") -> str:
        """Execute a SQL command and return a string representing the results.

        If the statement returns rows, a string of the results is returned.
        If the statement returns no rows, an empty string is returned.

        If the statement throws an error, the error message is returned.
        """
        try:
            return self.run(command, fetch)
        except Exception as e:
            """Format the error message"""
            return f"Error: {e}"
