
    Ng}                        d Z ddgZddlmZmZ ddlmZ ddlZddlm	Z	 ddl
m	c mZ ddlmZmZ ddlmZmZmZ d	d
lmZ d	dlmZ d	dlmZ d	dlmZmZ dgdepepefdZ dhdepepede!de!fdZ"de!p	ee!df         de!p	ee!df         fdZ# G d de	j$                  Z% G d de	j$                  Z& G d de	j$                  Z' G d de	j$                  Z( G d  d!e	j$                  Z) G d" d#e	j$                  Z* ee*            G d$ d%e	j$                  Z+ G d& d'e	j$                  Z,	 	 did*e!d+e!d,e!d-e-d.e.d/e.d0e/d1e.fd2Z0 G d3 d4e	j1                  Z2 G d5 d6e	j$                  Z3 G d7 d8e	j$                  Z4 G d9 d:e	j$                  Z5 G d; de	j$                  Z6 G d< de	j$                  Z7djd>Z8 ei d? e8d@A          dB e8d@A          dC e8d@dDdEdFG          dH e8d@dIdJdFG          dK e8d@A          dL e8d@dDdEdFG          dM e8d@dIdJdFG          dN e8d@A          dO e8d@dDdEdFG          dP e8d@dIdJdFG          dQ e8d@dFR          dS e8d@dFR          dT e8d@dDdEdFG          dU e8d@dIdJdFG          dV e8d@dWdXdFG          dY e8d@dFR          dZ e8d@dDdEdFG           e8d@d[d\dFG           e8d@dWdXdFG          d]          Z9dkd^Z:dkd_Z;edkd`            Z<edkda            Z=edkdb            Z>edkdc            Z?edkdd            Z@edkde            ZAedkdf            ZBdS )la   EfficientViT (by MIT Song Han's Lab)

Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition`
    - https://arxiv.org/abs/2205.14756

Adapted from official impl at https://github.com/mit-han-lab/efficientvit
EfficientVitEfficientVitLarge    )ListOptional)partialNIMAGENET_DEFAULT_MEANIMAGENET_DEFAULT_STD)SelectAdaptivePool2dcreate_conv2dGELUTanh   )build_model_with_cfg)register_notrace_module)checkpoint_seq)register_modelgenerate_default_cfgsxc                      t           t          t          f          rt                     S  fdt          |          D             S )Nc                     g | ]}S  r   ).0_r   s     X/var/www/html/ai-engine/env/lib/python3.11/site-packages/timm/models/efficientvit_mit.py
<listcomp>zval2list.<locals>.<listcomp>   s    ***!A***    )
isinstancelisttuplerange)r   repeat_times   ` r   val2listr"      sF    !dE]## Aww****u[))****r   min_len
idx_repeatc                      t                      t                     dk    r1 fdt          |t                     z
            D              <   t                     S )Nr   c                      g | ]
}         S r   r   )r   r   r%   r   s     r   r   zval2tuple.<locals>.<listcomp>#   s    #S#S#SaAjM#S#S#Sr   )r"   lenr    r   )r   r$   r%   s   ` `r   	val2tupler)      sb    A
1vvzz#S#S#S#S#S53q66AQ;R;R#S#S#S*Z
 88Or   kernel_size.returnc                     t          | t                    rt          d | D                       S | dz  dk    s
J d            | dz  S )Nc                 ,    g | ]}t          |          S r   )get_same_padding)r   kss     r   r   z$get_same_padding.<locals>.<listcomp>*   s!    AAAr&r**AAAr      r   z kernel size should be odd number)r   r   )r*   s    r   r.   r.   (   sZ    +u%%  AA[AAABBBQ"""$F"""ar   c                   T     e Zd Zddddddej        ej        fdedef fdZd Z xZ	S )	ConvNormAct   r   F        in_channelsout_channelsc           	      Z   t          t          |                                            t          j        |d          | _        t          |||||||          | _        |	r |	|          nt          j                    | _	        |
 |
d          nt          j                    | _
        d S )NFinplace)r*   stridedilationgroupsbias)num_featuresT)superr2   __init__nnDropoutdropoutr   convIdentitynormact)selfr5   r6   r*   r:   r;   r<   r=   rC   
