
    NgP                     @   d Z ddlZddlmZmZ ddlmZ ddlmZm	Z	m
Z
 ddlZddlmZ ddlmZmZ ddlmZmZmZmZmZmZmZ dd	lmZ dd
lmZmZ ddlmZm Z  ddl!m"Z"m#Z# e G d d                      Z$e G d d                      Z%dJdZ& G d dej'                  Z( G d dej'                  Z) G d dej'                  Z* G d dej'                  Z+ G d dej'                  Z, G d dej'                  Z-dKd Z.dJd!Z/ e  e/d"d#$           e/d%d#$           e/d&d'$           e/d(d'$           e/d)d'$           e/d*d'd+d,-           e/d.d'd/d,-           e/d0d'd1d,-           e/d2d3$           e/d4d3d+d,-           e/d5d3d/d,-           e/d6d3d1d,-           e/d7d8d+d,-           e/d9d8d/d,-           e/d:d8d1d,-          d;          Z0edLd<e"fd=            Z1edLd<e"fd>            Z2edLd<e"fd?            Z3edLd<e"fd@            Z4edLd<e"fdA            Z5edLd<e"fdB            Z6edLd<e"fdC            Z7edLd<e"fdD            Z8edLd<e"fdE            Z9edLd<e"fdF            Z:edLd<e"fdG            Z;edLd<e"fdH            Z<edLd<e"fdI            Z=dS )Mab   ViTamin

Paper: Designing Scalable Vison Models in the Vision-Language Era
A family of model weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf

@inproceedings{chen2024vitamin,
  title={ViTamin: Designing Scalable Vision Models in the Vision-language Era},
  author={Chen, Jieneng and Yu, Qihang and Shen, Xiaohui and Yuille, Alan and Chen, Liang-Chieh},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2024}
}

Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin

Modifications and timm support by Jieneng Chen 2024

Reference:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
    N)	dataclassfield)partial)OptionalUnionTupleOPENAI_CLIP_MEANOPENAI_CLIP_STD)create_act_layerget_norm_layerget_norm_act_layercreate_conv2dmake_divisibleDropPathHybridEmbed   )build_model_with_cfg)named_applycheckpoint_seq)register_modelgenerate_default_cfgs)VisionTransformercheckpoint_filter_fnc                       e Zd ZU dZeed<   dZeed<   dZe	ed<   dZ
e	ed<   d	Zeed
<   dZeed<   dZeed<   dZeed<   dZeed<   dZeed<   dZeed<   dZee         ed<   dZeed<   dS )
VitConvCfg      @expand_ratioTexpand_output   kernel_sizer   
group_sizeFpre_norm_actdwstride_modeavg2	pool_typedownsample_pool_typegelu	act_layer 
norm_layergh㈵>norm_epsdown_shortcutmlpN)__name__
__module____qualname__r   float__annotations__r   boolr!   intr"   r#   r%   strr'   r(   r*   r,   r-   r.   r   r/        O/var/www/html/ai-engine/env/lib/python3.11/site-packages/timm/models/vitamin.pyr   r   '   s         L%M4KJL$KIs &#&&&IsJHe$(M8D>(((Cr9   r   c                       e Zd ZU dZeeeeedf         f         df         ed<   dZeeeeedf         f         df         ed<   dZ	eed<    e
e          Zeed	<   d
Zeed<   dS )VitCfg)`           .	embed_dim)   r       rB   depths@   
stem_width)default_factoryconv_cfgr+   	head_typeN)r0   r1   r2   rA   r   r   r6   r4   rD   rF   r   r   rH   rI   r7   r8   r9   r:   r<   r<   8   s         9LIuU3c3h/0#56LLL6BFE%U38_,-s23BBBJ 5<<<Hj<<<Isr9   r<   r+   c                 h   t          | t          j                  r| j        d         | j        d         z  | j        z  }|| j        z  }t          j                            | j        dt          j
        d|z                       | j        (t          j                            | j                   d S d S d S )Nr   r          @)
isinstancennConv2dr!   out_channelsgroupsinitnormal_weightmathsqrtbiaszeros_)modulenameschemefan_outs       r:   
_init_convr\   A   s    &")$$ ($Q'&*<Q*??&BUUFM!
q$)C'M*B*BCCC;"GNN6;'''''( ( #"r9   c                   H     e Zd Z	 	 	 	 ddedededed	ed
ef fdZd Z xZ	S )Stemr)   layernorm2dư>Tin_chsout_chsr*   r,   r-   rV   c                 B   t                                                       t          t          ||          |          }|| _        t          ||dd|          | _         ||          | _        t          ||dd|          | _        t          t          |            d S )Nepsr    rB   striderV   r   )super__init__r   r   rb   r   conv1norm1conv2r   r\   )	selfra   rb   r*   r,   r-   rV   norm_act_layer	__class__s	           r:   ri   zStem.__init__K   s     	 !3J	!J!JPXYYY"67AadKKK
#^G,,
"7GQqtLLL
J%%%%%r9   c                     |                      |          }|                     |          }|                     |          }|S N)rj   rk   rl   rm   xs     r:   forwardzStem.forward^   s4    JJqMMJJqMMJJqMMr9   )r)   r_   r`   T)
r0   r1   r2   r6   r7   r3   r5   ri   rt   __classcell__ro   s   @r:   r^   r^   J   s        
 $+"& && & 	&
 & & & & & & & &&      r9   r^   c            	       <     e Zd Z	 	 d	dedededef fdZd Z xZS )
