"""
Poolformer from MetaFormer is Actually What You Need for Vision https://arxiv.org/abs/2111.11418

IdentityFormer, RandFormer, PoolFormerV2, ConvFormer, and CAFormer
from MetaFormer Baselines for Vision https://arxiv.org/abs/2210.13452

All implemented models support feature extraction and variable input resolution.

Original implementation by Weihao Yu et al.,
adapted for timm by Fredo Guan and Ross Wightman.

Adapted from https://github.com/sail-sg/metaformer, original copyright below
"""

# Copyright 2022 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.jit import Final

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \
    use_fused_attn
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model

__all__ = ['MetaFormer']


class Stem(nn.Module):
    """
    Stem implemented by a layer of convolution.
    Conv2d params constant across all models.
    """

    def __init__(
            self,
            in_channels,
            out_channels,
            norm_layer=None,
    ):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=7,
            stride=4,
            padding=2
        )
        self.norm = norm_layer(out_channels) if norm_layer else nn.Identity()

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        return x


class Downsampling(nn.Module):
    """
    Downsampling implemented by a layer of convolution.
    """

    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            norm_layer=None,
    ):
        super().__init__()
        self.norm = norm_layer(in_channels) if norm_layer else nn.Identity()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding
        )

    def forward(self, x):
        x = self.norm(x)
        x = self.conv(x)
        return x


class Scale(nn.Module):
    """
    Scale vector by element multiplications.
    """

    def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True):
        super().__init__()
        self.shape = (dim, 1, 1) if use_nchw else (dim,)
        self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)

    def forward(self, x):
        return x * self.scale.view(self.shape)


class SquaredReLU(nn.Module):
    """
        Squared ReLU: https://arxiv.org/abs/2109.08668
    """

    def __init__(self, inplace=False):
        super().__init__()
        self.relu = nn.ReLU(inplace=inplace)

    def forward(self, x):
        return torch.square(self.relu(x))


class StarReLU(nn.Module):
    """
    StarReLU: s * relu(x) ** 2 + b
    """

    def __init__(
            self,
            scale_value=1.0,
            bias_value=0.0,
            scale_learnable=True,
            bias_learnable=True,
            mode=None,
            inplace=False
    ):
        super().__init__()
        self.inplace = inplace
        self.relu = nn.ReLU(inplace=inplace)
        self.scale = nn.Parameter(scale_value * torch.ones(1), requires_grad=scale_learnable)
        self.bias = nn.Parameter(bias_value * torch.ones(1), requires_grad=bias_learnable)

    def forward(self, x):
        return self.scale * self.relu(x) ** 2 + self.bias


class Attention(nn.Module):
    """
    Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762.
    Modified from timm.
    """
    fused_attn: Final[bool]

    def __init__(
            self,
            dim,
            head_dim=32,
            num_heads=None,
            qkv_bias=False,
            attn_drop=0.,
            proj_drop=0.,
            proj_bias=False,
            **kwargs
    ):
        super().__init__()

        self.head_dim = head_dim
        self.scale = head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.num_heads = num_heads if num_heads else dim // head_dim
        if self.num_heads == 0:
            self.num_heads = 1

        self.attention_dim = self.num_heads * self.head_dim

        self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


# custom norm modules that disable the bias term, since the original models defs
# used a custom norm with a weight term but no bias term.

class GroupNorm1NoBias(GroupNorm1):
    def __init__(self, num_channels, **kwargs):
        super().__init__(num_channels, **kwargs)
        self.eps = kwargs.get('eps', 1e-6)
        self.bias = None


class LayerNorm2dNoBias(LayerNorm2d):
    def __init__(self, num_channels, **kwargs):
        super().__init__(num_channels, **kwargs)
        self.eps = kwargs.get('eps', 1e-6)
        self.bias = None


class LayerNormNoBias(nn.LayerNorm):
    def __init__(self, num_channels, **kwargs):
        super().__init__(num_channels, **kwargs)
        self.eps = kwargs.get('eps', 1e-6)
        self.bias = None


