
    Ng]1                        d Z ddlmZ ddlmZ ddlmZ ddlZddlm	Z	 ddl
mZmZmZmZmZmZ ddlmZ dd	lmZ dd
lmZmZmZ dgZ G d de	j                  Z G d de	j                  Z G d de	j                  Zd Zd(dZd)dZ  e e d           e dd           e d           e d           e d           e ddd           e ddd           e ddd           e d           e dd          d
          Z!ed(defd            Z"ed(defd             Z#ed(defd!            Z$ed(defd"            Z% ee&d#d$d%d&d'           dS )*z
TResNet: High Performance GPU-Dedicated Architecture
https://arxiv.org/pdf/2003.13630.pdf

Original model: https://github.com/mrT23/TResNet

    )OrderedDict)partial)OptionalN)SpaceToDepth
BlurPool2dClassifierHeadSEModuleConvNormActDropPath   )build_model_with_cfg)checkpoint_seq)register_modelgenerate_default_cfgsregister_model_deprecationsTResNetc                   4     e Zd ZdZ	 	 	 	 	 d fd	Zd Z xZS )
BasicBlockr   NT        c                    t          t          |                                            || _        || _        t          t          j        d          }t          ||d|||          | _	        t          ||ddd          | _
        t          j        d	          | _        t          || j        z  d
z  d          }	|rt          || j        z  |	          nd | _        |dk    rt#          |          nt          j                    | _        d S )NMbP?negative_slope   kernel_sizestride	act_layeraa_layerr   Fr   r   	apply_actTinplace   @   rd_channelsr   )superr   __init__
downsampler   r   nn	LeakyReLUr
   conv1conv2ReLUactmax	expansionr	   ser   Identity	drop_path)selfinplanesplanesr   r*   use_ser   drop_path_rater   rd_chs	__class__s             O/var/www/html/ai-engine/env/lib/python3.11/site-packages/timm/models/tresnet.pyr)   zBasicBlock.__init__   s     	j$((***$BL>>>	 6q[dowxxx
 QqTYZZZ
74(((Vdn,1266KQ[(6DN2GGGGW[5Ca5G5G.111R[]]    c                 *   | j         |                      |          }n|}|                     |          }|                     |          }| j        |                     |          }|                     |          |z   }|                     |          }|S N)r*   r-   r.   r3   r5   r0   r6   xshortcutouts       r=   forwardzBasicBlock.forward1   s    ?&q))HHHjjmmjjoo7''#,,CnnS!!H,hhsmm
r>   )r   NTNr   __name__
__module____qualname__r2   r)   rE   __classcell__r<   s   @r=   r   r      sh        I [ [ [ [ [ [.      r>   r   c                   6     e Zd ZdZ	 	 	 	 	 	 d fd	Zd Z xZS )	
Bottleneckr$   r   NTr   c	                 L   t          t          |                                            || _        || _        |pt          t          j        d          }t          ||dd|          | _	        t          ||d|||          | _
        t          || j        z  dz  d          }	|rt          ||		          nd | _        t          ||| j        z  ddd
          | _        |dk    rt!          |          nt          j                    | _        t          j        d          | _        d S )Nr   r   r   )r   r   r   r   r      r%   r&   Fr    r   Tr"   )r(   rM   r)   r*   r   r   r+   r,   r
   r-   r.   r1   r2   r	   r3   conv3r   r4   r5   r/   r0   )r6   r7   r8   r   r*   r9   r   r   r:   reduction_chsr<   s             r=   r)   zBottleneck.__init__B   s9    	j$((***$Kd!K!K!K	 f!AL L L
 F&IX`b b b
 FT^3q8"==AGQ(6}====T FT^+1PUW W W
 6Da5G5G.111R[]]74(((r>   c                 T   | j         |                      |          }n|}|                     |          }|                     |          }| j        |                     |          }|                     |          }|                     |          |z   }|                     |          }|S r@   )r*   r-   r.   r3   rP   r5   r0   rA   s       r=   rE   zBottleneck.forward`   s    ?&q))HHHjjmmjjoo7''#,,CjjoonnS!!H,hhsmm
