
    NgQ                     (   d Z ddlZddlmZ ddlmZmZ ddlZddlm	c m
Z ddlm	Z	 ddlmZmZ ddlmZmZmZmZmZmZmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZm Z  dgZ!e G d de	j"                              Z# G d de	j"                  Z$ G d de	j"                  Z% G d de	j"                  Z& G d de	j"                  Z' G d de	j"                  Z(d-dZ)d Z*d.dZ+d/dZ, e  e,d d!d"           e,d d!d"           e,d d#d$d%           e,d d#d$d%           e,d d#d$d%           e,d d$d"          d&          Z-ed.d'e(fd(            Z.ed.d'e(fd)            Z/ed.d'e(fd*            Z0ed.d'e(fd+            Z1ed.d'e(fd,            Z2dS )0a#   EdgeNeXt

Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications`
 - https://arxiv.org/abs/2206.10589

Original code and weights from https://github.com/mmaaz60/EdgeNeXt

Modifications and additions for timm by / Copyright 2022, Ross Wightman
    N)partial)OptionalTuple)nnIMAGENET_DEFAULT_MEANIMAGENET_DEFAULT_STD)trunc_normal_tf_DropPathLayerNorm2dMlpcreate_conv2dNormMlpClassifierHeadClassifierHead   )build_model_with_cfg)register_notrace_module)named_applycheckpoint_seq)register_modelgenerate_default_cfgsEdgeNeXtc                   >     e Zd Zd fd	Zdeeeef         fdZ xZS )PositionalEncodingFourier       '  c                     t                                                       t          j        |dz  |d          | _        dt
          j        z  | _        || _        || _	        || _
        d S )N   r   )kernel_size)super__init__r   Conv2dtoken_projectionmathpiscaletemperature
hidden_dimdim)selfr)   r*   r(   	__class__s       P/var/www/html/ai-engine/env/lib/python3.11/site-packages/timm/models/edgenext.pyr"   z"PositionalEncodingFourier.__init__   s]     "	*q.#1 M M M[
&$    shapec           
      d   | j         j        j        }| j         j        j        }t	          j        |                              |t          j                   }|                    dt          j	                  }|                    dt          j	                  }d}||d d dd d d f         |z   z  | j
        z  }||d d d d dd f         |z   z  | j
        z  }t	          j        | j        t          j        |                              t          j	                  }| j        dt	          j        |dd	          z  | j        z  z  }|d d d d d d d f         |z  }	|d d d d d d d f         |z  }
t	          j        |	d d d d d d d
d df                                         |	d d d d d d dd df                                         fd                              d          }	t	          j        |
d d d d d d d
d df                                         |
d d d d d d dd df                                         fd                              d          }
t	          j        |
|	fd                              d
ddd          }|                      |                    |                    }|S )N)devicedtyper   )r2   r   ư>)r2   r1   floor)rounding_moder      r*      )r$   weightr1   r2   torchzerostoboolcumsumfloat32r'   aranger)   int64r(   divstacksincosflattencatpermute)r+   r/   r1   r2   inv_masky_embedx_embedepsdim_tpos_xpos_yposs               r-   forwardz!PositionalEncodingFourier.forward'   s   &-4%,2K&&))uz)JJJ//!5=/99//!5=/99WQQQQQQY/#56CWQQQ233Y/#56CT_EKOOORRSXS`aa Q5!7)S)S)S%SVZVe%ef111aaa&.111aaa&.111aaaADqD=!%%''111aaaADqD=!%%'')./1 1 118 	 111aaaADqD=!%%''111aaaADqD=!%%'')./1 1 118 	 iA...66q!QBB##CFF5MM22
r.   )r   r   r   )__name__
__module____qualname__r"   r   intrR   __classcell__r,   s   @r-   r   r      sc             U3S=1        r.   r   c            
       b     e Zd Zdddddd eej        d          ej        df	 fd		Zd
 Z xZ	S )	ConvBlockN   r   Tr7   r3   rM           c                    t                                                       |p|}|dk    p||k    | _        t          ||||d|          | _         ||          | _        t          |t          ||z            |	          | _        |dk    r)t          j
        |t          j        |          z            nd | _        |
dk    rt          |
          nt          j                    | _        d S )Nr   T)r    stride	depthwisebias	act_layerr   r]   )r!   r"   shortcut_after_dwr   conv_dwnormr   rV   mlpr   	Parameterr;   onesgammar   Identity	drop_path)r+   r*   dim_outr    r_   	conv_biasexpand_ratiols_init_value