class SepConv(nn.Module):
    r"""
    Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.
    """

    def __init__(
            self,
            dim,
            expansion_ratio=2,
            act1_layer=StarReLU,
            act2_layer=nn.Identity,
            bias=False,
            kernel_size=7,
            padding=3,
            **kwargs
    ):
        super().__init__()
        mid_channels = int(expansion_ratio * dim)
        self.pwconv1 = nn.Conv2d(dim, mid_channels, kernel_size=1, bias=bias)
        self.act1 = act1_layer()
        self.dwconv = nn.Conv2d(
            mid_channels, mid_channels, kernel_size=kernel_size,
            padding=padding, groups=mid_channels, bias=bias)  # depthwise conv
        self.act2 = act2_layer()
        self.pwconv2 = nn.Conv2d(mid_channels, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.pwconv1(x)
        x = self.act1(x)
        x = self.dwconv(x)
        x = self.act2(x)
        x = self.pwconv2(x)
        return x


class Pooling(nn.Module):
    """
    Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418
    """

    def __init__(self, pool_size=3, **kwargs):
        super().__init__()
        self.pool = nn.AvgPool2d(
            pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)

    def forward(self, x):
        y = self.pool(x)
        return y - x


class MlpHead(nn.Module):
    """ MLP classification head
    """

    def __init__(
            self,
            dim,
            num_classes=1000,
            mlp_ratio=4,
            act_layer=SquaredReLU,
            norm_layer=LayerNorm,
            drop_rate=0.,
            bias=True
    ):
        super().__init__()
        hidden_features = int(mlp_ratio * dim)
        self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
        self.act = act_layer()
        self.norm = norm_layer(hidden_features)
        self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
        self.head_drop = nn.Dropout(drop_rate)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.norm(x)
        x = self.head_drop(x)
        x = self.fc2(x)
        return x


class MetaFormerBlock(nn.Module):
    """
    Implementation of one MetaFormer block.
    """

    def __init__(
            self,
            dim,
            token_mixer=Pooling,
            mlp_act=StarReLU,
            mlp_bias=False,
            norm_layer=LayerNorm2d,
            proj_drop=0.,
            drop_path=0.,
            use_nchw=True,
            layer_scale_init_value=None,
            res_scale_init_value=None,
            **kwargs
    ):
        super().__init__()
        ls_layer = partial(Scale, dim=dim, init_value=layer_scale_init_value, use_nchw=use_nchw)
        rs_layer = partial(Scale, dim=dim, init_value=res_scale_init_value, use_nchw=use_nchw)

        self.norm1 = norm_layer(dim)
        self.token_mixer = token_mixer(dim=dim, proj_drop=proj_drop, **kwargs)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.layer_scale1 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
        self.res_scale1 = rs_layer() if res_scale_init_value is not None else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(
            dim,
            int(4 * dim),
            act_layer=mlp_act,
            bias=mlp_bias,
            drop=proj_drop,
            use_conv=use_nchw,
        )
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.layer_scale2 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
        self.res_scale2 = rs_layer() if res_scale_init_value is not None else nn.Identity()

    def forward(self, x):
        x = self.res_scale1(x) + \
            self.layer_scale1(
                self.drop_path1(
                    self.token_mixer(self.norm1(x))
                )
            )
        x = self.res_scale2(x) + \
            self.layer_scale2(
                self.drop_path2(
                    self.mlp(self.norm2(x))
                )
            )
        return x


class MetaFormerStage(nn.Module):

    def __init__(
            self,
            in_chs,
            out_chs,
            depth=2,
            token_mixer=nn.Identity,
            mlp_act=StarReLU,
            mlp_bias=False,
            downsample_norm=LayerNorm2d,
            norm_layer=LayerNorm2d,
            proj_drop=0.,
            dp_rates=[0.] * 2,
            layer_scale_init_value=None,
            res_scale_init_value=None,
            **kwargs,
    ):
        super().__init__()

        self.grad_checkpointing = False
        self.use_nchw = not issubclass(token_mixer, Attention)

        # don't downsample if in_chs and out_chs are the same
        self.downsample = nn.Identity() if in_chs == out_chs else Downsampling(
            in_chs,
            out_chs,
            kernel_size=3,
            stride=2,
            padding=1,
            norm_layer=downsample_norm,
        )

        self.blocks = nn.Sequential(*[MetaFormerBlock(
            dim=out_chs,
            token_mixer=token_mixer,
            mlp_act=mlp_act,
            mlp_bias=mlp_bias,
            norm_layer=norm_layer,
            proj_drop=proj_drop,
            drop_path=dp_rates[i],
            layer_scale_init_value=layer_scale_init_value,
            res_scale_init_value=res_scale_init_value,
            use_nchw=self.use_nchw,
            **kwargs,
        ) for i in range(depth)])

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable

    def forward(self, x: Tensor):
        x = self.downsample(x)
        B, C, H, W = x.shape

        if not self.use_nchw:
            x = x.reshape(B, C, -1).transpose(1, 2)

        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.blocks(x)

        if not self.use_nchw:
            x = x.transpose(1, 2).reshape(B, C, H, W)

        return x


class MetaFormer(nn.Module):
    r""" MetaFormer
        A PyTorch impl of : `MetaFormer Baselines for Vision`  -
          https://arxiv.org/abs/2210.13452

    Args:
        in_chans (int): Number of input image channels.
        num_classes (int): Number of classes for classification head.
        global_pool: Pooling for classifier head.
        depths (list or tuple): Number of blocks at each stage.
        dims (list or tuple): Feature dimension at each stage.
        token_mixers (list, tuple or token_fcn): Token mixer for each stage.
        mlp_act: Activation layer for MLP.
        mlp_bias (boolean): Enable or disable mlp bias term.
        drop_path_rate (float): Stochastic depth rate.
        drop_rate (float): Dropout rate.
        layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale.
            None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239.
        res_scale_init_values (list, tuple, float or None): Init value for res Scale on residual connections.
            None means not use the res scale. From: https://arxiv.org/abs/2110.09456.
        downsample_norm (nn.Module): Norm layer used in stem and downsampling layers.
        norm_layers (list, tuple or norm_fcn): Norm layers for each stage.
        output_norm: Norm layer before classifier head.
        use_mlp_head: Use MLP classification head.
    """

    def __init__(
            self,
            in_chans=3,
            num_classes=1000,
            global_pool='avg',
            depths=(2, 2, 6, 2),
            dims=(64, 128, 320, 512),
            token_mixers=Pooling,
            mlp_act=StarReLU,
            mlp_bias=False,
            drop_path_rate=0.,
            proj_drop_rate=0.,
            drop_rate=0.0,
            layer_scale_init_values=None,
            res_scale_init_values=(None, None, 1.0, 1.0),
            downsample_norm=LayerNorm2dNoBias,
            norm_layers=LayerNorm2dNoBias,
            output_norm=LayerNorm2d,
            use_mlp_head=True,
            **kwargs,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = dims[-1]
        self.drop_rate = drop_rate
        self.use_mlp_head = use_mlp_head
        self.num_stages = len(depths)

        # convert everything to lists if they aren't indexable
        if not isinstance(depths, (list, tuple)):
            depths = [depths]  # it means the model has only one stage
        if not isinstance(dims, (list, tuple)):
            dims = [dims]
        if not isinstance(token_mixers, (list, tuple)):
            token_mixers = [token_mixers] * self.num_stages
        if not isinstance(norm_layers, (list, tuple)):
            norm_layers = [norm_layers] * self.num_stages
        if not isinstance(layer_scale_init_values, (list, tuple)):
            layer_scale_init_values = [layer_scale_init_values] * self.num_stages
        if not isinstance(res_scale_init_values, (list, tuple)):
            res_scale_init_values = [res_scale_init_values] * self.num_stages

        self.grad_checkpointing = False
        self.feature_info = []

        self.stem = Stem(
            in_chans,
            dims[0],
            norm_layer=downsample_norm
        )

        stages = []
        prev_dim = dims[0]
        dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
        for i in range(self.num_stages):
            stages += [MetaFormerStage(
                prev_dim,
                dims[i],
                depth=depths[i],
                token_mixer=token_mixers[i],
                mlp_act=mlp_act,
                mlp_bias=mlp_bias,
                proj_drop=proj_drop_rate,
                dp_rates=dp_rates[i],
                layer_scale_init_value=layer_scale_init_values[i],
                res_scale_init_value=res_scale_init_values[i],
                downsample_norm=downsample_norm,
                norm_layer=norm_layers[i],
                **kwargs,
            )]
            prev_dim = dims[i]
            self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]

        self.stages = nn.Sequential(*stages)

        # if using MlpHead, dropout is handled by MlpHead
        if num_classes > 0:
            if self.use_mlp_head:
                # FIXME not actually returning mlp hidden state right now as pre-logits.
                final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
                self.head_hidden_size = self.num_features
            else:
                final = nn.Linear(self.num_features, num_classes)
                self.head_hidden_size = self.num_features
        else:
            final = nn.Identity()

        self.head = nn.Sequential(OrderedDict([
            ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
            ('norm', output_norm(self.num_features)),
            ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
            ('drop', nn.Dropout(drop_rate) if self.use_mlp_head else nn.Identity()),
            ('fc', final)
        ]))

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable
        for stage in self.stages:
            stage.set_grad_checkpointing(enable=enable)

    @torch.jit.ignore
    def get_classifier(self) -> nn.Module:
        return self.head.fc

    def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
        if global_pool is not None:
            self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
            self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
        if num_classes > 0:
            if self.use_mlp_head:
                final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
            else:
                final = nn.Linear(self.num_features, num_classes)
        else:
            final = nn.Identity()
        self.head.fc = final

    def forward_head(self, x: Tensor, pre_logits: bool = False):
        # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
        x = self.head.global_pool(x)
        x = self.head.norm(x)
        x = self.head.flatten(x)
        x = self.head.drop(x)
        return x if pre_logits else self.head.fc(x)

    def forward_features(self, x: Tensor):
        x = self.stem(x)
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.stages, x)
        else:
            x = self.stages(x)
        return x

    def forward(self, x: Tensor):
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x