r>   )r   NTNNr   rF   rK   s   @r=   rM   rM   ?   se        I ) ) ) ) ) )<      r>   rM   c                        e Zd Z	 	 	 	 	 	 	 d fd	ZddZej        j        dd            Zej        j        dd            Z	ej        j        de
j        fd            Zddedee         fdZd ZddefdZd Z xZS )r   r           ?Ffastr   c	                    || _         || _        d| _        t          t          |                                            t          }	t          j        }
t          d|z            | _
        t          d|z            | _        |r$| j
        dz  dz  | _
        | j        dz  dz  | _        d t          j        d|t          |                                        |          D             }t!          |dz  | j        dd|
	          }|                     |rt$          nt&          | j        |d         dd
|	|d                   }|                     |rt$          nt&          | j        dz  |d         dd
|	|d                   }|                     t$          | j        dz  |d         dd
|	|d                   }|                     t$          | j        dz  |d         dd|	|d                   }t          j        t+          dt-                      fd|fd|fd|fd|fd|fg                    | _        t1          | j        dd          t1          | j        |rt$          j        ndz  dd          t1          | j        dz  |rt$          j        ndz  dd          t1          | j        dz  t$          j        z  dd          t1          | j        dz  t$          j        z  dd          g| _        | j        dz  t$          j        z  x| _        | _        t;          | j        |||          | _        |                                 D ]}}tA          |t          j!                  r't          j"        #                    |j$        dd           tA          |t          j%                  r |j$        j&        '                    dd           ~|                                 D ]}tA          |t&                    r.t          j"        (                    |j)        j*        j$                   tA          |t$                    r.t          j"        (                    |j+        j*        j$                   d S ) NFr%   rO   c                 6    g | ]}|                                 S  )tolist).0rB   s     r=   
<listcomp>z$TResNet.__init__.<locals>.<listcomp>   s     ```aqxxzz```r>   r      r   r   )r   r   r   T)r   r9   r   r:      r$   s2dr-   layer1layer2layer3layer4 )num_chs	reductionmodulezbody.layer1zbody.layer2zbody.layer3    zbody.layer4)	pool_type	drop_ratefan_out
leaky_relu)modenonlinearityg{Gz?),num_classesrj   grad_checkpointingr(   r   r)   r   r+   r,   intr7   r8   torchlinspacesumsplitr
   _make_layerrM   r   
Sequentialr   r   bodydictr2   feature_infonum_featureshead_hidden_sizer   headmodules
isinstanceConv2dinitkaiming_normal_weightLineardatanormal_zeros_r.   bnrP   )r6   layersin_chansro   width_factorv2global_poolrj   r:   r   r   dprr-   r`   ra   rb   rc   mr<   s                     r=   r)   zTResNet.__init__p   s?    '""'gt%%'''L	 B-.."|+,, 	/ MQ.2DM+*Q.DK``5>!^S[[#Q#Q#W#WX^#_#_```HrM4;qa[deee!!,JJ*K1TH]`ab]c " e e !!,JJ*K!OVAYqadefag " i i !!K!OVAYqadefag " i i !!K!OVAYqbefgbh " j j
 M+LNN#evvvv/
 # #  	 "===(I
(<(<JVW`mnnnqB,MJ,@,@ANZ[dqrrrq:+??2Vcdddq:+??2Vcddd
 6:[1_
H\4\\D1"4#4k[dmnnn	  	/ 	/A!RY'' ]''y|'\\\!RY'' /%%a...  	2 	2A!Z(( 2qwz0111!Z(( 2qwz0111		2 	2r>   r   TNc                    d }|dk    s| j         ||j        z  k    rfg }	|dk    r+|	                    t          j        dddd                     |	t          | j         ||j        z  ddd          gz  }	t          j        |	 }g }	t          |          D ]h}
|	                     || j         ||
dk    r|nd|
dk    r|nd ||t          |t                    r||
         n|                     ||j        z  | _         it          j        |	 S )	Nr   r^   TF)r   r   	ceil_modecount_include_padr    r   )r   r*   r9   r   r:   )