Downsample2dr&   Tdimdim_outr'   rV   c                     t                                                       t          j        dddd          | _        ||k    rt          j        ||d|          | _        d S t          j                    | _        d S )Nr    rB   r   F)r!   rg   paddingcount_include_padrV   )rh   ri   rM   	AvgPool2dpoolrN   expandIdentity)rm   ry   rz   r'   rV   ro   s        r:   ri   zDownsample2d.__init__f   sn     	LQq!W\]]]	'>>)C!$???DKKK+--DKKKr9   c                 Z    |                      |          }|                     |          }|S rq   )r   r   rr   s     r:   rt   zDownsample2d.forwardu   s%    IIaLLKKNNr9   )r&   T)	r0   r1   r2   r6   r7   r5   ri   rt   ru   rv   s   @r:   rx   rx   e   s        
 $( (( ( 	(
 ( ( ( ( ( (      r9   rx   c                   4     e Zd ZdZ	 	 	 	 	 d fd	Zd Z xZS )	StridedConvz downsample 2d as well
    r    rB   r   r@   c                     t                                                       t          t          d          d          }t	          j        |||||          | _         ||          | _        d S )Nr_   r`   rd   )r!   rg   r|   )rh   ri   r   r   rM   rN   projnorm)rm   r!   rg   r|   in_chansrA   r,   ro   s          r:   ri   zStridedConv.__init__~   sh     	^M::EEE
Ih	{SYcjkkk	Jx((			r9   c                 Z    |                      |          }|                     |          }|S rq   )r   r   rr   s     r:   rt   zStridedConv.forward   s%    IIaLLIIaLLr9   )r    rB   r   r    r@   )r0   r1   r2   __doc__ri   rt   ru   rv   s   @r:   r   r   {   sg          ) ) ) ) ) )      r9   r   c                   f     e Zd ZdZ	 	 	 	 	 	 	 dd	ed
edededededededef fdZddZd Z	 xZ
S )MbConvLNBlockzL Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
    r           r    r_   r`   r)   r   ra   rb   rg   	drop_pathr!   r,   r-   r*   r   c
           	      8   t          t          |                                            |||c| _        | _        | _        t          ||	z            }
t          t          ||          |          }|dk    rt          ||dd          | _
        n<||k    rt          j        ||dd          | _
        nt          j                    | _
         ||d	          | _        t          j                    | _        t!          ||
ddd
          | _        t%          |d          | _        t!          |
|
||d|
d          | _        t%          |d          | _        t!          |
|dd          | _        |dk    rt/          |          nt          j                    | _        d S )Nrd   rB   avgT)r'   rV   r   r~   F)	apply_actrf   )inplace)rg   dilationrP   rV   r   )rh   r   ri   rg   ra   rb   r   r   r   rx   shortcutrM   rN   r   pre_normdownr   	conv1_1x1r   act1	conv2_kxkact2	conv3_1x1r   r   )rm   ra   rb   rg   r   r!   r,   r-   r*   r   mid_chsprenorm_act_layerro   s               r:   ri   zMbConvLNBlock.__init__   s    	mT""++---17.T[$, <!788#$6z9$M$MS[\\\Q;;(EPTUUUDMMwIfgqtDDDDMMKMMDM))&EBBBKMM	&vw!$OOO$Y===	&Wk&1W[_a a a$Y===	&wFFF09B),,,BKMMr9   r+   c                 N    t          t          t          |          |            d S )N)rZ   )r   r   r\   )rm   rZ   s     r:   init_weightszMbConvLNBlock.init_weights   s%    GJv666=====r9   c                    |                      |          }|                     |          }|                     |          }|                     |          }|                     |          }|                     |          }|                     |          }|                     |          }|                     |          |z   }|S rq   )	r   r   r   r   r   r   r   r   r   )rm   rs   r   s      r:   rt   zMbConvLNBlock.forward   s    ==##MM!IIaLL NN1IIaLL NN1IIaLL NN1NN1(r9   )r   r   r    r_   r`   r)   r   r+   )r0   r1   r2   r   r6   r3   r7   ri   r   rt   ru   rv   s   @r:   r   r      s          ! +"#"% R  R R  R 	 R
  R  R  R  R  R   R  R  R  R  R  RF> > > >      r9   r   c            	       \     e Zd ZdZ	 	 d	dedeeeeef         f         def fdZd Z	 xZ
S )
MbConvStagesz3 MobileConv for stage 1 and stage 2 of ViTamin
       r    cfgimg_sizer   c                 6   t                                                       d| _        t          ||j                  | _        g }t          |j                  | _        t          |j        d d                   D ][\  }|dk    r|j        |dz
           n|j        fdt          |j        |                   D             }|t          j        | gz  }\t          j        | | _        t          d|j        d         |j        d                   | _        d S )NF)ra   rb   rB   r   r   c                 T    g | ]$}t          |d k    rn|d k    rdnd          %S )r   rB   r   )ra   rb   rg   )r   ).0dry   stage_in_chss     r:   