norm_layerrc   rl   r,   s              r-   r"   zConvBlock.__init__C   s     	.S!'!!=sg~$k&DW`b b bJw''	wL7$: ; ;yQQQJWZ[J[J[R\-%*W2E2E"EFFFae
09B),,,BKMMr.   c                 L   |}|                      |          }| j        r|}|                    dddd          }|                     |          }|                     |          }| j        
| j        |z  }|                    dddd          }||                     |          z   }|S )Nr   r   r9   r   )re   rd   rI   rf   rg   rj   rl   )r+   xshortcuts      r-   rR   zConvBlock.forwardZ   s    LLOO! 	HIIaAq!!IIaLLHHQKK:!
QAIIaAq!!t~~a(((r.   
rS   rT   rU   r   r   	LayerNormGELUr"   rR   rW   rX   s   @r-   rZ   rZ   B   s         wr|666gR R R R R R.      r.   rZ   c                   X     e Zd Z	 	 	 	 d fd	Zd Zej        j        d             Z xZ	S )CrossCovarianceAttn   Fr]   c                    t                                                       || _        t          j        t          j        |dd                    | _        t          j        ||dz  |          | _	        t          j
        |          | _        t          j        ||          | _        t          j
        |          | _        d S )Nr   r9   )ra   )r!   r"   	num_headsr   rh   r;   ri   r(   LinearqkvDropout	attn_dropproj	proj_drop)r+   r*   r|   qkv_biasr   r   r,   s         r-   r"   zCrossCovarianceAttn.__init__l   s     	"<
9a(C(CDD9S#'999I..Ic3''	I..r.   c                 ~   |j         \  }}}|                     |                              ||d| j        d                              ddddd          }|                    d          \  }}}t          j        |d          t          j        |d                              dd          z  | j	        z  }	|	
                    d          }	|                     |	          }	|	|z  }|                    dddd                              |||          }|                     |          }|                     |          }|S )	Nr9   r4   r   r   r7   r   r8   )r/   r~   reshaper|   rI   unbindF	normalize	transposer(   softmaxr   r   r   )
r+   rs   BNCr~   qkvattns
             r-   rR   zCrossCovarianceAttn.forward}   s!   '1ahhqkk!!!Q4>2>>FFq!QPQSTUU**Q--1a A2&&&QB)?)?)?)I)I"b)Q)QQUYUee|||##~~d##AXIIaAq!!))!Q22IIaLLNN1r.   c                     dhS )Nr(    r+   s    r-   no_weight_decayz#CrossCovarianceAttn.no_weight_decay   s
    r.   )rz   Fr]   r]   )
rS   rT   rU   r"   rR   r;   jitignorer   rW   rX   s   @r-   ry   ry   k   s{         / / / / / /"    Y      r.   ry   c                   h     e Zd Zddddddd eej        d          ej        dddf fd	Zd	 Z xZ	S )
SplitTransposeBlockr   rz   r7   Tr3   r\   r]   c           
         t                                                       t          t          t	          j        ||z                      t          t	          j        ||z                                }|| _        t          d|dz
            | _        g }t          | j                  D ])}|
                    t          ||dd|                     *t          j        |          | _        d | _        |rt!          |          | _         |	|          | _        |dk    r)t          j        |t'          j        |          z            nd | _        t-          |||||          | _         |	|d	          | _        t3          |t          ||z            |

          | _        |dk    r)t          j        |t'          j        |          z            nd | _        |dk    rt9          |          nt          j                    | _        d S )Nr   r9   T)r    r`   ra   r8   r   )r|   r   r   r   r3   r\   rb   r]   )r!   r"   maxrV   r%   ceilr5   width
num_scalesrangeappendr   r   
ModuleListconvspos_embdr   norm_xcarh   r;   ri   	gamma_xcary   xcarf   r   rg   rj   r   rk   rl   )r+   r*   r   r|   ro   use_pos_embrn   r   rp   rq   rc   rl   r   r   r   r   ir,   s                    r-   r"   zSplitTransposeBlock.__init__   s     	C	#
"23344c$*SJEV:W:W6X6XYY
aa00t'' 	e 	eALLuedYbcccdddd]5))
 	?5#>>>DM"
3JWZ[J[J[mejoo&EFFFae&9x9Xac c c Js---	sCs 233yIIIFSVWFWFWR\-%*S//"ABBB]a
09B),,,BKMMr.   c           	      &   |}|                     t          | j                  dz   d          }g }|d         }t          | j                  D ]6\  }}|dk    r|||         z   } ||          }|                    |           7|                    |d                    t          j        |d          }|j        \  }}	}
}|                    ||	|
|z            	                    ddd          }| j
        R| 
                    ||