# this works but it's long and breaks backwards compatability with weights from the poolformer-only impl
def checkpoint_filter_fn(state_dict, model):
    if 'stem.conv.weight' in state_dict:
        return state_dict

    import re
    out_dict = {}
    is_poolformerv1 = 'network.0.0.mlp.fc1.weight' in state_dict
    model_state_dict = model.state_dict()
    for k, v in state_dict.items():
        if is_poolformerv1:
            k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k)
            k = k.replace('network.1', 'downsample_layers.1')
            k = k.replace('network.3', 'downsample_layers.2')
            k = k.replace('network.5', 'downsample_layers.3')
            k = k.replace('network.2', 'network.1')
            k = k.replace('network.4', 'network.2')
            k = k.replace('network.6', 'network.3')
            k = k.replace('network', 'stages')

        k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
        k = k.replace('downsample.proj', 'downsample.conv')
        k = k.replace('patch_embed.proj', 'patch_embed.conv')
        k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
        k = k.replace('stages.0.downsample', 'patch_embed')
        k = k.replace('patch_embed', 'stem')
        k = k.replace('post_norm', 'norm')
        k = k.replace('pre_norm', 'norm')
        k = re.sub(r'^head', 'head.fc', k)
        k = re.sub(r'^norm', 'head.norm', k)

        if v.shape != model_state_dict[k] and v.numel() == model_state_dict[k].numel():
            v = v.reshape(model_state_dict[k].shape)

        out_dict[k] = v
    return out_dict