<listcomp>z)MbConvStages.__init__.<locals>.<listcomp>   s[         -.TT\\s!"#q&&QQa    r9   )rg   r   rA   )rh   ri   grad_checkpointingr^   rF   stemlenrA   
num_stages	enumeraterangerD   rM   
Sequentialstagesr   r   )
rm   r   r   r   r   sblocksry   r   ro   s
          @@r:   ri   zMbConvStages.__init__   s4    	"'N
 
 
	
 cm,,bqb 122 
	/ 
	/FAs1213=1--#.L     sz!}--  F r}f-..FFmV,]1%mA&
 
 
			r9   c                     |                      |          }| j        r4t          j                                        st          | j        |          }n|                     |          }|                     |          }|S rq   )r   r   torchjitis_scriptingr   r   r   rr   s     r:   rt   zMbConvStages.forward   sg    IIaLL" 	59+A+A+C+C 	t{A..AAAAIIaLLr9   )r   r    )r0   r1   r2   r   r<   r   r6   r   ri   rt   ru   rv   s   @r:   r   r      s         
 58	!
 !
!
 CsCx01!
 	!
 !
 !
 !
 !
 !
F      r9   r   c                   *     e Zd Z	 	 d fd	Zd Z xZS )GeGluMlpr)   r   c                 h   t                                                       t          t          d          d          } ||          | _        t          j        ||          | _        t          |          | _	        t          j        ||          | _
        t          j        ||          | _        d S )N	layernormr`   rd   )rh   ri   r   r   r   rM   Linearw0r   actw1w2)rm   in_featureshidden_featuresr*   dropr,   ro   s         r:   ri   zGeGluMlp.__init__  s     	^K88dCCC
J{++	)K99#I..)K99)O[99r9   c                     |                      |          }|                     |                     |                    |                     |          z  }|                     |          }|S rq   )r   r   r   r   r   rr   s     r:   rt   zGeGluMlp.forward  sO    IIaLLHHTWWQZZ  4771::-GGAJJr9   )r)   r   )r0   r1   r2   ri   rt   ru   rv   s   @r:   r   r      sT        
 : : : : : :       r9   r   Fc                 6   |                     dd          }|J t          ||                    dd                    }t          t          |d          |d<   |                    dd	           t          t          | |ft          t          |d
          d|S )Nout_indicesr    r   )r   r   F)backboner   embed_layer
patch_sizer   getter)r   feature_cls)pretrained_filter_fnfeature_cfg)
popr   getr   r   
setdefaultr   r   r   dict)variant
pretrained	embed_cfgkwargsr   r   s         r:   _create_vitaminr     s    **]A..K   	FJJz14M4MNNNH#K(OOOF=
lA&&& 2[hGGG    r9   c                 6    | ddd dddt           t          ddd|S )	Ni  )r    r   r   g?bicubicTzpatch_embed.backbone.stem.conv1head)urlnum_classes
input_size	pool_sizecrop_pctinterpolationfixed_input_sizemeanstd
first_conv
classifierr	   )r   r   s     r:   _cfgr   )  s7    =t 7   r9   zjienengchen/ViTamin-S-LTTr?   )	hf_hub_idr   zjienengchen/ViTamin-Szjienengchen/ViTamin-B-LTTr@   zjienengchen/ViTamin-Bzjienengchen/ViTamin-L-224pxzjienengchen/ViTamin-L-256px)r       r   g      ?)r   r   r   r   zjienengchen/ViTamin-L-336px)r    P  r   zjienengchen/ViTamin-L-384px)r    r?   r?   zjienengchen/ViTamin-L2-224px   zjienengchen/ViTamin-L2-256pxzjienengchen/ViTamin-L2-336pxzjienengchen/ViTamin-L2-384pxzjienengchen/ViTamin-XL-256px  zjienengchen/ViTamin-XL-336pxzjienengchen/ViTamin-XL-384px)z%vitamin_small_224.datacomp1b_clip_lttz!vitamin_small_224.datacomp1b_clipz$vitamin_base_224.datacomp1b_clip_lttz vitamin_base_224.datacomp1b_clipz!vitamin_large_224.datacomp1b_clipz!vitamin_large_256.datacomp1b_clipz!vitamin_large_336.datacomp1b_clipz!vitamin_large_384.datacomp1b_clipz"vitamin_large2_224.datacomp1b_clipz"vitamin_large2_256.datacomp1b_clipz"vitamin_large2_336.datacomp1b_clipz"vitamin_large2_384.datacomp1b_clipz"vitamin_xlarge_256.datacomp1b_clipz"vitamin_xlarge_336.datacomp1b_clipz"vitamin_xlarge_384.datacomp1b_clipreturnc           
          t          dddt          dd          d          }t          d	d
dt          ddd|          }t	          dd| it          |fi |}|S )N)rE      r?   rB      r   rE   r_   r`   r,   r-   1drA   rD   rF   rH   rI   r?         rK   Fr   rA   depth	num_heads	mlp_layer	mlp_ratioclass_tokenglobal_poolr   vitamin_small_224r   )r   r<   r   r   r   r   r   r   r   
model_argsmodels        r:   r   r   `  s     $
 
 
 	 	 	I R1Bu	  J eeJe$zJdJd]cJdJdeeELr9   c           
          t          dddt          dd          d          }t          d	d
dt          ddd|          }t	          dd| it          |fi |}|S )N)r   r   r@   r   r   r_   r`   r   r   r   r@   r      rK   Fr   r   vitamin_base_224r   )r  r  r  s        r:   r  r  t  s    !$
 
 
 	 	 	I R2Ru	C C CJ dd:djIcIc\bIcIcddELr9   c           
          t          dddt          dd          d          }t          d	d