|f                              |d|j        d                   	                    ddd          }||z   }||                     | j        |                     |                     |                    z            z   }|                    ||
||	          }|                     |          }|                     |          }| j        
| j        |z  }|	                    dddd          }||                     |          z   }|S )Nr   r8   r   r4   r   r9   )chunklenr   	enumerater   r;   rH   r/   r   rI   r   rl   r   r   r   rf   rg   rj   )r+   rs   rt   spxspospr   convr   r   HWpos_encodings                r-   rR   zSplitTransposeBlock.forward   s    ggc$*oo)qg11V ,, 	 	GAt1uu#a&[bBJJrNNNN

3r7Ic1 W
1aIIaAE""**1a33=$==!Q33;;Ar171:NNVVWXZ[]^__LL At~q9I9I0J0JJKKKIIaAq!! IIaLLHHQKK:!
QAIIaAq!!t~~a(((r.   ru   rX   s   @r-   r   r      s         wr|666g%R %R %R %R %R %RN! ! ! ! ! ! !r.   r   c                   n     e Zd Zdddddddddddde eej        d	
          ej        f fd	Zd Z	 xZ
S )EdgeNeXtStager   r   r7   r[   FT      ?Nr3   r\   c                 R   t                                                       d| _        |s|dk    rt          j                    | _        n<t          j         ||          t          j        ||dd|                    | _        |}g }t          |          D ]|}|||z
  k     r=|	                    t          |||r|dk    r|nd|||	|||         ||
  
                   n2|	                    t          ||||	|
||||         ||
  
                   |}}t          j        | | _        d S )NFr   r   r    r_   ra   r   )
r*   rm   r_   rn   r    ro   rp   rl   rq   rc   )
r*   r   r|   ro   r   rn   rp   rl   rq   rc   )r!   r"   grad_checkpointingr   rk   
downsample
Sequentialr#   r   r   rZ   r   blocks)r+   in_chsout_chsr_   depthnum_global_blocksr|   scalesr    ro   r   downsample_blockrn   rp   drop_path_ratesrq   norm_layer_clrc   stage_blocksr   r,   s                       r-   r"   zEdgeNeXtStage.__init__   sy   ( 	"' 	v{{ kmmDOO m
6""	&'qSSS DO Fu 	 	A5,,,,##" ')9Ka1ffvv!"+$/%1&3"1!"4#0"+      ##'"#)"+%1$/"+&3"1!"4#0"+     FFm\2r.   c                     |                      |          }| j        r4t          j                                        st          | j        |          }n|                     |          }|S N)r   r   r;   r   is_scriptingr   r   r+   rs   s     r-   rR   zEdgeNeXtStage.forward"  sZ    OOA" 	59+A+A+C+C 	t{A..AAAAr.   )rS   rT   rU   r   r   r   rv   rw   r"   rR   rW   rX   s   @r-   r   r      s        
 " "!'",D999g%A3 A3 A3 A3 A3 A3F      r.   r   c                       e Zd Zddddddddd	d
dddddddej        ddf fd	Zej        j        dd            Z	ej        j        d d            Z
ej        j        dej        fd            Zd!dedee         fdZd ZddefdZd Z xZS )"r   r9     avg   0   X      r9   r9   	   r9   )r   r   r   r   )r9      r[   r   )rz   rz   rz   rz   )r   r   r9   r7   )FTFFr3   r   r7   FTpatchr]   c           
         t                                                       || _        || _        || _        t          t          d          }t          t          j        d          }g | _	        |dv sJ |dk    rGt          j
        t          j        ||d         dd|           ||d                             | _        nGt          j
        t          j        ||d         ddd|	           ||d                             | _        d}g }d
 t          j        d|t          |                                        |          D             }|d         }t#          d          D ]}|dk    s|dk    rdnd}||z  }|                    t'          d$i d|d||         d|d||         d||         d||         d||         d|	|         d|d||         d|
|         d|d|d|d|d|d|           ||         }| xj	        t)          ||d|           gz  c_	        t          j
        | | _        |d          x| _        | _        |r8 || j                  | _        t3          | j        ||| j        !          | _        n;t          j                    | _        t9          | j        ||| j        |"          | _        t;          t          t<          |#          |            d S )%Nr3   r\   )r   overlapr   r   r7   r   r   )r    r_   paddingra   c                 6    g | ]}|                                 S r   )tolist).0rs   s     r-   
