
    Ngܵ                        d Z ddlZddlmZmZ ddlmZ ddlmZm	Z	m
Z
mZmZ ddlZddlZddlmZ ddlmZmZ ddlmZmZmZmZmZmZ ddlmZmZmZmZ d	d
l m!Z! d	dl"m#Z# d	dl$m%Z%m&Z& d	dl'm(Z(m)Z)m*Z* ddgZ+e G d d                      Z,d Z-d[dZ.d\dZ/	 	 	 	 	 d]dZ0	 	 	 	 	 d]dZ1	 	 	 d^dZ2 G d dej3                  Z4 G d dej3                  Z5 G d d ej3                  Z6 G d! dej3                  Z7d_d#Z8d$ Z9 e:d`i d% e,d&d'd(dd)*          d+ e,d&d,d-d.d/*          d0 e,d&d,d-d.d/d12          d3 e,d4d5d6d&d.*          d7 e,d8d9d:d.d.*          d; e,d<d=d>d&d?*          d@ e,dAdBd>d4dC*          dD e,dEdFdGdHdI*          dJ e,dKdLdMd8dN*          dO e,d<dPdQdRdI*          dS e,dTdUdVdWdX*          dY e,dZd[d\d]d/*          d^ e,d_d`dadTdI*          db e,d&d'd(dd)dcd          de e,d4dfdgdd.dcd          dh e,d4didjd.dkdcd          dl e,d8dmdnd.dodcd          dp e,d8dmdnd.dodcd1q          dr e,d4dsdtd&dudcd          dv e,d<dwdxd&dydcd          dz e,dEd{d6d|d/dcd          d} e,dWd~dddCdcd          d e,dddd8dNdcd          d e,dddd8dNdcd1q          d e,dTdUdVdWdXdcd          d e,ddddWd?dcd          d e,ddddddcd          d e,dddndddcd          d e,dddddudcd          d e,dddddudcd          d e,dEd{d6d|d/dcd eed.                    d e,d/dEd{d6d|dcdd          d e,dCdWd~dddcddd	  	        d e,dyd.dddddcdddd          d e,dd4dddddcdddd          d e,dd4dddddcdddd          Z;d Z<dadZ=dadZ>dadZ? e(i d e=dd          d e=dd          d e=dd          d e=dd          d e=dæ          d e=dæ          d e=dæ          d e=ddǬȦ          d e=ddǬȦ          d e=d"˦          d e=dddάϦ          d e=dddάϦ          d e=d"˦          d e=ddddddج٦          d e=ddddddج٦          d e=ddݬ          d e?dd߬          i d e?ddᬻ          d e?dd㬻          d e?dd嬻          d e?dd笻          d e?dd鬻          d e?dd묻          d e?dd          d e?ddכּ          d e?dd񬻦          d e?dd󬻦          d e?dd          d e?dd          d e?dd          d e?dddddd׬          d  e?dddddd׬          d e?dddddd׬          d e?ddd          i d e?ddd          d	 e?dd
d          d e?dddddddא          d e?dddddddא          d e?dddddddא          d e?dddddddא          d e?ddddd          d e?ddddd          d e?ddddd          d e>dæ          d e>dæ          d e>dæ          d  e>dæ          d! e>dæ          d" e>dæ          d# e>dæ          d$ e>dæ          i d% e>dæ          d& e>dæ          d' e>dæ          d( e>dæ          d) e>dæ          d* e>dæ          d+ e>dæ          d, e>dæ          d- e>dæ          d. e>dæ          d/ e>dæ          d0 e>dæ          d1 e>dæ          d2 e>dæ          d3 e>dæ          d4 e>dæ                    Z@e)dbd5e7fd6            ZAe)dbd5e7fd7            ZBe)dbd5e7fd8            ZCe)dbd5e7fd9            ZDe)dbd5e7fd:            ZEe)dbd5e7fd;            ZFe)dbd5e7fd<            ZGe)dbd5e7fd=            ZHe)dbd5e7fd>            ZIe)dbd5e7fd?            ZJe)dbd5e7fd@            ZKe)dbd5e7fdA            ZLe)dbd5e7fdB            ZMe)dbd5e7fdC            ZNe)dbd5e7fdD            ZOe)dbd5e7fdE            ZPe)dbd5e7fdF            ZQe)dbd5e7fdG            ZRe)dbd5e7fdH            ZSe)dbd5e7fdI            ZTe)dbd5e7fdJ            ZUe)dbd5e7fdK            ZVe)dbd5e7fdL            ZWe)dbd5e7fdM            ZXe)dbd5e7fdN            ZYe)dbd5e7fdO            ZZe)dbd5e7fdP            Z[e)dbd5e7fdQ            Z\e)dbd5e7fdR            Z]e)dbd5e7fdS            Z^e)dbd5e7fdT            Z_e)dbd5e7fdU            Z`e)dbd5e7fdV            Zae)dbd5e7fdW            Zbe)dbd5e7fdX            Zce)dbd5e7fdY            Zd e*eedZdi           dS (c  a  RegNet X, Y, Z, and more

Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678
Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py

Paper: `Fast and Accurate Model Scaling` - https://arxiv.org/abs/2103.06877
Original Impl: None

Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here)
and cleaned up with more descriptive variable names.

Weights from original pycls impl have been modified:
* first layer from BGR -> RGB as most PyTorch models are
* removed training specific dict entries from checkpoints and keep model state_dict only
* remap names to match the ones here

Supports weight loading from torchvision and classy-vision (incl VISSL SEER)