dt          ddd|          }t	          dd| it          |fi |}|S )N   i@  r   r   r
  r_   r`   r   r   r   r         rK   Fr   r   vitamin_large_224r   )r  r  r  s        r:   r  r    s    "$
 
 
 	 	 	I bB(bu	  J eeJe$zJdJd]cJdJdeeELr9   c                     t          dddt          dd          d          }t          d	d
ddt          ddd|	  	        }t	          dd| it          |fi |}|S )Nr	  r   r
  r_   r`   r   r   r   r   r   r  r  rK   Fr   	r   rA   r   r   r   r   r   r   r   vitamin_large_256r   )r  r  r  s        r:   r  r        "$
 
 
 	 	 	I B"\^u	C C CJ eeJe$zJdJd]cJdJdeeELr9   c                     t          dddt          dd          d          }t          d	d
ddt          ddd|	  	        }t	          dd| it          |fi |}|S )Nr	  r   r
  r_   r`   r   r   r   r   r   r  r  rK   Fr   r  vitamin_large_336r   )r  r  r  s        r:   r  r    s    "$
 
 
 	 	 	I B"\^u	  J eeJe$zJdJd]cJdJdeeELr9   c                     t          dddt          dd          d          }t          d	d
ddt          ddd|	  	        }t	          dd| it          |fi |}|S )Nr	  r   r
  r_   r`   r   r   r   r?   r   r  r  rK   Fr   r  vitamin_large_384r   )r  r  r  s        r:   r  r    r  r9   c           
          t          dddt          dd          d          }t          d	d
