import logging
import os
import re
import statistics
from pathlib import Path
from typing import List, Optional, Union

import click
import pandas as pd

from unstructured.staging.base import elements_from_json, elements_to_text

logger = logging.getLogger("unstructured.eval")


def _prepare_output_cct(docpath: str, output_type: str) -> str:
    """
    Convert given input document (path) into cct-ready. The function only support conversion
    from `json` or `txt` file.
    """
    try:
        if output_type == "json":
            output_cct = elements_to_text(elements_from_json(docpath))
        elif output_type == "txt":
            output_cct = _read_text_file(docpath)
        else:
            raise ValueError(
                f"File type not supported. Expects one of `json` or `txt`, \
                    but received {output_type} instead."
            )
    except ValueError as e:
        logger.error(f"Could not read the file {docpath}")
        raise e
    return output_cct


def _listdir_recursive(dir: str) -> List[str]:
    """
    Recursively lists all files in the given directory and its subdirectories.
    Returns a list of all files found, with each file's path relative to the
    initial directory.
    """
    listdir = []
    for dirpath, _, filenames in os.walk(dir):
        for filename in filenames:
            # Remove the starting directory from the path to show the relative path
            relative_path = os.path.relpath(dirpath, dir)
            if relative_path == ".":
                listdir.append(filename)
            else:
                listdir.append(os.path.join(relative_path, filename))
    return listdir


def _rename_aggregated_columns(df):
    """
    Renames aggregated columns in a DataFrame based on a predefined mapping.

    Parameters:
    df (pandas.DataFrame): The DataFrame with aggregated columns to rename.

    Returns:
    pandas.DataFrame: A new DataFrame with renamed aggregated columns.
    """
    rename_map = {"_mean": "mean", "_stdev": "stdev", "_pstdev": "pstdev", "_count": "count"}
    return df.rename(columns=rename_map)


def _format_grouping_output(*df):
    """
    Concatenates multiple pandas DataFrame objects along the columns (side-by-side)
    and resets the index.
    """
    return pd.concat(df, axis=1).reset_index()


def _display(df):
    """
    Displays the evaluation metrics in a formatted text table.
    """
    if len(df) == 0:
        return
    headers = df.columns.tolist()
    col_widths = [
        max(len(header), max(len(str(item)) for item in df[header])) for header in headers
    ]
    click.echo(" ".join(header.ljust(col_widths[i]) for i, header in enumerate(headers)))
    click.echo("-" * sum(col_widths) + "-" * (len(headers) - 1))
    for _, row in df.iterrows():
        formatted_row = []
        for item in row:
            if isinstance(item, float):
                formatted_row.append(f"{item:.3f}")
            else:
                formatted_row.append(str(item))
        click.echo(
            " ".join(formatted_row[i].ljust(col_widths[i]) for i in range(len(formatted_row))),
        )


def _write_to_file(
    directory: str, filename: str, df: pd.DataFrame, mode: str = "w", overwrite: bool = True
):
    """
    Save the metrics report to tsv file. The function allows an option 1) to choose `mode`
    as `w` (write) or `a` (append) and 2) to `overwrite` the file if filename existed or not.
    """
    if mode not in ["w", "a"]:
        raise ValueError("Mode not supported. Mode must be one of [w, a].")
    if directory:
        Path(directory).mkdir(exist_ok=True)
    if "count" in df.columns:
        df["count"] = df["count"].astype(int)
    if "filename" in df.columns and "connector" in df.columns:
        df.sort_values(by=["connector", "filename"], inplace=True)
    if not overwrite:
        filename = _get_non_duplicated_filename(directory, filename)
    df.to_csv(
        os.path.join(directory, filename), sep="\t", mode=mode, index=False, header=(mode == "w")
    )


def _sorting_key(filename):
    """
    A function that defines the sorting method for duplicated file names. For example,
    with filename.ext filename (1).ext filename (2).ext filename (10).ext - this function
    extracts the integer in the bracket and sort those numbers ascendingly.
    """
    # Regular expression to find the number in the filename
    numbers = re.findall(r"(\d+)", filename)
    if numbers:
        # If there's a number, return it as an integer for sorting
        return int(numbers[-1])
    else:
        # If no number, return 0 so these files come first
        return 0


def _uniquity_file(file_list, target_filename) -> str:
    """
    Checks the duplicity of the file name from the list and run the numerical check
    of the minimum number needed as extension to not overwrite the exising file.
    Returns a string of file name in the format of `filename (<min number>).ext`.
    """
    original_filename, extension = target_filename.rsplit(".", 1)
    pattern = rf"^{re.escape(original_filename)}(?: \((\d+)\))?\.{re.escape(extension)}$"
    duplicated_files = sorted([f for f in file_list if re.match(pattern, f)], key=_sorting_key)

    numbers = []
    for file in duplicated_files:
        match = re.search(r"\((\d+)\)", file)
        if match:
            numbers.append(int(match.group(1)))

    numbers.sort()

    counter = 1
    for number in numbers:
        if number == counter:
            counter += 1
        else:
            break

    return original_filename + " (" + str(counter) + ")." + extension


def _get_non_duplicated_filename(dir, filename) -> str:
    """
    Helper function to calls the `_uniquity_file` function. Takes in directory and file name
    to check on.
    """
    filename = _uniquity_file(os.listdir(dir), filename)
    return filename


def _mean(scores: Union[pd.Series, List[float]], rounding: Optional[int] = 3) -> Union[float, None]:
    """
    Find mean from the list. Returns None if no element in the list.

    Args:
        rounding (int): optional argument that allows user to define decimal points. Default at 3.
    """
    if len(scores) == 0:
        return None
    mean = statistics.mean(scores)
    if not rounding:
        return mean
    return round(mean, rounding)


def _stdev(scores: List[Optional[float]], rounding: Optional[int] = 3) -> Union[float, None]:
    """
    Find standard deviation from the list.
    Returns None if only 0 or 1 element in the list.

    Args:
        rounding (int): optional argument that allows user to define decimal points. Default at 3.
    """
    # Filter out None values
    scores = [score for score in scores if score is not None]
    # Proceed only if there are more than one value
    if len(scores) <= 1:
        return None
    if not rounding:
        return statistics.stdev(scores)
    return round(statistics.stdev(scores), rounding)


def _pstdev(scores: List[Optional[float]], rounding: Optional[int] = 3) -> Union[float, None]:
    """
    Find population standard deviation from the list.
    Returns None if only 0 or 1 element in the list.

    Args:
        rounding (int): optional argument that allows user to define decimal points. Default at 3.
    """
    scores = [score for score in scores if score is not None]
    if len(scores) <= 1:
        return None
    if not rounding:
        return statistics.pstdev(scores)
    return round(statistics.pstdev(scores), rounding)


def _count(scores: List[Optional[float]]) -> float:
    """
    Returns the row count of the list.
    """
    return len(scores)


def _read_text_file(path):
    """
    Reads the contents of a text file and returns it as a string.
    """
    # Check if the file exists
    if not os.path.exists(path):
        raise FileNotFoundError(f"The file at {path} does not exist.")

    try:
        with open(path, errors="ignore") as f:
            text = f.read()
        return text
    except OSError as e:
        # Handle other I/O related errors
        raise IOError(f"An error occurred when reading the file at {path}: {e}")