A number of custom timm model definitions additions including:
* stochastic depth, gradient checkpointing, layer-decay, configurable dilation
* a pre-activation 'V' variant
* only known RegNet-Z model definitions with pretrained weights

Hacked together by / Copyright 2020 Ross Wightman
    N)	dataclassreplace)partial)CallableListOptionalUnionTupleIMAGENET_DEFAULT_MEANIMAGENET_DEFAULT_STD)ClassifierHeadAvgPool2dSameConvNormActSEModuleDropPathGroupNormAct)get_act_layerget_norm_act_layercreate_conv2dmake_divisible   )build_model_with_cfg)feature_take_indices)checkpoint_seqnamed_apply)generate_default_cfgsregister_modelregister_model_deprecationsRegNet	RegNetCfgc                      e Zd ZU dZeed<   dZeed<   dZeed<   dZ	eed<   d	Z
eed
<   dZeed<   dZeed<   dZeed<   dZeed<   dZee         ed<   dZeed<   dZeed<   dZeed<   dZeeef         ed<   dZeeef         ed<   dS )r!      depthP   w0q=
ףPE@waHzG@wm   
group_size      ?bottle_ratio        se_ratiogroup_min_ratio    
stem_widthconv1x1
downsampleF
linear_outpreactr   num_featuresrelu	act_layer	batchnorm
norm_layerN)__name__
__module____qualname__r$   int__annotations__r&   r(   floatr*   r,   r.   r0   r1   r3   r5   r   strr6   boolr7   r8   r:   r	   r   r<        N/var/www/html/ai-engine/env/lib/python3.11/site-packages/timm/models/regnet.pyr!   r!   .   s        E3OOOBLLLBBJL%HeOUJ )J)))JFDL#&,IuS(]#,,,'2Jc8m$22222rF   c                 F    t          t          | |z            |z            S )z<Converts a float to the closest non-zero int divisible by q.)r@   round)fqs     rG   quantize_floatrL   A   s    uQU||a   rF   r/   c                    d t          | |          D             }d t          ||          D             }rfdt          ||          D             }nd t          ||          D             }d t          ||          D             } | |fS )z/Adjusts the compatibility of widths and groups.c                 8    g | ]\  }}t          ||z            S rE   r@   ).0wbs      rG   
<listcomp>z-adjust_widths_groups_comp.<locals>.<listcomp>H   s&    KKK1QUKKKrF   c                 4    g | ]\  }}t          ||          S rE   )min)rP   gw_bots      rG   rS   z-adjust_widths_groups_comp.<locals>.<listcomp>I   s$    KKK5c!UmmKKKrF   c                 8    g | ]\  }}t          ||          S rE   )r   )rP   rW   rV   	min_ratios      rG   rS   z-adjust_widths_groups_comp.<locals>.<listcomp>L   s)    pppXUA^E1i@@ppprF   c                 4    g | ]\  }}t          ||          S rE   )rL   )rP   rW   rV   s      rG   rS   z-adjust_widths_groups_comp.<locals>.<listcomp>N   s&    eee(%^E155eeerF   c                 8    g | ]\  }}t          ||z            S rE   rO   )rP   rW   rR   s      rG   rS   z-adjust_widths_groups_comp.<locals>.<listcomp>O   s&    SSSc%!)nnSSSrF   )zip)widthsbottle_ratiosgroupsrY   bottleneck_widthss      ` rG   adjust_widths_groups_compra   F   s    KKFM0J0JKKKKKC8I,J,JKKKF fppppQTUfhnQoQopppeecJ[]cFdFdeeeSSS1BM-R-RSSSF6>rF      c           	         | dk    r|dk    r|dk    r	||z  dk    sJ t          j        |          | z  |z   }t          j        t          j        ||z            t          j        |          z            }t          j        t          j        |t          j        ||          z  |                    |z  }t          t          j        |                    |                                dz   }
}	t          j	        fdt          |	          D                       }|                    t                                                    |	|                    t                                                    fS )z2Generates per block widths from RegNet parameters.r   r   c                     g | ]}S rE   rE   )rP   _r,   s     rG   rS   z#generate_regnet.<locals>.<listcomp>]   s    ===az===rF   )nparangerI   logdividepowerlenuniquemaxarrayrangeastyper@   tolist)width_slopewidth_initial
width_multr$   r,   quantwidths_cont
width_expsr]   
num_stages	max_stager_   s       `       rG   generate_regnetrz   S   sE   ! 1 1j1nnY^I^bcIcIcIcIc )E""[0=@K"&}!<==z@R@RRSSJXbiZ0P0P PRWXXYY\aaF	& 1 122JNN4D4Dq4H	JX====5+<+<===>>F==$$&&
FMM#4F4F4M4M4O4OOOrF   Fc           	          |pt           j        }|dk    r|dk    rdn|}|dk    r|nd}|rt          | ||||          S t          | |||||d          S )Nr   )stridedilationF)r|   r}   r<   	apply_act)nnBatchNorm2dr   r   )in_chsout_chskernel_sizer|   r}   r<   r7   s          rG   downsample_convr   a   s     -r~J{{x1}}!!+K&??xxH 

 
 
 	
 !
 
 
 	
rF   c                 B   |pt           j        }|dk    r|nd}t          j                    }|dk    s|dk    r.|dk    r|dk    rt          nt           j        }	 |	d|dd          }|rt          | |dd          }
nt          | |dd|d          }
t          j        ||
g S )zd AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.r      TF)	ceil_modecount_include_padr|   )r|   r<   r~   )r   r   Identityr   	AvgPool2dr   r   
Sequential)r   r   r   r|   r}   r<   r7   
avg_stridepoolavg_pool_fnconvs              rG   downsample_avgr      s     -r~J#q==aJ;==DzzX\\'1Q8a<<mmR\{1jDERRR aVWa:::67AaJZ_```=4,''rF   r   r   c                     | dv sJ ||k    s|dk    s|d         |d         k    rAt          ||d         ||          }| sd S | dk    rt          ||fi |S t          ||fd|i|S t          j                    S )N)avgr4    Nr   r   )r|   r}   r<   r7   r   r   )dictr   r   r   r   )	downsample_typer   r   r   r|   r}   r<   r7   dargss	            rG   create_shortcutr      s     :::::FaKK8A;(1++E+EFXa[ZX^___ 	V4%%!&';;U;;;"67UUUuUUU{}}rF   c                   Z     e Zd ZdZdddddddej        ej        ddf fd		Zd
 Zd Z	 xZ