<listcomp>z%EdgeNeXt.__init__.<locals>.<listcomp>Y  s     eee1AHHJJeeer.   r   r   r   r   r_   r   r   r|   r   r   ro   r    r   rp   r   rn   rq   r   rc   zstages.)num_chs	reductionmoduler4   )	pool_type	drop_rate)r   r   rq   )head_init_scaler   )r!   r"   num_classesglobal_poolr   r   r   r   rv   feature_infor   r#   stemr;   linspacesumsplitr   r   r   dictstagesnum_featureshead_hidden_sizenorm_prer   headrk   r   r   _init_weights)r+   in_chansr   r   dimsdepthsglobal_block_countskernel_sizesheads	d2_scalesr   rp   r   ro   r   rn   	stem_typehead_norm_firstrc   drop_path_rater   rq   r   curr_strider   dp_ratesr   r   r_   r,   s                                r-   r"   zEdgeNeXt.__init__,  s   . 	&&"[d333
$77700000	(DG19UUU
47## DII
 	(DG1f[deee
47## DI
 eeq.#f++(V(V(\(\]c(d(deeeaq 	e 	eA%**a!eeQQF6!KMM-   vQ v Qii	
 #6a"8"8  (( !) !|| *\ )OO (NN ,m "2!1 $) &:  ,m!" $)#   ( !WF$vUb_`UbUb"c"c"c!ddmV,48H<D1 	&Jt'899DM&!%.	  DII KMMDM-!%.%  DI 	GM?KKKTRRRRRr.   c                 0    t          d|rdng d          S )Nz^stemz^stages\.(\d+)))z^stages\.(\d+)\.downsample)r   )z^stages\.(\d+)\.blocks\.(\d+)N)z	^norm_pre)i )r   r   )r   )r+   coarses     r-   group_matcherzEdgeNeXt.group_matcher  s9    (. $$ 5 5 5
 
 
 	
r.   c                 (    | j         D ]	}||_        
d S r   )r   r   )r+   enabless      r-   set_grad_checkpointingzEdgeNeXt.set_grad_checkpointing  s(     	* 	*A#)A  	* 	*r.   returnc                     | j         j        S r   )r   fcr   s    r-   get_classifierzEdgeNeXt.get_classifier  s    y|r.   Nr   r   c                 J    || _         | j                            ||           d S r   )r   r   reset)r+   r   r   s      r-   reset_classifierzEdgeNeXt.reset_classifier  s&    &	[11111r.   c                     |                      |          }|                     |          }|                     |          }|S r   )r   r   r   r   s     r-   forward_featureszEdgeNeXt.forward_features  s6    IIaLLKKNNMM!r.   
pre_logitsc                 ^    |r|                      |d          n|                      |          S )NT)r  )r   )r+   rs   r  s      r-   forward_headzEdgeNeXt.forward_head  s,    0:Ltyyty,,,		!Lr.   c                 Z    |                      |          }|                     |          }|S r   )r  r  r   s     r-   rR   zEdgeNeXt.forward  s-    !!!$$a  r.   F)Tr   )rS   rT   rU   r   rw   r"   r;   r   r   r   r  Moduler  rV   r   strr  r  r>   r  rR   rW   rX   s   @r-   r   r   +  st        " ,%"3"!g+_S _S _S _S _S _SB Y
 
 
 
 Y* * * * Y	    2 2C 2hsm 2 2 2 2  M M$ M M M M      r.   r   c                     t          | t          j                  rEt          | j        d           | j        &t          j                            | j                   d S d S t          | t          j                  rt          | j        d           t          j                            | j                   |rFd|v rD| j        j	        
                    |           | j        j	        
                    |           d S d S d S d S )Ng{Gz?)stdhead.)
