# mypy: allow-untyped-defs
import operator
from functools import reduce
from typing import Callable, Optional, Tuple, Union

import torch
import torch.nn.functional as F

from .base_sparsifier import BaseSparsifier


__all__ = ["WeightNormSparsifier"]


def _flat_idx_to_2d(idx, shape):
    rows = idx // shape[1]
    cols = idx % shape[1]
    return rows, cols


class WeightNormSparsifier(BaseSparsifier):
    r"""Weight-Norm Sparsifier

    This sparsifier computes the norm of every sparse block and "zeroes-out" the
    ones with the lowest norm. The level of sparsity defines how many of the
    blocks is removed.

    This sparsifier is controlled by three variables:
    1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out
    2. `sparse_block_shape` defines the shape of the sparse blocks. Note that
        the sparse blocks originate at the zero-index of the tensor.
    3. `zeros_per_block` is the number of zeros that we are expecting in each
        sparse block. By default we assume that all elements within a block are
        zeroed-out. However, setting this variable sets the target number of
        zeros per block. The zeros within each block are chosen as the *smallest
        absolute values*.

    Args:

        sparsity_level: The target level of sparsity
        sparse_block_shape: The shape of a sparse block (see note below)
        zeros_per_block: Number of zeros in a sparse block
        norm: Norm to use. Could be either `int` or a callable.
            If `int`, only L1 and L2 are implemented.

    Note::
        The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS),
        irrespective of what the rows / cols mean in the data tensor. That means,
        if you were to sparsify a weight tensor in the nn.Linear, which has a
        weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output
        channels, while the `block_COLS` would refer to the input channels.

    Note::
        All arguments to the WeightNormSparsifier constructor are "default"
        arguments and could be overriden by the configuration provided in the
        `prepare` step.
    """

    def __init__(
        self,
        sparsity_level: float = 0.5,
        sparse_block_shape: Tuple[int, int] = (1, 4),
        zeros_per_block: Optional[int] = None,
        norm: Optional[Union[Callable, int]] = None,
    ):
        if zeros_per_block is None:
            zeros_per_block = reduce(operator.mul, sparse_block_shape)
        defaults = {
            "sparsity_level": sparsity_level,
            "sparse_block_shape": sparse_block_shape,
            "zeros_per_block": zeros_per_block,
        }
        if norm is None:
            norm = 2
        if callable(norm):
            self.norm_fn = norm
        elif norm == 1:
            self.norm_fn = lambda T: T.abs()
        elif norm == 2:
            self.norm_fn = lambda T: T * T
        else:
            raise NotImplementedError(f"L-{norm} is not yet implemented.")
        super().__init__(defaults=defaults)

    def _scatter_fold_block_mask(
        self,
        output_shape,
        dim,
        indices,
        block_shape,
        mask=None,
        input_shape=None,
        device=None,
    ):
        r"""Creates patches of size `block_shape` after scattering the indices."""
        if mask is None:
            assert input_shape is not None
            mask = torch.ones(input_shape, device=device)
        mask.scatter_(dim=dim, index=indices, value=0)
        mask.data = F.fold(
            mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape
        )
        return mask

    def _make_tensor_mask(
        self, data, input_shape, sparsity_level, sparse_block_shape, mask=None
    ):
        r"""Creates a tensor-level mask.

        Tensor-level mask is described as a mask, where the granularity of sparsification of the
        smallest patch is the sparse_block_shape. That means, that for a given mask and a
        sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape.

        In this context, `sparsity_level` describes the fraction of sparse patches.
        """
        h, w = data.shape[-2:]
        block_h, block_w = sparse_block_shape
        dh = (block_h - h % block_h) % block_h
        dw = (block_w - w % block_w) % block_w

        if mask is None:
            mask = torch.ones(h + dh, w + dw, device=data.device)

        if sparsity_level >= 1.0:
            mask.data = torch.zeros_like(mask)
            return mask
        elif sparsity_level <= 0.0:
            mask.data = torch.ones_like(mask)
            return mask

        values_per_block = reduce(operator.mul, sparse_block_shape)
        if values_per_block > 1:
            # Reduce the data
            data = F.avg_pool2d(
                data[None, None, :],
                kernel_size=sparse_block_shape,
                stride=sparse_block_shape,
                ceil_mode=True,
            )
        data = data.flatten()
        num_blocks = len(data)

        data = data.repeat(1, values_per_block, 1)

        threshold_idx = int(round(sparsity_level * num_blocks))
        threshold_idx = max(0, min(num_blocks - 1, threshold_idx))  # Sanity check
        _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False)

        # Temp reshape for mask
        mask_reshape = mask.reshape(data.shape)  # data might be reshaped
        self._scatter_fold_block_mask(
            dim=2,
            output_shape=(h + dh, w + dw),
            indices=sorted_idx,
            block_shape=sparse_block_shape,
            mask=mask_reshape,
        )
        mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous()
        return mask

    def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None):
        r"""Creates a block-level mask.

        Block-level mask is described as a mask, where the granularity of sparsification of the
        largest patch is the sparse_block_shape. That means that for a given mask and a
        sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape.

        In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch.
        """
        h, w = data.shape[-2:]
        block_h, block_w = sparse_block_shape
        dh = (block_h - h % block_h) % block_h
        dw = (block_w - w % block_w) % block_w
        values_per_block = reduce(operator.mul, sparse_block_shape)

        if mask is None:
            mask = torch.ones((h + dh, w + dw), device=data.device)

        if values_per_block == zeros_per_block:
            # Everything should be sparsified
            mask.data = torch.zeros_like(mask)
            return mask

        # create a new padded tensor like data (to match the block_shape)
        padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device)
        padded_data.fill_(torch.nan)
        padded_data[:h, :w] = data
        unfolded_data = F.unfold(
            padded_data[None, None, :],
            kernel_size=sparse_block_shape,
            stride=sparse_block_shape,
        )

        # Temp reshape for mask
        mask_reshape = mask.reshape(unfolded_data.shape)
        _, sorted_idx = torch.topk(
            unfolded_data, k=zeros_per_block, dim=1, largest=False
        )

        self._scatter_fold_block_mask(
            dim=1,
            indices=sorted_idx,
            output_shape=padded_data.shape,
            block_shape=sparse_block_shape,
            mask=mask_reshape,
        )

        mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous()
        return mask

    def update_mask(
        self,
        module,
        tensor_name,
        sparsity_level,
        sparse_block_shape,
        zeros_per_block,
        **kwargs,
    ):
        values_per_block = reduce(operator.mul, sparse_block_shape)
        if zeros_per_block > values_per_block:
            raise ValueError(
                "Number of zeros per block cannot be more than the total number of elements in that block."
            )
        if zeros_per_block < 0:
            raise ValueError("Number of zeros per block should be positive.")

        mask = getattr(module.parametrizations, tensor_name)[0].mask
        if sparsity_level <= 0 or zeros_per_block == 0:
            mask.data = torch.ones_like(mask)
        elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block):
            mask.data = torch.zeros_like(mask)
        else:
            ww = self.norm_fn(getattr(module, tensor_name))
            tensor_mask = self._make_tensor_mask(
                data=ww,
                input_shape=ww.shape,
                sparsity_level=sparsity_level,
                sparse_block_shape=sparse_block_shape,
            )
            if values_per_block != zeros_per_block:
                block_mask = self._make_block_mask(
                    data=ww,
                    sparse_block_shape=sparse_block_shape,
                    zeros_per_block=zeros_per_block,
                )
                tensor_mask = torch.logical_or(tensor_mask, block_mask)
            mask.data = tensor_mask