S )
Bottleneck RegNet Bottleneck

    This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from
    after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
    r   r         ?r4   FNr/   c           	         t          t          |                                            t          |
          }
t	          t          ||z                      }||z  }t          |
|          }t          ||fddi|| _        t          ||fd||d         ||d|| _	        |r7t	          t          ||z                      }t          |||
          | _        nt          j                    | _        t          ||fddd	|| _        |	rt          j                    n	 |
            | _        t!          |||d|||
          | _        |dk    rt%          |          nt          j                    | _        d S )Nr:   r<   r   r      r   )r   r|   r}   r_   
drop_layerrd_channelsr:   F)r   r~   )r   r|   r}   r<   )superr   __init__r   r@   rI   r   r   conv1conv2r   ser   r   conv3act3r   r5   r   	drop_path)selfr   r   r|   r}   r.   r,   r0   r5   r6   r:   r<   
drop_blockdrop_path_ratebottleneck_chsr_   cargsse_channels	__class__s                     rG   r   zBottleneck.__init__   s     	j$((***!),,	U7\#9::;;:-yZ@@@ PPQP%PP
 	
 a[!	
 	
 	
 	

  	$eFX$56677K~;R[\\\DGGkmmDG baSXbb\abb
%/@BKMMMYY[[	)!
 
 
 6Da5G5G.111R[]]rF   c                 b    t           j                            | j        j        j                   d S N)r   initzeros_r   bnweightr   s    rG   zero_init_lastzBottleneck.zero_init_last   s#    
tz}+,,,,,rF   c                 @   |}|                      |          }|                     |          }|                     |          }|                     |          }| j        +|                     |          |                     |          z   }|                     |          }|S r   )r   r   r   r   r5   r   r   r   xshortcuts      rG   forwardzBottleneck.forward   s    JJqMMJJqMMGGAJJJJqMM?& q!!DOOH$=$==AIIaLLrF   r=   r>   r?   __doc__r   ReLUr   r   r   r   __classcell__r   s   @rG   r   r      s           g~1[ 1[ 1[ 1[ 1[ 1[f- - -      rF   r   c                   Z     e Zd ZdZdddddddej        ej        ddf fd		Zd
 Zd Z	 xZ
S )PreBottleneckr   r   r   r   r4   FNr/   c           	         t          t          |                                            t          ||
          }t	          t          ||z                      }||z  } ||          | _        t          ||d          | _         ||          | _	        t          ||d||d         |          | _
        |r7t	          t          ||z                      }t          |||
          | _        nt          j                    | _         ||          | _        t          ||d          | _        t#          |||d||d          | _        |dk    rt'          |          nt          j                    | _        d S )	Nr   )r   r   r   )r   r|   r}   r_   r   T)r   r|   r}   r7   )r   r   r   r   r@   rI   norm1r   r   norm2r   r   r   r   r   norm3r   r   r5   r   r   )r   r   r   r|   r}   r.   r,   r0   r5   r6   r:   r<   r   r   norm_act_layerr   r_   r   r   s                     rG   r   zPreBottleneck.__init__  s     	mT""++---+J	BBU7\#9::;;:-#^F++
"6>qIII
#^N33
"a[
 
 

  	$eFX$56677K~;R[\\\DGGkmmDG#^N33
">7JJJ
)
 
 
 6Da5G5G.111R[]]rF   c                     d S r   rE   r   s    rG   r   zPreBottleneck.zero_init_last3  s    rF   c                    |                      |          }|}|                     |          }|                     |          }|                     |          }|                     |          }|                     |          }|                     |          }| j        +|                     |          |                     |          z   }|S r   )	r   r   r   r   r   r   r   r5   r   r   s      rG   r   zPreBottleneck.forward6  s    JJqMMJJqMMJJqMMJJqMMGGAJJJJqMMJJqMM?& q!!DOOH$=$==ArF   r   r   s   @rG   r   r      s           g~0[ 0[ 0[ 0[ 0[ 0[d        rF   r   c                   .     e Zd ZdZdef fd	Zd Z xZS )RegStagez4Stage (sequence of blocks w/ the same output shape).Nc                 X   t          t          |                                            d| _        |dv rdnd}	t	          |          D ]b}
|
dk    r|nd}|
dk    r|n|}|	|f}|||
         nd}d                    |