norm_layer	act_layer	__class__s              r   r@   zConvNormAct.__init__1   s     	k4  ))+++z'5999!#
 
 
	 >HZJJL9999R[]]	.7.C99T****r   c                     |                      |          }|                     |          }|                     |          }|                     |          }|S N)rC   rD   rF   rG   rH   r   s     r   forwardzConvNormAct.forwardL   sC    LLOOIIaLLIIaLLHHQKKr   )
__name__
__module____qualname__rA   BatchNorm2dReLUintr@   rO   __classcell__rK   s   @r   r2   r2   0   s        
 >'W WW W W W W W W6      r   r2   c                   `     e Zd Zdddej        ej        fej        dffdedef fdZd Z xZ	S )	DSConvr3   r   FNr5   r6   c                 j   t          t          |                                            t          |d          }t          |d          }t          |d          }t	          ||||||d         |d         |d                   | _        t	          ||d|d         |d         |d                   | _        d S )Nr0   r   )r<   rI   rJ   r=   r   rI   rJ   r=   )r?   rY   r@   r)   r2   
depth_conv
point_conv)	rH   r5   r6   r*   r:   use_biasrI   rJ   rK   s	           r   r@   zDSConv.__init__U   s     	fd$$&&&Xq))z1--
i++	%!!}l!	
 	
 	
 &!!}l!
 
 
r   c                 Z    |                      |          }|                     |          }|S rM   )r\   r]   rN   s     r   rO   zDSConv.forwardw   s)    OOAOOAr   
rP   rQ   rR   rA   rS   ReLU6rU   r@   rO   rV   rW   s   @r   rY   rY   T   s        
 NBN38T" 
  
 
  
  
  
  
  
  
D      r   rY   c                   d     e Zd Zdddddej        ej        fej        dffdedef fdZd Z xZ	S )		ConvBlockr3   r   NFr5   r6   c
           
         t          t          |                                            t          |d          }t          |d          }t          |	d          }	|pt	          ||z            }t          |||||d         |	d         |d                   | _        t          |||d|d         |	d         |d                   | _        d S )Nr0   r   r[   r   )r?   rc   r@   r)   roundr2   conv1conv2rH   r5   r6   r*   r:   mid_channelsexpand_ratior^   rI   rJ   rK   s             r   r@   zConvBlock.__init__~   s     	i'')))Xq))z1--
i++	#Hu[<-G'H'H !!}l!
 
 

 !!!}l!
 
 



r   c                 Z    |                      |          }|                     |          }|S rM   )rf   rg   rN   s     r   rO   zConvBlock.forward   s%    JJqMMJJqMMr   r`   rW   s   @r   rc   rc   }   s        
 NBN38T"#
 #
#
 #
 #
 #
 #
 #
 #
J      r   rc   c            	       |     e Zd Zdddddej        ej        ej        fej        ej        dffdedef fdZd	 Z xZ	S )
MBConvr3   r   N   Fr5   r6   c
                    t          t          |                                            t          |d          }t          |d          }t          |	d          }	|pt	          ||z            }t          ||dd|d         |	d         |d                   | _        t          ||||||d         |	d         |d                   | _        t          ||d|d         |	d         |d                   | _        d S )Nr3   r   r   )r:   rI   rJ   r=   r:   r<   rI   rJ   r=   r0   r[   )	r?   rm   r@   r)   re   r2   inverted_convr\   r]   rh   s             r   r@   zMBConv.__init__   s    	fd$$&&&Xq))z1--
i++	#Hu[<-G'H'H(!!}l!
 
 
 &!!}l!	
 	
 	
 &!!}l!
 
 
r   c                     |                      |          }|                     |          }|                     |          }|S rM   )rq   r\   r]   rN   s     r   rO   zMBConv.forward   s<    q!!OOAOOAr   r`   rW   s   @r   rm   rm      s        
 NBNBNC8RXt,,
 ,
,
 ,
 ,
 ,
 ,
 ,
 ,