def _create_metaformer(variant, pretrained=False, **kwargs):
    default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2))))
    out_indices = kwargs.pop('out_indices', default_out_indices)

    model = build_model_with_cfg(
        MetaFormer,
        variant,
        pretrained,
        pretrained_filter_fn=checkpoint_filter_fn,
        feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
        **kwargs,
    )

    return model


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
        'crop_pct': 1.0, 'interpolation': 'bicubic',
        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
        'classifier': 'head.fc', 'first_conv': 'stem.conv',
        **kwargs
    }


default_cfgs = generate_default_cfgs({
    'poolformer_s12.sail_in1k': _cfg(
        hf_hub_id='timm/',
        crop_pct=0.9),
    'poolformer_s24.sail_in1k': _cfg(
        hf_hub_id='timm/',
        crop_pct=0.9),
    'poolformer_s36.sail_in1k': _cfg(
        hf_hub_id='timm/',
        crop_pct=0.9),
    'poolformer_m36.sail_in1k': _cfg(
        hf_hub_id='timm/',
        crop_pct=0.95),
    'poolformer_m48.sail_in1k': _cfg(
        hf_hub_id='timm/',
        crop_pct=0.95),

    'poolformerv2_s12.sail_in1k': _cfg(hf_hub_id='timm/'),
    'poolformerv2_s24.sail_in1k': _cfg(hf_hub_id='timm/'),
    'poolformerv2_s36.sail_in1k': _cfg(hf_hub_id='timm/'),
    'poolformerv2_m36.sail_in1k': _cfg(hf_hub_id='timm/'),
    'poolformerv2_m48.sail_in1k': _cfg(hf_hub_id='timm/'),

    'convformer_s18.sail_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'convformer_s18.sail_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'convformer_s18.sail_in22k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'convformer_s18.sail_in22k_ft_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'convformer_s18.sail_in22k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', num_classes=21841),

    'convformer_s36.sail_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'convformer_s36.sail_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'convformer_s36.sail_in22k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'convformer_s36.sail_in22k_ft_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'convformer_s36.sail_in22k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', num_classes=21841),

    'convformer_m36.sail_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'convformer_m36.sail_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'convformer_m36.sail_in22k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'convformer_m36.sail_in22k_ft_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'convformer_m36.sail_in22k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', num_classes=21841),

    'convformer_b36.sail_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'convformer_b36.sail_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'convformer_b36.sail_in22k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'convformer_b36.sail_in22k_ft_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'convformer_b36.sail_in22k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', num_classes=21841),

    'caformer_s18.sail_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'caformer_s18.sail_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'caformer_s18.sail_in22k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'caformer_s18.sail_in22k_ft_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'caformer_s18.sail_in22k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', num_classes=21841),

    'caformer_s36.sail_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'caformer_s36.sail_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'caformer_s36.sail_in22k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'caformer_s36.sail_in22k_ft_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'caformer_s36.sail_in22k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', num_classes=21841),

    'caformer_m36.sail_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'caformer_m36.sail_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'caformer_m36.sail_in22k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'caformer_m36.sail_in22k_ft_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'caformer_m36.sail_in22k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', num_classes=21841),

    'caformer_b36.sail_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'caformer_b36.sail_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'caformer_b36.sail_in22k_ft_in1k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2'),
    'caformer_b36.sail_in22k_ft_in1k_384': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
    'caformer_b36.sail_in22k': _cfg(
        hf_hub_id='timm/',
        classifier='head.fc.fc2', num_classes=21841),
})


