import itertools
from collections.abc import Hashable
from operator import itemgetter
from typing import Any, Callable, Dict, Iterable, List, Tuple, TypeVar, Union

from .._typing import T_num, T_obj


def cluster_list(xs: List[T_num], tolerance: T_num = 0) -> List[List[T_num]]:
    if tolerance == 0:
        return [[x] for x in sorted(xs)]
    if len(xs) < 2:
        return [[x] for x in sorted(xs)]
    groups = []
    xs = list(sorted(xs))
    current_group = [xs[0]]
    last = xs[0]
    for x in xs[1:]:
        if x <= (last + tolerance):
            current_group.append(x)
        else:
            groups.append(current_group)
            current_group = [x]
        last = x
    groups.append(current_group)
    return groups


def make_cluster_dict(values: Iterable[T_num], tolerance: T_num) -> Dict[T_num, int]:
    clusters = cluster_list(list(set(values)), tolerance)

    nested_tuples = [
        [(val, i) for val in value_cluster] for i, value_cluster in enumerate(clusters)
    ]

    return dict(itertools.chain(*nested_tuples))


Clusterable = TypeVar("Clusterable", T_obj, Tuple[Any, ...])


def cluster_objects(
    xs: List[Clusterable],
    key_fn: Union[Hashable, Callable[[Clusterable], T_num]],
    tolerance: T_num,
    preserve_order: bool = False,
) -> List[List[Clusterable]]:

    if not callable(key_fn):
        key_fn = itemgetter(key_fn)

    values = map(key_fn, xs)
    cluster_dict = make_cluster_dict(values, tolerance)

    get_0, get_1 = itemgetter(0), itemgetter(1)

    if preserve_order:
        cluster_tuples = [(x, cluster_dict.get(key_fn(x))) for x in xs]
    else:
        cluster_tuples = sorted(
            ((x, cluster_dict.get(key_fn(x))) for x in xs), key=get_1
        )

    grouped = itertools.groupby(cluster_tuples, key=get_1)

    return [list(map(get_0, v)) for k, v in grouped]