\      r   rm   c            	       f     e Zd Zddddddej        ej        fej        dffdedef fdZd	 Z xZ	S )
FusedMBConvr3   r   Nrn   Fr5   r6   c                    t          t          |                                            t          |d          }t          |	d          }	t          |
d          }
|pt	          ||z            }t          ||||||	d         |
d         |d                   | _        t          ||d|	d         |
d         |d                   | _        d S )Nr0   r   rp   r   r[   )r?   rt   r@   r)   re   r2   spatial_convr]   )rH   r5   r6   r*   r:   ri   rj   r<   r^   rI   rJ   rK   s              r   r@   zFusedMBConv.__init__   s     	k4  ))+++Xq))z1--
i++	#Hu[<-G'H'H'!!}l!	
 	
 	
 &!!}l!
 
 
r   c                 Z    |                      |          }|                     |          }|S rM   )rv   r]   rN   s     r   rO   zFusedMBConv.forward  s+    a  OOAr   r`   rW   s   @r   rt   rt      s        
 NBN38T"$
 $
$
 $
 $
 $
 $
 $
 $
L      r   rt   c            	       p     e Zd ZdZdddddej        fdej        ddf	d	ed
edepddef fdZ	d Z
d Z xZS )LiteMLAz(Lightweight multi-scale linear attentionN      ?   FNN)   h㈵>r5   r6   headsheads_ratioc           	      \   t          t          |                                            || _        pt	          ||z  |z            |z  t          d          t          |d          }t          |d          }|| _        t          |dz  dd         |d         |d                   | _        t          j
        fd|
D                       | _         |	d          | _        t          dt          |
          z   z  |dd         |d         |d                   | _        d S )	Nr0   r3   r   r   )r=   rI   rJ   c                     g | ]q}t          j        t          j        d z  d z  |t          |          d z  d                   t          j        d z  d z  dd z  d                             rS )r3   r   )paddingr<   r=   r   )r<   r=   )rA   
SequentialConv2dr.   )r   scaler   	total_dimr^   s     r   r   z$LiteMLA.__init__.<locals>.<listcomp>.  s     %
 %
 %
  M		M	M,U33y=!!   	!i-Y!e)RZ[\R]^^^
 
%
 %
 %
r   Fr8   )r?   ry   r@   epsrU   r)   dimr2   qkvrA   
ModuleListaggregkernel_funcr(   proj)rH   r5   r6   r   r   r   r^   rI   rJ   r   scalesr   r   rK   s      `  `     @r   r@   zLiteMLA.__init__  sa    	gt%%'''>[C/+=>>CK	Xq))z1--
i++		M!!!}l
 
 
 m %
 %
 %
 %
 %
 %
  %
 %
 %
   ';u555S[[)!!!}l
 
 
			r   c                 0   |j         }|                                |                                |                                }}}|                    dd          |z  }||z  }|dd df         |ddd f         | j        z   z  }|                    |          S )Nr#   .)dtypefloat	transposer   to)rH   qkvr   kvouts          r   _attnzLiteMLA._attnG  s    ''))QWWYY		a1[[R  1$"f#ss(ms38}tx78vve}}r   c                 ~   |j         \  }}}}|                     |          }|g}| j        D ] }|                     ||                     !t	          j        |d          }|                    |dd| j        z  ||z                                dd          }|	                    dd          \  }	}
}| 
                    |	          }	| 
                    |
          }
t          j        |ddd	          }t          j                                        sPt	          j        |j        j        d
          5  |                     |	|
|          }d d d            n# 1 swxY w Y   n|                     |	|
|          }|                    dd                              |d||          }|                     |          }|S )Nr   )r   r#   r3   r   )r   r   constantrz   )modevalueF)device_typeenabled)shaper   r   appendtorchcatreshaper   r   chunkr   Fpadjitis_scriptingautocastdevicetyper   r   )rH   r   Br   HWr   multi_scale_qkvopr   r   r   r   s                r   rO   zLiteMLA.forwardO  s   W
1a hhqkk%+ 	, 	,B""22c77++++)O;;;)11!RTXq1uMMWWXZ\^__!''r'221a QQE!V*B777y%%'' 	&AHM5III * *jjAq))* * * * * * * * * * * * * * * **Q1%%C mmB##++Ar1a88iinn
s   5EE E)rP   rQ   rR   __doc__rA   rS   rT   rU   r   r@   r   rO   rV   rW   s   @r   ry   ry     s        22 " ".)G6
 6
6
 6
 {d	6

 6
 6
 6
 6
 6
 6
p        r   ry   c                   D     e Zd Zdddej        ej        f fd	Zd Z xZS )EfficientVitBlockrz          c                 H   t          t          |                                            t          t	          ||||d |f          t          j                              | _        t          t          |||dd d |f||d f          t          j                              | _	        d S )N)r5   r6   r   r   rI   TTF)r5   r6   rj   r^   rI   rJ   )
r?   r   r@   ResidualBlockry   rA   rE   context_modulerm   local_module)rH   r5   r   head_dimrj   rI   rJ   rK   s          r   r@   zEfficientVitBlock.__init__p  s     	&&//111+'(' *-   KMM	
 	
 *'(), $
3$i6   KMM

 

r   c                 Z    |                      |          }|                     |          }|S rM   )r   r   rN   s     r   rO   zEfficientVitBlock.forward  s-    ""a  r   )	rP   rQ   rR   rA   rS   	Hardswishr@   rO   rV   rW   s   @r   r   r   o  sb         >,
 
 
 
 
 
@      r   r   c                   z     e Zd Z	 	 ddeej                 deej                 deej                 f fdZd Z xZS )r   Nmainshortcutpre_normc                     t          t          |                                            ||nt          j                    | _        || _        || _        d S rM   )r?   r   r@   rA   rE   r   r   r   )rH   r   r   r   rK   s       r   r@   zResidualBlock.__init__  sI     	mT""++---$,$8bkmm	 r   c                     |                      |                     |                    }| j        ||                     |          z   }|S rM   )r   r   r   )rH   r   ress      r   rO   zResidualBlock.forward  sB    iia(())=$a(((C
r   r|   )	rP   rQ   rR   r   rA   Moduler@   rO   rV   rW   s   @r   r   r     s         )-(,		! 	!ry!	! 29%	! 29%		! 	! 	! 	! 	! 	!      r   r   Fdefaultr5   r6   r:   rj   rI   rJ   
fewer_norm
block_typec                 L   |dv sJ |dk    rJ|dk    r"t          | |||rdnd|rd |fn||d f          }not          | |||rdnd|rd |fn||d f          }nM|dk    r%t          | ||||rdnd|rd d |fn|||d f          }n"t          | ||||rdnd|rd |fn||d f          }|S )	N)r   largefusedr   r   )TFF)r5   r6   r:   r^   rI   rJ   r   )r5   r6   r:   rj   r^   rI   rJ   )rY   rc   rm   rt   )	r5   r6   r:   rj   rI   rJ   r   r   blocks	            r   build_local_blockr     sb    66666q""')*4?%1;KD*--$d+  EE ')*4?%1;KD*--$d+  EE ""'))0:E,,7AQD$
33z$i6  EE  '))*4?%1;KD*--$d+  E Lr   c                         e Zd Zd fd	Z xZS )Stemr   c                 r   t                                                       d| _        |                     dt	          ||dd||                     d}t          |          D ]S}|                     d| t          t          ||dd|||          t          j	                                         |dz  }Td S )	Nr0   in_convr3   )r*   r:   rI   rJ   r   r   r   )r5   r6   r:   rj   rI   rJ   r   )
r?   r@   r:   
add_moduler2   r    r   r   rA   rE   )
rH   in_chsout_chsdepthrI   rJ   r   
stem_blockr   rK   s
            r   r@   zStem.__init__  s    aJ)  	
 	
 	
 
u 	 	AOO.*..! '!(!")')   1 1    !OJJ	 	r   )r   )rP   rQ   rR   r@   rV   rW   s   @r   r   r     s=                 r   r   c                   (     e Zd Z	 d fd	Zd Z xZS )EfficientVitStageFc	                    t          t          |                                            t          t	          ||d||||          d           g}	|}|r:t          |          D ])}
|	                    t          |||||                     *nZt          d|          D ]I}|	                    t          t	          ||d|||          t          j	                                         Jt          j
        |	 | _        d S )Nr0   )r5   r6   r:   rj   rI   rJ   r   r5   r   rj   rI   rJ   r   )r5   r6   r:   rj   rI   rJ   )r?   r   r@   r   r   r    r   r   rA   rE   r   blocks)rH   r   r   r   rI   rJ   rj   r   	vit_stager   r   irK   s               r   r@   zEfficientVitStage.__init__  sJ    	&&//111"$)%#$   
 
   	5\\ 	 	%$*!)%1#-"+     	 1e__  m%$*%, %1#-"+   KMM
 
 
 
 
 
 mV,r   c                 ,    |                      |          S rM   r   rN   s     r   rO   zEfficientVitStage.forward4      {{1~~r   FrP   rQ   rR   r@   rO   rV   rW   s   @r   r   r     sR         5- 5- 5- 5- 5- 5-n      r   r   c                   *     e Zd Z	 	 d fd	Zd Z xZS )EfficientVitLargeStageFc	                    t          t          |                                            t          t	          ||d|rdnd|||p||rdnd          d           g}	|}|r:t          |          D ])}
|	                    t          ||d||                     *n_t          |          D ]O}|	                    t          t	          ||d	d
||||rdnd          t          j	                                         Pt          j
        |	 | _        d S )Nr0         r   r   )r5   r6   r:   rj   rI   rJ   r   r   rn   r   r   r   )r?   r   r@   r   r   r    r   r   rA   rE   r   r   )rH   r   r   r   rI   rJ   r   r   r   r   r   r   rK   s               r   r@   zEfficientVitLargeStage.__init__9  sq    	$d++44666"$#,4RR"%#$2
(2?99	 	 	 
 
   	5\\ 	 	%$*!)%&#-"+     	 5\\  m%$*%, %&#-"+#-0:#G99	 	 	 KMM      mV,r   c                 ,    |                      |          S rM   r   rN   s     r   rO   zEfficientVitLargeStage.forwards  r   r   )FFr   rW   s   @r   r   r   8  sU         8- 8- 8- 8- 8- 8-t      r   r   c                        e Zd Zddej        ej        ddfdedee         deded	e	d
ef fdZ
dded	ee	         fdZddefdZ xZS )ClassifierHead  r4   avgr~   r5   widthsnum_classesrC   	pool_typenorm_epsc	                 ~   t          t          |                                            || _        |d         | _        |s
J d            t          ||d         d||          | _        t          |d          | _        t          j
        t          j        |d         |d         d	          t          j        |d         |
          | |d          nt          j                    t          j        |d          |dk    rt          j        |d         |d	          nt          j                              | _        d S )Nr#   Cannot disable poolingr   r   )rI   rJ   Tr   flattenFr=   r   r8   )r?   r   r@   r   r>   r2   r   r   global_poolrA   r   Linear	LayerNormrE   rB   
classifier)
rH   r5   r   r   rC   rI   rJ   r   r   rK   s
            r   r@   zClassifierHead.__init__x  s'    	nd##,,..."2J222222";q	1_hiii/)TRRR-IfQi777L111'0'<IId####"+--Jw...<G!OOBIfQi48888QSQ\Q^Q^
 
r   Nc                     |"|s
J d            t          |d          | _        |dk    r&t          j        | j        |d          | j        d<   d S t          j                    | j        d<   d S )Nr   Tr   r   r   r#   )r   r   rA   r   r>   r   rE   )rH   r   r   s      r   resetzClassifierHead.reset  s{     6666663iQUWWWD??"$)D,={QU"V"V"VDOB"$+--DOBr   F
pre_logitsc                 :   |                      |          }|                     |          }|rY | j        d         |          } | j        d         |          } | j        d         |          } | j        d         |          }n|                     |          }|S )Nr   r   r0   r3   )r   r   r   rH   r   r   s      r   rO   zClassifierHead.forward  s    LLOOQ 	#""1%%A""1%%A""1%%A""1%%AA""Ar   rM   r   )rP   rQ   rR   rA   rS   r   rU   r   r   strr@   r   r   boolrO   rV   rW   s   @r   r   r   w  s        
  >,
 

 S	
 	

 
 
 
 
 
 
 
 
40 0 0# 0 0 0 0 T        r   r   c                       e Zd Zdddddej        ej        ddddf fd	Zej        j	        dd
            Z
ej        j	        dd            Zej        j	        dej        fd            Zddedee         fdZd ZddefdZd Z xZS )r   r3   r   r   r   r   r4   r   c                    t          t          |                                            d| _        || _        || _        t          ||d         |d         ||          | _        | j        j        }g | _	        t          j                    | _        |d         }t          t          |dd          |dd                              D ]f\  }\  }}| j                            t!          ||||||||dk                         |dz  }|}| xj	        t#          ||d|           gz  c_	        g|| _        t'          | j        |	||
| j                  | _        | j        j        | _        d S )	NFr   r   r0   )r   rI   rJ   rj   r   r   stages.num_chs	reductionmodule)r   r   rC   r   )r?   r   r@   grad_checkpointingr   r   r   stemr:   feature_inforA   r   stages	enumeratezipr   r   dictr>   r   headhead_hidden_size)rH   in_chansr   depthsr   rj   rI   rJ   r   head_widths	drop_rater   r:   r5   r   wdrK   s                    r   r@   zEfficientVit.__init__  s    	lD!!**,,,"'&& 6!9fQiYOO	! mooQi"3vabbz6!"":#>#>?? 	e 	eIAv1K0%#)!q&	  	  	  	 	 	 aKFK${fUb_`UbUb"c"c"c!dd'"#&
 
 
	 !%	 6r   Fc                 4    t          d|rdnddg          }|S Nz^stemz^stages\.(\d+))z^stages\.(\d+).downsample)r   )z^stages\.(\d+)\.\w+\.(\d+)N)r  r   r  rH   coarsematchers      r   group_matcherzEfficientVit.group_matcher  9    (. $$455
 
 
 r   Tc                     || _         d S rM   r  rH   enables     r   set_grad_checkpointingz#EfficientVit.set_grad_checkpointing      "(r   r+   c                 &    | j         j        d         S Nr#   r  r   rH   s    r   get_classifierzEfficientVit.get_classifier      y#B''r   Nr   r   c                 J    || _         | j                            ||           d S rM   r   r  r   rH   r   r   s      r   reset_classifierzEfficientVit.reset_classifier  &    &	[11111r   c                     |                      |          }| j        r4t          j                                        st          | j        |          }n|                     |          }|S rM   r  r  r   r   r   r   r  rN   s     r   forward_featureszEfficientVit.forward_features  X    IIaLL" 	59+A+A+C+C 	t{A..AAAAr   r   c                 ^    |r|                      ||          n|                      |          S N)r   r  r  s      r   forward_headzEfficientVit.forward_head  -    6@Rtyyzy222diiPQllRr   c                 Z    |                      |          }|                     |          }|S rM   r6  r;  rN   s     r   rO   zEfficientVit.forward   -    !!!$$a  r   r   TrM   )rP   rQ   rR   rA   rS   r   r@   r   r   ignorer!  r'  r   r-  rU   r   r  r2  r6  r  r;  rO   rV   rW   s   @r   r   r     sO        >,27 27 27 27 27 27h Y    Y) ) ) ) Y(	 ( ( ( (2 2C 2hsm 2 2 2 2  S S$ S S S S      r   c                   
    e Zd Zddddej        edddddf fd	Zej        j	        dd
            Z
ej        j	        dd            Zej        j	        dej        fd            Zddedee         fdZd ZddefdZd Z xZS )r   r3   r   r   r   r4   r   gHz>c                     t          t          |                                            d| _        || _        |
| _        || _        t          || j                  }t          ||d         |d         ||d          | _	        | j	        j
        }g | _        t          j                    | _        |d         }t          t!          |dd          |dd                              D ]j\  }\  }}| j                            t%          |||||||dk    |dk    	                     |dz  }|}| xj        t'          ||d