dz             }|                     | |||f|||d|           |}	cd S )	NF)r   r   r   r   r   r/   zb{})r|   r}   r   )r   r   r   grad_checkpointingro   format
add_module)r   r$   r   r   r|   r}   drop_path_ratesblock_fnblock_kwargsfirst_dilationiblock_strideblock_in_chsblock_dilationdprnamer   s                   rG   r   zRegStage.__init__I  s    	h&&((("'&&00au 	& 	&A%&!VV66L%&!VV66L,h7N(7(C/!$$C<<A&&DOO  (+#&  # 
 
 
 &NN#	& 	&rF   c                     | j         rAt          j                                        s#t	          |                                 |          }n"|                                 D ]} ||          }|S r   )r   torchjitis_scriptingr   children)r   r   blocks      rG   r   zRegStage.forwardk  sf    " 	59+A+A+C+C 	t}}22AA  E!HHrF   )r=   r>   r?   r   r   r   r   r   r   s   @rG   r   r   F  s[        >> ! &  &  &  &  &  &D      rF   r   c                        e Zd ZdZ	 	 	 	 	 	 	 d%def fd	Zd&defdZej        j	        d'd            Z
ej        j	        d(d            Zej        j	        dej        fd            Zd)dedee         fdZ	 	 	 	 	 d*dej        deeeee         f                  dededededeeej                 eej        eej                 f         f         fdZ	 	 	 d+deeee         f         dedefd Zd! Zd'd"efd#Zd$ Z xZS ),r    zRegNet-X, Y, and Z Models

    Paper: https://arxiv.org/abs/2003.13678
    Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
    r     r2   r   r/   Tcfgc	           
         t                                                       || _        || _        |dv sJ t	          |fi |	}|j        }
t          |j        |j                  }|j	        rt          ||
dd          | _        nt          ||
dfddi|| _        t          |
dd          g| _        |
}d}|                     |||	          \  }}t          |          d
k    sJ |j	        rt           nt"          }t%          |          D ]v\  }}d                    |dz             }|                     |t+          d||d||           |d         }||d         z  }| xj        t          |||          gz  c_        w|j        r't          ||j        fddi|| _        |j        | _        nK|j        p|j	        }|r t3          |j                              nt5          j                    | _        || _        | j        | _        t;          | j        |||          | _        t?          tA          tB          |          |            dS )a  

        Args:
            cfg (RegNetCfg): Model architecture configuration
            in_chans (int): Number of input channels (default: 3)
            num_classes (int): Number of classifier classes (default: 1000)
            output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
            global_pool (str): Global pooling type (default: 'avg')
            drop_rate (float): Dropout rate (default: 0.)
            drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
            zero_init_last (bool): Zero-init last weight of residual path
            kwargs (dict): Extra kwargs overlayed onto cfg
        )rb      r2   r   r   r   r   r|   stem)num_chs	reductionmodule)output_strider      zs{}r   )r   r   r   r   )in_featuresnum_classes	pool_type	drop_rate)r   NrE   )"r   r   r   r   r   r3   r   r:   r<   r7   r   r   r   feature_info_get_stage_argsrk   r   r   	enumerater   r   r   r8   