dt          ddd|          }t	          dd| it          |fi |}|S )Nr	  r   r
  r_   r`   r   r   r   r   r  r  rK   Fr   r   vitamin_large2_224r   )r  r  r  s        r:   r  r    s    "$
 
 
 	 	 	I bB(bu	  J ffZf4PZKeKe^dKeKeffELr9   c                     t          dddt          dd          d          }t          d	d
ddt          ddd|	  	        }t	          dd| it          |fi |}|S )Nr	  r   r
  r_   r`   r   r   r   r   r   r  r  rK   Fr   r  vitamin_large2_256r   )r  r  r  s        r:   r  r        "$
 
 
 	 	 	I B"\^u	C C CJ ffZf4PZKeKe^dKeKeffELr9   c                     t          dddt          dd          d          }t          d	d
ddt          ddd|	  	        }t	          dd| it          |fi |}|S )Nr	  r   r
  r_   r`   r   r   r   r   r   r  r  rK   Fr   r  vitamin_large2_336r   )r  r  r  s        r:   r  r    s    "$
 
 
 	 	 	I B"\^u	  J ffZf4PZKeKe^dKeKeffELr9   c                     t          dddt          dd          d          }t          d	d
ddt          ddd|	  	        }t	          dd| it          |fi |}|S )Nr	  r   r
  r_   r`   r   r   r   r?   r   r  r  rK   Fr   r  vitamin_large2_384r   )r  r  r  s        r:   r  r    r  r9   c                     t          dddt          dd          d          }t          d	d
ddt          dddd|
  
        }t	          	 dd| it          |fi |}|S )Nr>   r?   r   r   r>   r_   r`   r   r   r   r   r       r  rK   Fr   none
r   rA   r   r   r   r   r   r   	pos_embedr   vitamin_xlarge_256r   r%  r  r  s        r:   r%  r%  #  s    "$
 
 
 	 	 	I B"\^u)U U UJ S S)3S7;J7Q7Q&7Q7QS SELr9   c                     t          dddt          dd          d          }t          d	d
ddt          dddd|
  
        }t	          dd| it          |fi |}|S )Nr   r   r>   r_   r`   r   r   r   r   r   r!  r  rK   Fr   r"  r#  r%  r   r&  r  r  s        r:   vitamin_xlarge_336r(  7      "$
 
 
 	 	 	I B"\^u)U U UJ ffZf4PZKeKe^dKeKeffELr9   c                     t          dddt          dd          d          }t          d	d
ddt          dddd|
  
        }t	          dd| it          |fi |}|S )Nr   r   r>   r_   r`   r   r   r   r?   r   r!  r  rK   Fr   r"  r#  vitamin_xlarge_384r   )r+  r  r  s        r:   r+  r+  J  r)  r9   r   )FN)F)>r   rT   dataclassesr   r   	functoolsr   typingr   r   r   r   torch.nnrM   	timm.datar
   r   timm.layersr   r   r   r   r   r   r   _builderr   _manipulater   r   	_registryr   r   vision_transformerr   r   r   r<   r\   Moduler^   rx   r   r   r   r   r   r   default_cfgsr   r  r  r  r  r  r  r  r  r  r%  r(  r+  r8   r9   r:   <module>r8     s*   *  ( ( ( ( ( ( ( (       ) ) ) ) ) ) ) ) ) )        7 7 7 7 7 7 7 7* * * * * * * * * * * * * * * * * * * * * * * * 4 4 4 4 4 4 4 4 < < < < < < < < G G G G G G G G                 ( ( ( (    29   6    29   ,    ")   .; ; ; ; ;BI ; ; ;|- - - - -29 - - -`    ry   0   "	 	 	 	 %$-1T-3.@ .@ .@)-)s*< *< *<,0D-3-@ -@ -@(,)s)< )< )<)-/S*B *B *B)-/S 3*0 *0 *0 *./S 3*0 *0 *0 *./S 3*0 *0 *0 +/$0d+D +D +D*.$0d 3+0 +0 +0 +/$0d 3+0 +0 +0 +/$0d 3+0 +0 +0 +/$0d 3+0 +0 +0 +/$0d 3+0 +0 +0 +/$0d 3+0 +0 +0K(& (& ( (V  5F    &  4E    $  5F    &  5F    $  5F    &  5F    $  6G    &  6G    $  6G    &  6G    $  6G    &  6G    $  6G      r9   