|           gz  c_        k|| _        t+          | j        ||
|	| j        || j                  | _        | j        j        | _        d S )NFr   r   r   )r   r   r3   r0   )r   rI   rJ   r   r   r   r  r  )r   r   rC   r   rJ   r   )r?   r   r@   r  r   r   r   r   r   r  r:   r  rA   r   r  r  r  r   r   r  r>   r   r  r  )rH   r  r   r  r   rI   rJ   r   r  r  r   r   r:   r5   r   r  r  rK   s                    r   r@   zEfficientVitLarge.__init__  s    	&&//111"'&& ZT];;;
 6!9fQiY[bccc	! mooQi"3vabbz6!"":#>#>?? 	e 	eIAv1K5%#!q&6	  	  	  	 	 	 aKFK${fUb_`UbUb"c"c"c!dd'"#&]
 
 
	 !%	 6r   Fc                 4    t          d|rdnddg          }|S r  r  r  s      r   r!  zEfficientVitLarge.group_matcher?  r"  r   Tc                     || _         d S rM   r$  r%  s     r   r'  z(EfficientVitLarge.set_grad_checkpointingJ  r(  r   r+   c                 &    | j         j        d         S r*  r+  r,  s    r   r-  z EfficientVitLarge.get_classifierN  r.  r   Nr   r   c                 J    || _         | j                            ||           d S rM   r0  r1  s      r   r2  z"EfficientVitLarge.reset_classifierR  r3  r   c                     |                      |          }| j        r4t          j                                        st          | j        |          }n|                     |          }|S rM   r5  rN   s     r   r6  z"EfficientVitLarge.forward_featuresV  r7  r   r   c                 ^    |r|                      ||          n|                      |          S r9  r:  r  s      r   r;  zEfficientVitLarge.forward_head^  r<  r   c                 Z    |                      |          }|                     |          }|S rM   r>  rN   s     r   rO   zEfficientVitLarge.forwarda  r?  r   r   r@  rM   )rP   rQ   rR   rA   rS   r   r@   r   r   rA  r!  r'  r   r-  rU   r   r  r2  r6  r  r;  rO   rV   rW   s   @r   r   r     sM        >67 67 67 67 67 67p Y    Y) ) ) ) Y(	 ( ( ( (2 2C 2hsm 2 2 2 2  S S$ S S S S      r    c           
      2    | dt           t          dddddd	|S )Nr   zstem.in_conv.convzhead.classifier.4gffffff?)r3      rM  )   rN  )	urlr   meanstd