final_convr6   r   r   r   head_hidden_sizer   headr   r   _init_weights)r   r   in_chansr   r   global_poolr   r   r   kwargsr3   na_args
prev_widthcurr_strideper_stage_argscommon_argsr   r   
stage_args
stage_name	final_actr   s                        rG   r   zRegNet.__init__{  s   2 	&"++++c$$V$$ ^
3>JJJ: 	R%h
AaHHHDII#Hj!QQAQQQDI!*&QQQR  
&*&:&:') '; '
 '
#
 >""a''''$'J>==J&~66 	f 	fMAza!e,,JOO %%  ! "	    $I.J:h//K$z[Yc"d"d"d!ee  	+)*c6FaaTUaY`aaDO # 0D4#*I@I\:mCM::<<<r{}}DO *D $ 1")#!	
 
 
	 	GM.III4PPPPPrF   r   c           
         t          j        j        j        j        j                  \  }}}t          j        |d          \  }}	fdt          |          D             }
g }g }d}d}t          |          D ]A}||k    r||z  }d}n|}||z  }|	                    |           |	                    |           Bt          j
        t          j        d|t          |	                    t          j        |	d d                             }t          ||
|j                  \  }}g d	fd
t!          ||||	|
||          D             }t#          j        j        j        j        j                  }||fS )NT)return_countsc                     g | ]	}j         
S rE   )r.   )rP   re   r   s     rG   rS   z*RegNet._get_stage_args.<locals>.<listcomp>  s    @@@C$@@@rF   r   r   r   )rY   )r   r|   r}   r$   r.   r,   r   c                 J    g | ]}t          t          |                     S rE   )r   r\   )rP   params	arg_namess     rG   rS   z*RegNet._get_stage_args.<locals>.<listcomp>  s8     
 
 
-3DY''((
 
 
rF   )r5   r0   r6   r:   r<   )rz   r(   r&   r*   r$   r,   rf   rl   ro   appendsplitlinspacesumcumsumra   r1   r\   r   r5   r0   r6   r:   r<   )r   r   default_strider   r   r]   rx   stage_gsstage_widthsstage_depthsstage_brstage_stridesstage_dilations
net_strider}   re   r|   	stage_dprr   r  r  s    `                  @rG   r   zRegNet._get_stage_args  s   '6svsvsvsyZ]Zh'i'i$
H &(YvT%J%J%J"l@@@@eJ.?.?@@@
z"" 	- 	-A]**N*'f$
  (((""8,,,,HR[NC<M<MNNPRPYZfgjhjgjZkPlPlmm	 ";(H8K"M "M "Mhooo	
 
 
 
m_lHV^`ijj
 
 
 ~\~m~
 
 
 {**rF   Fc                 ,    t          d|rdnd          S )Nz^stemz^s(\d+)z^s(\d+)\.b(\d+))r   blocks)r   )r   coarses     rG   group_matcherzRegNet.group_matcher  s)    !'?::-?
 
 
 	
rF   c                 l    t          |                                           dd         D ]	}||_        
d S )Nr   r  )listr   r   )r   enabless      rG   set_grad_checkpointingzRegNet.set_grad_checkpointing  s?    dmmoo&&qt, 	* 	*A#)A  	* 	*rF   returnc                     | j         j        S r   )r   fcr   s    rG   get_classifierzRegNet.get_classifier  s    y|rF   Nr   r   c                 >    | j                             ||           d S )N)r   )r   reset)r   r   r   s      rG   reset_classifierzRegNet.reset_classifier  s     	{;;;;;rF   NCHWr   indicesnorm
stop_early
output_fmtintermediates_onlyc                    |dv s
J d            g }t          d|          \  }}	d}
|                     |          }|
|v r|                    |           d}|r
|d|	         }|D ]9}|
dz  }
 t          | |          |          }|
|v r|                    |           :|r|S |
dk    r|                     |          }||fS )	a   Forward features that returns intermediates.

        Args:
            x: Input image tensor
            indices: Take last n blocks if int, all if None, select matching indices if sequence
            norm: Apply norm layer to compatible intermediates
            stop_early: Stop iterating over blocks when last desired intermediate hit
            output_fmt: Shape of intermediate feature outputs
            intermediates_only: Only return intermediate features
        Returns:

        )r*  zOutput shape must be NCHW.   r   s1s2s3s4Nr   r   )r   r   r  getattrr   )r   r   r+  r,  r-  r.  r/  intermediatestake_indices	max_indexfeat_idxlayer_namesns                rG   forward_intermediateszRegNet.forward_intermediates  s   * Y&&&(D&&&"6q'"B"Bi IIaLL|##  ###. 	2%jyj1K 	( 	(AMH a  ##A<''$$Q''' 	!  q==""A-rF   r   
prune_norm
prune_headc                     t          d|          \  }}d}||d         }|D ]$}t          | |t          j                               %|dk     rt          j                    | _        |r|                     dd           |S )z@ Prune layers not required for specified intermediates.
        r1  r2  Nr   r   r   )r   setattrr   r   r   r)  )r   r+  r?  r@  r9  r:  r<  r=  s           rG   prune_intermediate_layersz RegNet.prune_intermediate_layers8  s     #7q'"B"Bi.!)**- 	, 	,AD!R[]]++++q== kmmDO 	)!!!R(((rF   c                    |                      |          }|                     |          }|                     |          }|                     |          }|                     |          }|                     |          }|S r   )r   r3  r4  r5  r6  r   r   r   s     rG   forward_featureszRegNet.forward_featuresK  sc    IIaLLGGAJJGGAJJGGAJJGGAJJOOArF   
pre_logitsc                 ^    |r|                      ||          n|                      |          S )N)rG  )r   )r   r   rG  s      rG   forward_headzRegNet.forward_headT  s-    6@Rtyyzy222diiPQllRrF   c                 Z    |                      |          }|                     |          }|S r   )rF  rI  rE  s     rG   r   zRegNet.forwardW  s-    !!!$$a  rF   )r   r   r2   r   r/   r/   T)r   r2   r/   F)Tr   )NFFr*  F)r   FT)r=   r>   r?   r   r!   r   r   r   r   ignorer  r"  r   Moduler&  r@   r   rC   r)  Tensorr	   r   rD   r
   r>  rC  rF  rI  r   r   r   s   @rG   r    r    t  s{         QQ QQQQ QQ QQ QQ QQ QQf%+ %+9 %+ %+ %+ %+N Y
 
 
 
 Y* * * * Y	    < <C <hsm < < < < 8<$$',.  . |.  eCcN34.  	. 
 .  .  !%.  
tEL!5tEL7I)I#JJ	K.  .  .  . d ./$#	 3S	>*  	   &  S S$ S S S S      rF   r   c                    t          | t          j                  r| j        d         | j        d         z  | j        z  }|| j        z  }| j        j                            dt          j
        d|z                       | j         | j        j                                         d S d S t          | t          j                  rVt          j                            | j        dd           | j        &t          j                            | j                   d S d S |r&t!          | d          r|                                  d S d S d S )Nr   r          @r/   g{Gz?)meanstdr   )
isinstancer   Conv2dr   out_channelsr_   r   datanormal_mathsqrtbiaszero_Linearr   r   hasattrr   )r   r   r   fan_outs       rG   r   r   ]  s<   &")$$  $Q'&*<Q*??&BUUFM!""1dig&>&>???;"K""$$$$$ #"	FBI	&	&  
CT:::;"GNN6;''''' #"	  GF,<==         rF   c                    |                      d|           } g d}d| v rdd l}| d         d         d         } i }| d                                         D ]\  }}|                    dd          }|                    d	d
          }|                    dd |          }|                    dd|          }|D ]\  }}|                    ||          }|||<   | d                                         D ])\  }}d|v sd|v r|                    dd          }|||<   *|S d| v rdd l}i }|                                 D ]\  }}|                    dd          }|                    dd
          }|                    dd |          }|D ]\  }}|                    ||          }|                    dd          }|||<   |S | S )Nmodel))zf.a.0z