@register_model
def poolformer_s12(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[2, 2, 6, 2],
        dims=[64, 128, 320, 512],
        downsample_norm=None,
        mlp_act=nn.GELU,
        mlp_bias=True,
        norm_layers=GroupNorm1,
        layer_scale_init_values=1e-5,
        res_scale_init_values=None,
        use_mlp_head=False,
        **kwargs)
    return _create_metaformer('poolformer_s12', pretrained=pretrained, **model_kwargs)


@register_model
def poolformer_s24(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[4, 4, 12, 4],
        dims=[64, 128, 320, 512],
        downsample_norm=None,
        mlp_act=nn.GELU,
        mlp_bias=True,
        norm_layers=GroupNorm1,
        layer_scale_init_values=1e-5,
        res_scale_init_values=None,
        use_mlp_head=False,
        **kwargs)
    return _create_metaformer('poolformer_s24', pretrained=pretrained, **model_kwargs)


@register_model
def poolformer_s36(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[6, 6, 18, 6],
        dims=[64, 128, 320, 512],
        downsample_norm=None,
        mlp_act=nn.GELU,
        mlp_bias=True,
        norm_layers=GroupNorm1,
        layer_scale_init_values=1e-6,
        res_scale_init_values=None,
        use_mlp_head=False,
        **kwargs)
    return _create_metaformer('poolformer_s36', pretrained=pretrained, **model_kwargs)


@register_model
def poolformer_m36(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[6, 6, 18, 6],
        dims=[96, 192, 384, 768],
        downsample_norm=None,
        mlp_act=nn.GELU,
        mlp_bias=True,
        norm_layers=GroupNorm1,
        layer_scale_init_values=1e-6,
        res_scale_init_values=None,
        use_mlp_head=False,
        **kwargs)
    return _create_metaformer('poolformer_m36', pretrained=pretrained, **model_kwargs)