first_convr   crop_pct
input_size	pool_sizer   )rO  kwargss     r   _cfgrW  g  s6    %#))#   r   zefficientvit_b0.r224_in1kztimm/)	hf_hub_idzefficientvit_b1.r224_in1kzefficientvit_b1.r256_in1k)r3      rY  )r{   r{   rz   )rX  rT  rU  rS  zefficientvit_b1.r288_in1k)r3      rZ  )	   r[  zefficientvit_b2.r224_in1kzefficientvit_b2.r256_in1kzefficientvit_b2.r288_in1kzefficientvit_b3.r224_in1kzefficientvit_b3.r256_in1kzefficientvit_b3.r288_in1kzefficientvit_l1.r224_in1k)rX  rS  zefficientvit_l2.r224_in1kzefficientvit_l2.r256_in1kzefficientvit_l2.r288_in1kzefficientvit_l2.r384_in1k)r3     r\  )   r]  zefficientvit_l3.r224_in1kzefficientvit_l3.r256_in1k)r3   @  r^  )
   r_  )zefficientvit_l3.r320_in1kzefficientvit_l3.r384_in1kc                 |    |                     dd          }t          t          | |fdt          d|          i|}|S Nout_indices)r   r   r0   r3   feature_cfgT)flatten_sequentialrb  )popr   r   r  variant