isinstancer   r#   r
   r:   ra   initzeros_r}   datamul_)r   namer   s      r-   r   r     s    &")$$ 	3C0000;"GNN6;''''' #"	FBI	&	& 3C0000
v{### 	3GtOOM##O444K!!/222223 3	3 	3OOr.   c                    d| v sd| v r| S d| v r	| d         } nd| v r	| d         } nd| v r| d         } i }ddl }|                                 D ]\  }}|                    dd	          }|                    d
d|          }|                    dd|          }|                    dd          }|                    dd          }|                    dd          }|                    d          r|                    dd          }|j        dk    r8d|vr4|                                |         j        }|                    |          }|||<   |S )z Remap FB checkpoints -> timm zhead.norm.weightznorm_pre.weight	model_emamodel
state_dictr   Nzdownsample_layers.0.zstem.zstages.([0-9]+).([0-9]+)zstages.\1.blocks.\2z#downsample_layers.([0-9]+).([0-9]+)zstages.\1.downsample.\2dwconvre   pwconvzmlp.fcr  zhead.fc.znorm.rf   z	head.normr   r   )	reitemsreplacesub
startswithndimr!  r/   r   )r!  r   out_dictr$  r   r   model_shapes          r-   checkpoint_filter_fnr,    s   Z''+<
+J+J j  ,

	J		(

		#	#-
HIII  ""  1II,g66FF.0FJJFF9;UWXYYIIh	**IIh))IIgz**<<   	/		&+..A6Q;;6??**,,Q/5K		+&&AOr.   Fc                 \    t          t          | |ft          t          dd          d|}|S )N)r   r   r   r9   T)out_indicesflatten_sequential)pretrained_filter_fnfeature_cfg)r   r   r,  r   )variant
pretrainedkwargsr   s       r-   _create_edgenextr5    sF     ':1\dKKK  	 E
 Lr.    c                 4    | dddddt           t          ddd
|S )	Nr   )r9      r8  )rz   rz   g?bicubiczstem.0zhead.fc)
urlr   
input_size	pool_sizecrop_pctinterpolationmeanr  
first_conv
classifierr   )r:  r4  s     r-   _cfgrB    s5    =v)%.Bi   r.   ztimm/)r9      rC  )	hf_hub_idtest_input_sizetest_crop_pctgffffff?)r9   @  rG  )rD  r=  rE  rF  )zedgenext_xx_small.in1kzedgenext_x_small.in1kzedgenext_small.usi_in1kzedgenext_base.usi_in1kzedgenext_base.in21k_ft_in1kzedgenext_small_rw.sw_in1kr  c           	      X    t          ddd          }t          dd| it          |fi |S )N)r   r      r   r   r7   r7   r7   r7   r   r   r   edgenext_xx_smallr3  )rL  r   r5  r3  r4  
model_argss      r-   rL  rL    sC     \0AVVVJeeJe$zJdJd]cJdJdeeer.   c           	      X    t          ddd          }t          dd| it          |fi |S )Nr   )r   @   d      rJ  rK  edgenext_x_smallr3  )rT  rM  rN  s      r-   rT  rT    sC     \0B,WWWJdd:djIcIc\bIcIcdddr.   c           	      V    t          dd          }t          dd| it          |fi |S )Nr   )r   `      i0  r   r   edgenext_smallr3  )rY  rM  rN  s      r-   rY  rY  $  sA     \0BCCCJbbbtJGaGaZ`GaGabbbr.   c           	      ^    t          g dg d          }t          dd| it          |fi |S )Nr   )P   rW  rC  iH  rX  edgenext_baser3  )r\  rM  rN  s      r-   r\  r\  /  sK     \\\0C0C0CDDDJaa
ad:F`F`Y_F`F`aaar.   c           	      \    t          ddddd          }t          d	d| it          |fi |S )
Nr   )r   rV  rS  i  TFr   )r   r   r   rn   r   edgenext_small_rwr3  )r^  rM  rN  s      r-   r^  r^  :  sR    "4)E E EJ eeJe$zJdJd]cJdJdeeer.   )Nr   r  )r6  )3__doc__r%   	functoolsr   typingr   r   r;   torch.nn.functionalr   
functionalr   	timm.datar   r	   timm.layersr
   r   r   r   r   r   r   _builderr   _features_fxr   _manipulater   r   	_registryr   r   __all__r  r   rZ   ry   r   r   r   r   r,  r5  rB  default_cfgsrL  rT  rY  r\  r^  r   r.   r-   <module>rl     s           " " " " " " " "                 A A A A A A A A* * * * * * * * * * * * * * * * * * * * * * * * 1 1 1 1 1 1 4 4 4 4 4 4 4 4 < < < < < < < <, ! ! ! ! !	 ! ! !H& & & & &	 & & &R$ $ $ $ $") $ $ $NI I I I I") I I IXJ J J J JBI J J JZF F F F Fry F F FR
3 
3 
3 
3  >       %$"d%S: : : "T%S: : :  $t}C      #d}C   $(4}C$ $ $ "&%S" " "'& &  4 f fX f f f f e eH e e e e c c( c c c c b b b b b b f fX f f f f f fr.   