@register_model
def poolformer_m48(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[8, 8, 24, 8],
        dims=[96, 192, 384, 768],
        downsample_norm=None,
        mlp_act=nn.GELU,
        mlp_bias=True,
        norm_layers=GroupNorm1,
        layer_scale_init_values=1e-6,
        res_scale_init_values=None,
        use_mlp_head=False,
        **kwargs)
    return _create_metaformer('poolformer_m48', pretrained=pretrained, **model_kwargs)


@register_model
def poolformerv2_s12(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[2, 2, 6, 2],
        dims=[64, 128, 320, 512],
        norm_layers=GroupNorm1NoBias,
        use_mlp_head=False,
        **kwargs)
    return _create_metaformer('poolformerv2_s12', pretrained=pretrained, **model_kwargs)


@register_model
def poolformerv2_s24(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[4, 4, 12, 4],
        dims=[64, 128, 320, 512],
        norm_layers=GroupNorm1NoBias,
        use_mlp_head=False,
        **kwargs)
    return _create_metaformer('poolformerv2_s24', pretrained=pretrained, **model_kwargs)


@register_model
def poolformerv2_s36(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[6, 6, 18, 6],
        dims=[64, 128, 320, 512],
        norm_layers=GroupNorm1NoBias,
        use_mlp_head=False,
        **kwargs)
    return _create_metaformer('poolformerv2_s36', pretrained=pretrained, **model_kwargs)


@register_model
def poolformerv2_m36(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[6, 6, 18, 6],
        dims=[96, 192, 384, 768],
        norm_layers=GroupNorm1NoBias,
        use_mlp_head=False,
        **kwargs)
    return _create_metaformer('poolformerv2_m36', pretrained=pretrained, **model_kwargs)


@register_model
def poolformerv2_m48(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[8, 8, 24, 8],
        dims=[96, 192, 384, 768],
        norm_layers=GroupNorm1NoBias,
        use_mlp_head=False,
        **kwargs)
    return _create_metaformer('poolformerv2_m48', pretrained=pretrained, **model_kwargs)


@register_model
def convformer_s18(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[3, 3, 9, 3],
        dims=[64, 128, 320, 512],
        token_mixers=SepConv,
        norm_layers=LayerNorm2dNoBias,
        **kwargs)
    return _create_metaformer('convformer_s18', pretrained=pretrained, **model_kwargs)


@register_model
def convformer_s36(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[3, 12, 18, 3],
        dims=[64, 128, 320, 512],
        token_mixers=SepConv,
        norm_layers=LayerNorm2dNoBias,
        **kwargs)
    return _create_metaformer('convformer_s36', pretrained=pretrained, **model_kwargs)


@register_model
def convformer_m36(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[3, 12, 18, 3],
        dims=[96, 192, 384, 576],
        token_mixers=SepConv,
        norm_layers=LayerNorm2dNoBias,
        **kwargs)
    return _create_metaformer('convformer_m36', pretrained=pretrained, **model_kwargs)


@register_model
def convformer_b36(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[3, 12, 18, 3],
        dims=[128, 256, 512, 768],
        token_mixers=SepConv,
        norm_layers=LayerNorm2dNoBias,
        **kwargs)
    return _create_metaformer('convformer_b36', pretrained=pretrained, **model_kwargs)


@register_model
def caformer_s18(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[3, 3, 9, 3],
        dims=[64, 128, 320, 512],
        token_mixers=[SepConv, SepConv, Attention, Attention],
        norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
        **kwargs)
    return _create_metaformer('caformer_s18', pretrained=pretrained, **model_kwargs)


@register_model
def caformer_s36(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[3, 12, 18, 3],
        dims=[64, 128, 320, 512],
        token_mixers=[SepConv, SepConv, Attention, Attention],
        norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
        **kwargs)
    return _create_metaformer('caformer_s36', pretrained=pretrained, **model_kwargs)


@register_model
def caformer_m36(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[3, 12, 18, 3],
        dims=[96, 192, 384, 576],
        token_mixers=[SepConv, SepConv, Attention, Attention],
        norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
        **kwargs)
    return _create_metaformer('caformer_m36', pretrained=pretrained, **model_kwargs)


@register_model
def caformer_b36(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[3, 12, 18, 3],
        dims=[128, 256, 512, 768],
        token_mixers=[SepConv, SepConv, Attention, Attention],
        norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
        **kwargs)
    return _create_metaformer('caformer_b36', pretrained=pretrained, **model_kwargs)