pretrainedrV  rb  models        r   _create_efficientvitrj    sZ    **]L99K   DkJJJ	
  E Lr   c                 |    |                     dd          }t          t          | |fdt          d|          i|}|S ra  )re  r   r   r  rf  s        r   _create_efficientvit_largerl    sZ    **]L99K   DkJJJ	
  E Lr   c           	      Z    t          dddd          }t          dd| it          |fi |S )	N)r{   r   r   @      )r   r0   r0   r0   r0   r   )   i   r   r  r   r  efficientvit_b0rh  )rr  r  rj  rh  rV  
model_argss      r   rr  rr    sN    #ObVbd d dJggjgDQ[LfLf_eLfLfgggr   c           	      Z    t          dddd          }t          dd| it          |fi |S )	N)r   r   rn  ro  rY  )r   r0   r3   r3   r   r   )i   i@  rq  efficientvit_b1rh  )rw  rs  rt  s      r   rw  rw    N    %oXdf f fJggjgDQ[LfLf_eLfLfgggr   c           	      Z    t          dddd          }t          dd| it          |fi |S )	N)r   0   `      r\  )r   r3   r   r   rn   r   i 	  i 
  rq  efficientvit_b2rh  )r~  rs  rt  s      r   r~  r~    rx  r   c           	      Z    t          dddd          }t          dd| it          |fi |S )	Nr   rn  ro  rY     )r   r   rn   rn   r[  r   r}  rq  efficientvit_b3rh  )r  rs  rt  s      r   r  r    sN    &Yeg g gJggjgDQ[LfLf_eLfLfgggr   c           	      Z    t          dddd          }t          dd| it          |fi |S )	Nr  )r   r   r   rn   rn   r   i   i  rq  efficientvit_l1rh  )r  r  rl  rt  s      r   r  r    O    &Yeg g gJ%mmJmRVWaRlRlekRlRlmmmr   c           	      Z    t          dddd          }t          dd| it          |fi |S )	Nr  r   r0   r0   r{   r{   r   r  rq  efficientvit_l2rh  )r  r  rt  s      r   r  r    r  r   c           	      Z    t          dddd          }t          dd| it          |fi |S )	N)rn  ro  rY  r  rp  r  r   )i   i   rq  efficientvit_l3rh  )r  r  rt  s      r   r  r    sO    (2[gi i iJ%mmJmRVWaRlRlekRlRlmmmr   )r   )r   r#   )Fr   )rK  r   )Cr   __all__typingr   r   	functoolsr   r   torch.nnrA   torch.nn.functional