conv1.conv)zf.a.1zconv1.bn)zf.b.0z
conv2.conv)zf.b.1zconv2.bn)z
f.final_bnconv3.bn)zf.se.excitation.0zse.fc1)zf.se.excitation.2zse.fc2)zf.ser   )zf.c.0
conv3.conv)zf.c.1ra  )zf.crb  )zproj.0downsample.conv)zproj.1zdownsample.bn)projrc  classy_state_dictr   
base_modeltrunkz_feature_blocks.conv1.stem.0	stem.convz_feature_blocks.conv1.stem.1zstem.bnz&^_feature_blocks.res\d.block(\d)-(\d+)c                     dt          |                     d                     dt          |                     d                    dz    S )Nr!  r   .br   r@   groupr   s    rG   <lambda>z_filter_fn.<locals>.<lambda>  <    Fc!''!**ooFFQWWQZZ11DFF rF   zs(\d)\.b(\d+)\.bnzs\1.b\2.downsample.bnheadsprojection_head
prototypesz0.clf.0head.fczstem.0.weightzstem.0zstem.1z)trunk_output.block(\d)\.block(\d+)\-(\d+)c                     dt          |                     d                     dt          |                     d                    dz    S )Nr!  r   rj  r   rk  rm  s    rG   rn  z_filter_fn.<locals>.<lambda>  ro  rF   zfc.zhead.fc.)getreitemsr   sub)
state_dictreplacesrv  outkvr!  rs           rG   
_filter_fnr  l  s:   44J  H  j((			 34\B7K
w'--// 		 		DAq		8+FFA		8)DDA9FFK KA +-EqIIA  $ $1IIaOOCFFw'--// 	 	DAq A%%):):		)Y//ACFF
*$$			$$&& 		 		DAq		(K00A		(I..A<FFK KA ! $ $1IIaOO		%,,ACFF
rF   regnetx_002r+   gQ8B@gQ@   )r&   r(   r*   r,   r$   regnetx_004g{Gz8@gRQ@r      regnetx_004_tvg?)r&   r(   r*   r,   r$   r1   regnetx_0060   g\(|B@gQ@regnetx_0088   g=
ףpA@g=
ףp=@regnetx_016r%   gzGA@g      @   regnetx_032X   g(\O:@   regnetx_040`   g33333SC@gq=
ףp@(      regnetx_064   g
ףp=jN@g(\ @   regnetx_080gHzH@g
ףp=
@x   regnetx_120   gףp=
WR@g(\@p      regnetx_160   gQK@g @   regnetx_320@  gףp=
wQ@rP  regnety_002r   )r&   r(   r*   r,   r$   r0   regnety_004gp=
;@gQ @regnety_006gQE@@g(\@   regnety_008gQkC@g333333@   regnety_008_tv)r&   r(   r*   r,   r$   r0   r1   regnety_016g(\µ4@g333333@   regnety_032r'   r)   r#   regnety_040g)\h?@@   regnety_064g\(@@g)\(@H   regnety_080   gGz4S@gQ@regnety_080_tvregnety_120regnety_160   gQZ@gףp=
@regnety_320   g)\\@g=
ףp=@   regnety_640i`  g(\ob@iH  regnety_1280i  g(\d@g)\(@i  regnety_2560i  g(\l@iu  regnety_040_sgnsilu)r,   )r&   r(   r*   r,   r$   r0   r:   r<   regnetv_040T)r$   r&   r(   r*   r,   r0   r7   r:   regnetv_064r   )	r$   r&   r(   r*   r,   r0   r7   r:   r5   regnetz_005gffffff%@gGz@r   g      @i   )r$   r&   r(   r*   r,   r.   r0   r5   r6   r8   r:   regnetz_040   g      -@g+@regnetz_040_hi   c                 P    t          t          | |ft          |          t          d|S )N)	model_cfgpretrained_filter_fn)r   r    
model_cfgsr  )variant
pretrainedr   s      rG   _create_regnetr    s:    W%'  	  rF   c                 8    | dddddddt           t          dd	d
|S )Nr   r      r     r  )r      r  gffffff?r-   bicubicrh  rs  )urlr   
input_size	pool_sizetest_input_sizecrop_pcttest_crop_pctinterpolationrQ  rR  
first_conv
classifierr   r  r   s     rG   _cfgr    s:    4}SY(dS",AJ^!	 
  rF   c                 8    | dddddt           t          dddd	d
|S )Nr   r  r  g      ?r  rh  rs  mitz)https://github.com/facebookresearch/pyclsr  r   r  r  r  r  rQ  rR  r  r  license
origin_urlr   r  s     rG   _cfgpycr    s=    4}SYI%.B!(S 
 X^ rF   c                 8    | dddddt           t          dddd	d
|S )Nr   r  r  gzG?r  rh  rs  zbsd-3-clausez!https://github.com/pytorch/visionr  r   r  s     rG   _cfgtv2r    s=    4}SYI%.B!!1T 
 Y_ rF   zregnety_032.ra_in1kztimm/znhttps://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth)	hf_hub_idr  zregnety_040.ra3_in1kzshttps://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_040_ra3-670e1166.pthzregnety_064.ra3_in1kzshttps://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_064_ra3-aa26dc7d.pthzregnety_080.ra3_in1kzshttps://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_080_ra3-1fdc4344.pthzregnety_120.sw_in12k_ft_in1k)r  zregnety_160.sw_in12k_ft_in1kzregnety_160.lion_in12k_ft_in1kzregnety_120.sw_in12ki-.  )r  r   zregnety_160.sw_in12kzregnety_040_sgn.untrained)r  zregnetv_040.ra3_in1kzshttps://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pthr   )r  r  r  zregnetv_064.ra3_in1kzshttps://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pthzregnetz_005.untrainedzregnetz_040.ra3_in1kzshttps://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040_ra3-9007edf5.pth)r      r  )rb   rb   r-   )r   r  r  )r  r  r  r  r  r  zregnetz_040_h.ra3_in1kzthttps://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040h_ra3-f594343b.pthzregnety_160.deit_in1kz<https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pthzregnetx_004_tv.tv2_in1kz?https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pthzregnetx_008.tv2_in1kz?https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pthzregnetx_016.tv2_in1kz?https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pthzregnetx_032.tv2_in1kz?https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pthzregnetx_080.tv2_in1kz=https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pthzregnetx_160.tv2_in1kz>https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pthzregnetx_320.tv2_in1kz>https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pthzregnety_004.tv2_in1kz?https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pthzregnety_008_tv.tv2_in1kz?https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pthzregnety_016.tv2_in1kz?https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pthzregnety_032.tv2_in1kz?https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pthzregnety_080_tv.tv2_in1kz=https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pthzregnety_160.tv2_in1kz>https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pthzregnety_320.tv2_in1kz>https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pthzregnety_160.swag_ft_in1kzChttps://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pthzcc-by-nc-4.0)r     r  )   r  )r  r  r  r  r  r  zregnety_320.swag_ft_in1kzChttps://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pthzregnety_1280.swag_ft_in1kzDhttps://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pthzregnety_160.swag_lc_in1kzFhttps://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth)r  r  r  zregnety_320.swag_lc_in1kzFhttps://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pthzregnety_1280.swag_lc_in1kzGhttps://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pthzregnety_320.seer_ft_in1kotherz)https://github.com/facebookresearch/visslzhttps://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch)r  r  r  r  r  r  r  zregnety_640.seer_ft_in1kzhttps://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torchzregnety_1280.seer_ft_in1kzhttps://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torchzregnety_2560.seer_ft_in1kzhttps://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet256_finetuned_in1k_model_final_checkpoint_phase38.torchzregnety_320.seerzihttps://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch)r  r  r   r  r  zregnety_640.seerzphttps://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torchzregnety_1280.seerzhttps://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torchzregnetx_002.pycls_in1kzregnetx_004.pycls_in1kzregnetx_006.pycls_in1kzregnetx_008.pycls_in1kzregnetx_016.pycls_in1kzregnetx_032.pycls_in1kzregnetx_040.pycls_in1kzregnetx_064.pycls_in1kzregnetx_080.pycls_in1kzregnetx_120.pycls_in1kzregnetx_160.pycls_in1kzregnetx_320.pycls_in1kzregnety_002.pycls_in1kzregnety_004.pycls_in1kzregnety_006.pycls_in1kzregnety_008.pycls_in1kzregnety_016.pycls_in1kzregnety_032.pycls_in1kzregnety_040.pycls_in1kzregnety_064.pycls_in1kzregnety_080.pycls_in1kzregnety_120.pycls_in1kzregnety_160.pycls_in1kzregnety_320.pycls_in1kr#  c                     t          d| fi |S )zRegNetX-200MFr  r  r  r   s     rG   r  r         ->>v>>>rF   c                     t          d| fi |S )zRegNetX-400MFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )z+RegNetX-400MF w/ torchvision group roundingr  r  r  s     rG   r  r         *JAA&AAArF   c                     t          d| fi |S )zRegNetX-600MFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetX-800MFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetX-1.6GFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetX-3.2GFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetX-4.0GFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetX-6.4GFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetX-8.0GFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetX-12GFr  r  r  s     rG   r  r     r  rF   c                     t          d| fi |S )zRegNetX-16GFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetX-32GFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetY-200MFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetY-400MFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetY-600MFr  r  r  s     rG   r  r    r  rF   c                     t          d| fi |S )zRegNetY-800MFr  r  r  s     rG   r  r  $  r  rF   c                     t          d| fi |S )z+RegNetY-800MF w/ torchvision group roundingr  r  r  s     rG   r  r  *  r  rF   c                     t          d| fi |S )zRegNetY-1.6GFr  r  r  s     rG   r  r  0  r  rF   c                     t          d| fi |S )zRegNetY-3.2GFr  r  r  s     rG   r  r  6  r  rF   c                     t          d| fi |S )zRegNetY-4.0GFr  r  r  s     rG   r  r  <  r  rF   c                     t          d| fi |S )zRegNetY-6.4GFr  r  r  s     rG   r  r  B  r  rF   c                     t          d| fi |S )zRegNetY-8.0GFr  r  r  s     rG   r  r  H  r  rF   c                     t          d| fi |S )z+RegNetY-8.0GF w/ torchvision group roundingr  r  r  s     rG   r  r  N  r  rF   c                     t          d| fi |S )zRegNetY-12GFr  r  r  s     rG   r  r  T  r  rF   c                     t          d| fi |S )zRegNetY-16GFr  r  r  s     rG   r  r  Z  r  rF   c                     t          d| fi |S )zRegNetY-32GFr  r  r  s     rG   r  r  `  r  rF   c                     t          d| fi |S )zRegNetY-64GFr  r  r  s     rG   r  r  f  r  rF   c                     t          d| fi |S )zRegNetY-128GFr  r  r  s     rG   r  r  l       .*?????rF   c                     t          d| fi |S )zRegNetY-256GFr  r  r  s     rG   r  r  r  r	  rF   c                     t          d| fi |S )zRegNetY-4.0GF w/ GroupNorm r  r  r  s     rG   r  r  x  s     +ZBB6BBBrF   c                     t          d| fi |S )zRegNetV-4.0GF (pre-activation)r  r  r  s     rG   r  r  ~  r  rF   c                     t          d| fi |S )zRegNetV-6.4GF (pre-activation)r  r  r  s     rG   r  r    r  rF   c                 "    t          d| fddi|S )zRegNetZ-500MF
    NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
    but it's not clear it is equivalent to paper model as not detailed in the paper.
    r  r   Fr  r  s     rG   r  r    !     -TTETVTTTrF   c                 "    t          d| fddi|S )RegNetZ-4.0GF
    NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
    but it's not clear it is equivalent to paper model as not detailed in the paper.
    r  r   Fr  r  s     rG   r  r    r  rF   c                 "    t          d| fddi|S )r  r  r   Fr  r  s     rG   r  r    s!     /:VVeVvVVVrF   regnetz_040h)r/   )rb   )r   r   r   NF)r   NF)r   FrE   )r   rK  )fr   rX  dataclassesr   r   	functoolsr   typingr   r   r   r	   r
   numpyrf   r   torch.nnr   	timm.datar   r   timm.layersr   r   r   r   r   r   r   r   r   r   _builderr   	_featuresr   _manipulater   r   	_registryr   r   r   __all__r!   rL   ra   rz   r   r   r   rM  r   r   r   r    r   r  r   r  r  r  r  r  default_cfgsr  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r=   rE   rF   rG   <module>r!     s   2  * * * * * * * *       9 9 9 9 9 9 9 9 9 9 9 9 9 9            A A A A A A A A d d d d d d d d d d d d d d d d X X X X X X X X X X X X * * * * * * + + + + + + 4 4 4 4 4 4 4 4 Y Y Y Y Y Y Y Y Y Y[
! 3 3 3 3 3 3 3 3$! ! !

 
 
 
P P P P" 
 
 
 
F ( ( ( (:    .H H H H H H H HVI I I I IBI I I IX+ + + + +ry + + +\f f f f fRY f f fR       6 6 6t T = = =	REdqKKKK= 	REdrLLLL= 9u"B`cdddd	=
 	REdrLLLL= 	REdrLLLL= 	REdrLLLL= 	REdrLLLL= 	REdrLLLL= 	SUt"MMMM= 	REds"MMMM= 	SUt2NNNN= 	SUss"MMMM= 	SUss"MMMM=" 	REdqUYZZZZ#=$ 	REdqUYZZZZ%=& 	REdrVZ[[[['=( 	REcbUYZZZZ)=* 9u2X\nqrrrr+=, 	REdrVZ[[[[-=. 	REdrVZ[[[[/=0 	REdrVZ[[[[1=2 	SUt"W[\\\\3=4 	SUt"W[\\\\5=6 9$2RZ^pstttt7=8 	SUt2X\]]]]9=: 	SVBY]^^^^;=< 	SVBY]^^^^==> 	SV2X\]]]]?=@ cf#RZ^____A=B cf#RZ^____C=J I%DRrDWW\b%I%I%IK K K KK=T 	REdrDQUagi i i iU=X 	SUtTRVbh   Y=b 	RDTacTXDtv   c=j 	RDUqsUYDqF   k=r )RDUqsUYDtv   s=
@            %$ u&44|~ ~ ~u&
 DD BC C Cu& DD BC C Cu& DD BC C Cu& #DD7$;$;$;u& #DD7$;$;$;u&  %ddW&=&=&=!u&& DD  'u&, DD  -u&6  "7u&8 DD B  9u&@ DD B  Au&J TTb\\\Ku&L DD B FSR_a a aMu&T dd C FSR_a a aUu&` TT]_ _ _au&f wwM O  O  Ogu& u&l GGMO O Omu&r GGMO O Osu&x GGMO O Oyu&~ GGKM M Mu&D GGLN N NEu&J GGLN N NKu&R GGMO O OSu&X wwM O  O  OYu&^ GGMO O O_u&d GGMO O Oeu&j wwK M  M  Mku&p GGLN N Nqu&v GGLN N Nwu&~ Q[i Hs!D !D !Du&F Q[i Hs!D !D !DGu&N  R\j Hs"D "D "DOu&X T^l!n !n !nYu& u& u&^ T^l!n !n !n_u&d  U_m"o "o "oeu&l $O N Hs	!D !D !Dmu&v $O N Hs	!D !D !Dwu&@  $O O Hs	"D "D "DAu&J  $O O Hs	"D "D "DKu&V ww3^` ` `Wu&^ ~w3^` ` `_u&f  qw3^` ` `gu&x gg888yu&z gg888{u&| gg888}u&~ gg888u&@ gg888Au&B gg888Cu&D gg888Eu&F gg888Gu& u& u&H gg888Iu&J gg888Ku&L gg888Mu&N gg888Ou&R gg888Su&T gg888Uu&V gg888Wu&X gg888Yu&Z gg888[u&\ gg888]u&^ gg888_u&` gg888au&b gg888cu&d gg888eu&f gg888gu&h gg888iu& u& u up ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 B B& B B B B
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 B B& B B B B
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 B B& B B B B
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 @ @ @ @ @ @
 @ @ @ @ @ @
 C C6 C C C C
 ? ?v ? ? ? ?
 ? ?v ? ? ? ?
 U Uv U U U U U Uv U U U U W W W W W W  HO'     rF   