r7   r2   appendr+   	AvgPool2dr
   rw   ranger   list)r6   blockr8   blocksr   r9   r   r:   r*   r   is              r=   rv   zTResNet._make_layer   sH   
Q;;$-6EO+CCCF{{blqdfklllmmm{v7Qq\ac c c d dF/Jv 
	5 
	5AMM%%!"avvQ)*a::T!4>~t4T4Th~a00Zh      #U_4DMM}f%%r>   c                 0    t          d|rdnd          }|S )Nz^body\.conv1z^body\.layer(\d+)z^body\.layer(\d+)\.(\d+))stemr   )ry   )r6   coarsematchers      r=   group_matcherzTResNet.group_matcher   s%    OF4s4H4HXstttr>   c                     || _         d S r@   )rp   )r6   enables     r=   set_grad_checkpointingzTResNet.set_grad_checkpointing   s    "(r>   returnc                     | j         j        S r@   )r}   fc)r6   s    r=   get_classifierzTResNet.get_classifier   s    y|r>   ro   r   c                 >    | j                             ||           d S )N)ri   )r}   reset)r6   ro   r   s      r=   reset_classifierzTResNet.reset_classifier   s     	{;;;;;r>   c                 `   | j         rt          j                                        ss| j                            |          }| j                            |          }t          | j        j        | j        j	        | j        j
        | j        j        g|d          }n|                     |          }|S )NT)flatten)rp   rr   jitis_scriptingrx   r_   r-   r   r`   ra   rb   rc   r6   rB   s     r=   forward_featureszTResNet.forward_features   s    " 
	59+A+A+C+C 
		a  A	""A	 	 	 	 	 "
 4! ! !AA 		!Ar>   
pre_logitsc                 ^    |r|                      ||          n|                      |          S )N)r   )r}   )r6   rB   r   s      r=   forward_headzTResNet.forward_head   s-    6@Rtyyzy222diiPQllRr>   c                 Z    |                      |          }|                     |          }|S r@   )r   r   r   s     r=   rE   zTResNet.forward   s-    !!!$$a  r>   )r   rT   rU   FrV   r   r   )r   TNr   F)Tr@   )rG   rH   rI   r)   rv   rr   r   ignorer   r   r+   Moduler   rq   r   strr   r   boolr   rE   rJ   rK   s   @r=   r   r   o   sY        K2 K2 K2 K2 K2 K2Z& & & &2 Y    Y) ) ) ) Y	    < <C <hsm < < < <  S S$ S S S S      r>   c                 L   d| v r| S dd l }|                     d|           } |                     d|           } i }|                                 D ]\  }}|                    dd |          }|                    dd |          }|                    d	d
 |          }|                    dd |          }|                    dd |          }|                    dd |          }|                    d          r'|                                                    d          }|||<   |S )Nzbody.conv1.conv.weightr   model
state_dictzconv(\d+)\.0.0c                 N    dt          |                     d                     dS Nconvr   .convrq   grouprB   s    r=   <lambda>z&checkpoint_filter_fn.<locals>.<lambda>  s"    0Ms1771::0M0M0M r>   zconv(\d+)\.0.1c                 N    dt          |                     d                     dS Nr   r   .bnr   r   s    r=   r   z&checkpoint_filter_fn.<locals>.<lambda>  s"    0Ks1771::0K0K0K r>   zconv(\d+)\.0c                 N    dt          |                     d                     dS r   r   r   s    r=   r   z&checkpoint_filter_fn.<locals>.<lambda>  s"    .KS__.K.K.K r>   zconv(\d+)\.1c                 N    dt          |                     d                     dS r   r   r   s    r=   r   z&checkpoint_filter_fn.<locals>.<lambda>	  s"    .IS__.I.I.I r>   zdownsample\.(\d+)\.0c                 N    dt          |                     d                     dS )Ndownsample.r   r   r   r   s    r=   r   z&checkpoint_filter_fn.<locals>.<lambda>
  s#    6ZCPQ