functionalr   	timm.datar	   r
   timm.layersr   r   r   _builderr   _features_fxr   _manipulater   	_registryr   r   r   r   anyr"   rU   r)   r.   r   r2   rY   rc   rm   rt   ry   r   r   r   r  r  r   r   r   r   r   r   r   r   rW  default_cfgsrj  rl  rr  rw  r~  r  r  r  r  r   r   r   <module>r     s    .
/ ! ! ! ! ! ! ! !                       A A A A A A A A E E E E E E E E E E * * * * * * 1 1 1 1 1 1 ' ' ' ' ' ' < < < < < < < <+ +$$ + + + + %%#  S     #"8sCx  S=SE#s(O        ! ! ! ! !") ! ! !H& & & & &RY & & &R) ) ) ) )	 ) ) )X3 3 3 3 3RY 3 3 3l* * * * *") * * *Z] ] ] ] ]bi ] ] ]@        $ $ $ $ $	 $ $ $N    BI   4 !#3 333 3 	3
 3 3 3 3 3 3 3l    2=   :9 9 9 9 9	 9 9 9x< < < < <RY < < <~/ / / / /RY / / /dZ Z Z Z Z29 Z Z Zz^ ^ ^ ^ ^	 ^ ^ ^B    %$ X&" " "X&  " " "	X&   FS" " "X&   FS" " "X&  " " "X&$   FS" " "%X&,   FS" " "-X&4  " " "5X&:   FS" " ";X&B   FS" " "CX&J  " " "KX&R  " " "SX&Z   FS" " "[X&b   FS" " "cX&j   Hs" " "kX&r  " " "sX&z   FS" " "{X&B "& Hs" " " "& Hs" " "KX& X& X& X Xv	 	 	 		 	 	 	 h h h h h h h h h h h h h h h h n n n n n n n n n n n n n nr   