OO6Z6Z6Z r>   zdownsample\.(\d+)\.1c                 N    dt          |                     d                     dS )Nr   r   r   r   r   s    r=   r   z&checkpoint_filter_fn.<locals>.<lambda>  s#    6XCPQ

OO6X6X6X r>   z	bn.weightgh㈵>)regetitemssubendswithabsadd)r   r   r   out_dictkvs         r=   checkpoint_filter_fnr      s?   :--III44Jj99JH  "" 
 
1FF$&M&MqQQFF$&K&KQOOFF?$K$KQOOFF?$I$I1MMFF*,Z,Z\]^^FF*,X,XZ[\\::k"" 	"D!!AOr>   Fc                 X    t          t          | |ft          t          dd          d|S )N)r   r^   r   r$   T)out_indicesflatten_sequential)pretrained_filter_fnfeature_cfg)r   r   r   ry   )variant
pretrainedkwargss      r=   _create_tresnetr     sF     2\dKKK    r>   rd   c                      | ddddddddd	d

|S )NrT   )r      r   )   r   g      ?bilinear)r   r   r   )rU   rU   rU   zbody.conv1.convzhead.fc)
urlro   
input_size	pool_sizecrop_pctinterpolationmeanstd
first_conv
classifierrY   )r   r   s     r=   _cfgr     s4    4}SYJ\'y	 
  r>   ztimm/)	hf_hub_idi+  )r   ro   )r     r   )   r   )r   r   r   )
ztresnet_m.miil_in21k_ft_in1ktresnet_m.miil_in21kztresnet_m.miil_in1kztresnet_l.miil_in1kztresnet_xl.miil_in1ktresnet_m.miil_in1k_448tresnet_l.miil_in1k_448tresnet_xl.miil_in1k_448ztresnet_v2_l.miil_in21k_ft_in1kztresnet_v2_l.miil_in21kr   c           	      X    t          g d          }t          dd| it          |fi |S )N)r   r$      r   )r   	tresnet_mr   )r   ry   r   r   r   
model_argss      r=   r   r   =  s@    ]]]+++J\\:\jA[A[TZA[A[\\\r>   c           	      Z    t          g dd          }t          dd| it          |fi |S )N)r$         r   g333333?r   r   	tresnet_lr   )r   r   r   s      r=   r   r   C  sB    ]]]===J\\:\jA[A[TZA[A[\\\r>   c           	      Z    t          g dd          }t          dd| it          |fi |S )N)r$   r      r   g?r   
tresnet_xlr   )r  r   r   s      r=   r  r  I  sB    ]]]===J]]J]$zB\B\U[B\B\]]]r>   c           	      \    t          g ddd          }t          dd| it          |fi |S )N)r   r$      r   rU   T)r   r   r   tresnet_v2_lr   )r  r   r   s      r=   r  r  O  sD    ]]]FFFJ__j_DD^D^W]D^D^___r>   r   r   r   r   )tresnet_m_miil_in21ktresnet_m_448tresnet_l_448tresnet_xl_448r   )rd   )'__doc__collectionsr   	functoolsr   typingr   rr   torch.nnr+   timm.layersr   r   r   r	   r
   r   _builderr   _manipulater   	_registryr   r   r   __all__r   r   rM   r   r   r   r   default_cfgsr   r   r  r  rG   rY   r>   r=   <module>r     s    $ # # # # #                    a a a a a a a a a a a a a a a a * * * * * * ' ' ' ' ' ' Y Y Y Y Y Y Y Y Y Y+% % % % % % % %P- - - - - - - -`K K K K Kbi K K K\  ,       %$$(D7$;$;$; D7FFF4'2224'222 D7333#t H       $t H      !% H! ! ! (,tg'>'>'>#tg5III#& &  * ] ]W ] ] ] ]
 ] ]W ] ] ] ]
 ^ ^g ^ ^ ^ ^
 ` ` ` ` ` `
  H2..0	' '     r>   