
    g}              	          d Z ddlZddlZddlmZmZ ddlZddlZddl	Zddlm
Z
mZ ddlmZmZmZ ddlmZ ddlmZmZmZmZ dd	lmZ dd
lmZmZmZmZmZ ddlm Z  ddl!m"Z"  ej#        e$          Z%dZ&dZ'g dZ(dZ)dZ*d@deee+f         fdZ, G d dej-                  Z. G d dej/                  Z0 G d dej1                  Z2 G d dej3                  Z4 G d dej1                  Z5dAd!ej
        d"e6d#e+dej
        fd$Z7 G d% d&ej1                  Z8dBd(Z9 G d) d*ej1                  Z: G d+ d,ej1                  Z; G d- d.ej1                  Z< G d/ d0ej1                  Z= G d1 d2ej1                  Z> G d3 d4e          Z?d5Z@d6ZA ed7e@           G d8 d9e?                      ZB ed:e@           G d; d<e?                      ZC ed=e@           G d> d?e?e                       ZDdS )Cz9PyTorch BiT model. Also supports backbone for ViT hybrid.    N)OptionalTuple)Tensornn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BackboneOutputBaseModelOutputWithNoAttention(BaseModelOutputWithPoolingAndNoAttention$ImageClassifierOutputWithNoAttention)PreTrainedModel)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings)BackboneMixin   )	BitConfigr   zgoogle/bit-50)r   i      r   z	tiger catr   returnc                 &   d}| |dz
  ||dz
  z  z   dz  } | |fS t          | t                    r`|                                 } | dk    r,|dk    r!||dz
  z  dz  dk    r|dz
  ||dz
  z  z   dz  } nd} d}n| dk    rd} n|dz
  ||dz
  z  z   dz  } | |fS )	al  
    Utility function to get the tuple padding value given the kernel_size and padding.

    Args:
        padding (Union[`str`, `int`], *optional*):
            Padding value, can be either `"same"`, `"valid"`. If a different value is provided the default padding from
            PyTorch is used.
        kernel_size (`int`, *optional*, defaults to 7):
            Kernel size of the convolution layers.
        stride (`int`, *optional*, defaults to 1):
            Stride value of the convolution layers.
        dilation (`int`, *optional*, defaults to 1):
            Dilation value of the convolution layers.
    FNr      samer   Tvalid)
isinstancestrlower)paddingkernel_sizestridedilationdynamics        `/var/www/html/ai-engine/env/lib/python3.11/site-packages/transformers/models/bit/modeling_bit.pyget_padding_valuer(   <   s     GQJ(kAo">>1D'3 I--//f{{K!O <AQFF"QJ(kAo*FF1L GG 
h+/&BBqHGG    c                   6     e Zd ZdZ	 	 	 	 	 	 d fd	Zd Z xZS )	WeightStandardizedConv2dzConv2d with Weight Standardization. Includes TensorFlow compatible SAME padding. Used for ViT Hybrid model.

    Paper: [Micro-Batch Training with Batch-Channel Normalization and Weight
    Standardization](https://arxiv.org/abs/1903.10520v2)
    r   SAMEFư>c
           
          t          ||||          \  }}
t                                          ||||||||           |
rt          |||          | _        nd | _        |	| _        d S )N)r$   r%   )r$   r"   r%   groupsbias)r(   super__init__DynamicPad2dpadeps)self
in_channelout_channelsr#   r$   r"   r%   r/   r0   r5   
is_dynamic	__class__s              r'   r2   z!WeightStandardizedConv2d.__init__l   s     0V^fggg 	 		
 		
 		
  	#KBBDHHDHr)   c           	      |   | j         |                      |          }t          j                            | j                            d| j        d          d d dd| j                                      | j                  }t          j        	                    ||| j
        | j        | j        | j        | j                  }|S )Nr   T        )trainingmomentumr5   )r4   r   
functional
batch_normweightreshaper8   r5   
reshape_asconv2dr0   r$   r"   r%   r/   )r6   hidden_staterB   s      r'   forwardz WeightStandardizedConv2d.forward   s    888L11L))K4#4b994PT_bhlhp * 
 

*T[
!
! 	 }++&$)T[$,W[Wb
 
 r)   )r   r,   r   r   Fr-   __name__
__module____qualname____doc__r2   rG   __classcell__r:   s   @r'   r+   r+   e   sj               :	 	 	 	 	 	 	r)   r+   c                   *     e Zd ZdZd fd	Zd Z xZS )BitGroupNormActivationzQ
    A module that combines group normalization with an activation function.
    h㈵>Tc                     t          t          |                               |j        |||           |rt          |j                 | _        d S t          j                    | _        d S )N)r5   affine)	r1   rP   r2   
num_groupsr   
hidden_act
activationr   Identity)r6   confignum_channelsr5   rS   apply_activationr:   s         r'   r2   zBitGroupNormActivation.__init__   s\    $d++44V5FZ]fl4mmm 	,$V%67DOOO kmmDOOOr)   c                     t           j                            || j        | j        | j        | j                  }|                     |          }|S N)r   r@   
group_normrT   rB   r0   r5   rV   )r6   rF   s     r'   rG   zBitGroupNormActivation.forward   sB    }//dot{\`\egkgopp|44r)   )rQ   TTrH   rN   s   @r'   rP   rP      sV         , , , , , ,      r)   rP   c                   *     e Zd ZdZd fd	Zd Z xZS )r3   z
    A module that wraps dynamic padding of any input, given the parameters of the convolutional layer and the input
    hidden states.
    r   c                 *   t                                                       t          |t                    r||f}t          |t                    r||f}t          |t                    r||f}|| _        || _        || _        || _        d }|| _        d S )Nc                 v    t          t          j        | |z            dz
  |z  |dz
  |z  z   dz   | z
  d          S )Nr   r   )maxmathceil)xr#   r$   r%   s       r'   compute_paddingz.DynamicPad2d.__init__.<locals>.compute_padding   sF    	!f*--1V;{QRZ>ZZ]^^abbdefffr)   )	r1   r2   r   intr#   r$   r%   valuere   )r6   r#   r$   r%   rg   re   r:   s         r'   r2   zDynamicPad2d.__init__   s    k3'' 	5&4Kfc"" 	&f%Fh$$ 	, (+H& 
	g 	g 	g  /r)   c           	         |                                 dd          \  }}|                     || j        d         | j        d         | j        d                   }|                     || j        d         | j        d         | j        d                   }|dk    s|dk    r=t
          j                            ||dz  ||dz  z
  |dz  ||dz  z
  g| j                  }|S )Nr   r   r   )rg   )	sizere   r#   r$   r%   r   r@   r4   rg   )r6   inputinput_heightinput_widthpadding_heightpadding_widths         r'   __call__zDynamicPad2d.__call__   s    $)JJLL$5!k --lD<LQ<OQUQ\]^Q_aeanopaqrr,,[$:J1:Mt{[\~_c_lmn_opp A!2!2M%%!Q&!MQ$66"a'"^q%88	 j & 	 	E r)   )r   )rI   rJ   rK   rL   r2   rp   rM   rN   s   @r'   r3   r3      sV         
/ / / / / /,      r)   r3   c                   <     e Zd ZdZ	 	 	 	 	 	 ddef fd	Zd
 Z xZS )BitMaxPool2dz1Tensorflow like 'SAME' wrapper for 2D max poolingNr   Fr   r   r   Tr#   c                    t          |t          j        j                  r|n||f}t          |t          j        j                  r|n||f}t          |t          j        j                  r|n||f}t	                                          |||||           |rt          ||||          | _        d S t          j	                    | _        d S r\   )
r   collectionsabcIterabler1   r2   r3   r4   r   rW   )	r6   r#   r$   r%   	ceil_moder"   padding_valueuse_dynamic_paddingr:   s	           r'   r2   zBitMaxPool2d.__init__   s     &0[_=U%V%Vvkk]hju\v%fko.FGG]fV\M])(KO4LMMg88T\^fSgfgxKKK 	%#K=QQDHHH{}}DHHHr)   c                     |                      |          }t          j                            || j        | j        | j        | j        | j                  S r\   )	r4   r   r@   
max_pool2dr#   r$   r"   r%   rx   r6   hidden_statess     r'   rG   zBitMaxPool2d.forward   sG    //}''4+T[$,W[We
 
 	
r)   )Nr   Frs   r   T)rI   rJ   rK   rL   rf   r2   rG   rM   rN   s   @r'   rr   rr      ss        ;;
  % %% % % % % %&
 
 
 
 
 
 
r)   rr   c                   8     e Zd ZdZdef fdZdedefdZ xZS )BitEmbeddingszL
    BiT Embeddings (stem) composed of a single aggressive convolution.
    rX   c                    t                                                       t          |j        |j        ddd|j                  | _        t          dd|j                  | _	        |j        6|j        
                                dk    rt          j                    | _        nt          j        dd	
          | _        |j        dk    st!          ||j                  | _        nt          j                    | _        |j        | _        d S )Nr   r   :0yE>)r#   r$   r5   r"   r
   )r#   r$   rz   r,   )r   r   r   r   r=   )r"   rg   preactivationrY   )r1   r2   r+   rY   embedding_sizeglobal_paddingconvolutionrr   embedding_dynamic_paddingpoolerupperr   rW   r4   ConstantPad2d
layer_typerP   normr6   rX   r:   s     r'   r2   zBitEmbeddings.__init__   s    3!)
 
 
 #qPVPpqqq  ,1F1L1L1N1NRX1X1X{}}DHH'CHHHDH O33.vFDYZZZDIIDI"/r)   pixel_valuesr   c                     |j         d         }|| j        k    rt          d          |                     |          }|                     |          }|                     |          }|                     |          }|S )Nr   zeMake sure that the channel dimension of the pixel values match with the one set in the configuration.)shaperY   
ValueErrorr   r4   r   r   )r6   r   rY   	embeddings       r'   rG   zBitEmbeddings.forward  s    #)!,4,,,w   $$\22	HHY''	IIi((	KK	**	r)   )	rI   rJ   rK   rL   r   r2   r   rG   rM   rN   s   @r'   r   r      sp         0y 0 0 0 0 0 06F v        r)   r   r=   Frk   	drop_probr>   c                     |dk    s|s| S d|z
  }| j         d         fd| j        dz
  z  z   }|t          j        || j        | j                  z   }|                                 |                     |          |z  }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r=   r   r   )r   )dtypedevice)r   ndimtorchrandr   r   floor_div)rk   r   r>   	keep_probr   random_tensoroutputs          r'   	drop_pathr   *  s     CxII[^
Q 77E
5EL Y Y YYMYYy!!M1FMr)   c                   j     e Zd ZdZd	dee         ddf fdZdej        dej        fdZ	de
fdZ xZS )
BitDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr   r   c                 V    t                                                       || _        d S r\   )r1   r2   r   )r6   r   r:   s     r'   r2   zBitDropPath.__init__B  s$    "r)   r~   c                 8    t          || j        | j                  S r\   )r   r   r>   r}   s     r'   rG   zBitDropPath.forwardF  s    FFFr)   c                 6    d                     | j                  S )Nzp={})formatr   )r6   s    r'   
extra_reprzBitDropPath.extra_reprI  s    }}T^,,,r)   r\   )rI   rJ   rK   rL   r   floatr2   r   r   rG   r    r   rM   rN   s   @r'   r   r   ?  s        bb# #(5/ #T # # # # # #GU\ Gel G G G G-C - - - - - - - -r)   r      c                 x    |}t          |t          | |dz  z             |z  |z            }|d| z  k     r||z  }|S )Nr   g?)ra   rf   )rg   divisor	min_value	new_values       r'   make_divr   M  sP    IIs57Q;#6777BWLMMI3;W	r)   c                   :     e Zd ZdZ	 	 	 	 	 	 	 	 d	 fd	Zd Z xZS )
BitPreActivationBottleneckLayera  Pre-activation (v2) bottleneck block.
    Follows the implementation of "Identity Mappings in Deep Residual Networks":
    https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua

    Except it puts the stride on 3x3 conv when available.
    N      ?r   r=   Fc           	      T   t                                                       |p|}|p|}t          ||z            }|
rt          ||||d          | _        nd | _        t          ||          | _        t          ||dd|j                  | _	        t          ||          | _
        t          ||d||d|j                  | _        t          ||          | _        t          ||dd|j                  | _        |	d	k    rt          |	          nt          j                    | _        d S )
NTr$   preactr   r   r5   r"   r   r
   )r$   r/   r5   r"   r   )r1   r2   r   BitDownsampleConv
downsamplerP   norm1r+   r   conv1norm2conv2norm3conv3r   r   rW   r   )r6   rX   in_channelsr8   bottle_ratior$   r%   first_dilationr/   drop_path_rateis_first_layermid_channelsr:   s               r'   r2   z(BitPreActivationBottleneckLayer.__init__]  sD    	'38#2{| ;<< 		#/  DOO #DO+FK@@
-k<PT^d^sttt
+FNNN
-,&T[a[p
 
 

 ,FLAA
-lL!QU_e_tuuu
8F8J8J^444PRP[P]P]r)   c                 f   |                      |          }|}| j        |                     |          }|                     |          }|                     |                     |                    }|                     |                     |                    }|                     |          }||z   S r\   )r   r   r   r   r   r   r   r   )r6   r~   hidden_states_preactshortcuts       r'   rG   z'BitPreActivationBottleneckLayer.forward  s    #zz-88 !?&';<<H 

#788

4::m#<#<==

4::m#<#<==}55x''r)   Nr   r   r   Nr   r=   FrH   rN   s   @r'   r   r   U  sw          *^ *^ *^ *^ *^ *^X( ( ( ( ( ( (r)   r   c                   :     e Zd ZdZ	 	 	 	 	 	 	 	 d	 fd	Zd Z xZS )
BitBottleneckLayerz\Non Pre-activation bottleneck block, equivalent to V1.5/V1b bottleneck. Used for ViT Hybrid.Nr   r   r=   Fc           
         t                                                       |p|}|p|}t          ||z            }|
rt          ||||d          | _        nd | _        t          ||dd|j                  | _        t          ||          | _	        t          ||d|||d|j                  | _
        t          ||          | _        t          ||dd|j                  | _        t          ||d	          | _        |	d
k    rt          |	          nt          j                    | _        t$          |j                 | _        d S )NFr   r   r   r   r   r
   )r$   r%   r/   r5   r"   rY   rZ   r   )r1   r2   r   r   r   r+   r   r   rP   r   r   r   r   r   r   r   rW   r   r   rU   rV   )r6   rX   r   r8   r   r$   r%   r   r/   r   r   mid_chsr:   s               r'   r2   zBitBottleneckLayer.__init__  sb    	'38#2{<,677 		#/  DOO #DO-k7A4Y_Ynooo
+FIII
-#)	
 	
 	

 ,FIII
-g|QDZ`Zoppp
+F`efff
8F8J8J^444PRP[P]P] !23r)   c                    |}| j         |                      |          }|                     |          }|                     |          }|                     |          }|                     |          }|                     |          }|                     |          }|                     |          }|                     ||z             }|S r\   )	r   r   r   r   r   r   r   r   rV   )r6   r~   r   s      r'   rG   zBitBottleneckLayer.forward  s     ?&}55H 

=11

=11

=11

=11

=11

=11}55(@AAr)   r   rH   rN   s   @r'   r   r     sm        ff /4 /4 /4 /4 /4 /4b      r)   r   c                   *     e Zd Z	 	 d fd	Zd Z xZS )r   r   Tc                     t                                                       t          ||d|d|j                  | _        |rt          j                    nt          ||d          | _        d S )Nr   r   )r$   r5   r"   Fr   )	r1   r2   r+   r   convr   rW   rP   r   )r6   rX   r   r8   r$   r   r:   s         r'   r2   zBitDownsampleConv.__init__  st     	,qT6K`
 
 
	
 cBKMMM'\\abbb 				r)   c                 R    |                      |                     |                    S r\   )r   r   )r6   rd   s     r'   rG   zBitDownsampleConv.forward  s    yy1&&&r)   )r   T)rI   rJ   rK   r2   rG   rM   rN   s   @r'   r   r     sT         
 
 
 
 
 
$' ' ' ' ' ' 'r)   r   c                   >     e Zd ZdZ	 	 d	 fd	Zd ZdedefdZ xZS )
BitStagez7
    A ResNet v2 stage composed by stacked layers.
    r   Nc	                    t                                                       |dv rdnd}	|j        dk    rt          }
nt          }
|}t          j                    | _        t          |          D ][}| 	                    |||          \  }}}| j        
                    t          |           |
|||||||	||	  	                   |}|}	\d S )N)r   r   r   r   
bottleneck)r$   r%   r   r   r   r   )r1   r2   r   r   r   r   
Sequentiallayersrange_get_updated_hyperparameters
add_moduler    )r6   rX   r   r8   r$   r%   depthr   layer_dropoutr   	layer_clsprev_chs	layer_idxr   r   r:   s                  r'   r2   zBitStage.__init__  s     	&&00a ,,*II7Imoou 	& 	&I595V5V6=6 62FNN K""I	 !%!-#1#1#1
 
 
   $H%NN+	& 	&r)   c                 B    |r	||         }nd}|dk    rd}|dk    }|||fS )zt
        Get the new hyper-parameters with respect to the previous ones and the index of the current layer.
        r=   r   r    )r6   r   r$   r   r   r   s         r'   r   z%BitStage._get_updated_hyperparameters,  sA      	!*95NN N>>F"a~~55r)   rk   r   c                 T    |}t          | j                  D ]\  }} ||          }|S r\   )	enumerater   )r6   rk   rF   _layers        r'   rG   zBitStage.forward<  s;    !$+.. 	/ 	/HAu 5..LLr)   )r   N)	rI   rJ   rK   rL   r2   r   r   rG   rM   rN   s   @r'   r   r     s          ,& ,& ,& ,& ,& ,&\6 6 6 V         r)   r   c            	       F     e Zd Zdef fdZd Z	 ddededed	efd
Z	 xZ
S )
BitEncoderrX   c           
      z   t                                                       t          j        g           | _        |j        }d}d}d t          j        t          j	        d|j
        t          |j                                                          |j                  D             }t          t          |j        |j        |                    D ]k\  }\  }}}	|                     |||||          \  }
}}t%          |||
||||	          }|
}||z  }| j                            t)          |          |           ld S )N   r   c                 6    g | ]}|                                 S r   )tolist).0rd   s     r'   
<listcomp>z'BitEncoder.__init__.<locals>.<listcomp>N  s0     
 
 
 HHJJ
 
 
r)   r   )r$   r%   r   r   )r1   r2   r   
ModuleListstagesr   r   r   nplinspacer   sumdepthssplitr   ziphidden_sizesr   r   r   r    )r6   rX   r   current_strider%   layer_dropouts	stage_idxcurrent_depthcurrent_hidden_sizer   r8   r$   stager:   s                r'   r2   zBitEncoder.__init__D  sd   mB''( 
 
\"+a1FFMHZHZ"["[\\bbcicpqq
 
 

 OXv2NCCO
 O
 	: 	:JIJ':M .2-N-N>+>&. .*L&( !#+  E $Hf$NK""3y>>59999+	: 	:r)   c                 r    t          ||j        z            }|dk    rdnd}||j        k    r||z  }d}|||fS )Nr   r   r   )r   width_factoroutput_stride)r6   r   r   r   r%   rX   r8   r$   s           r'   r   z'BitEncoder._get_updated_hyperparametersj  sS     3f6I IJJ1nn!V111HFVX--r)   FTrF   output_hidden_statesreturn_dictr   c                     |rdnd }| j         D ]}|r||fz   } ||          }|r||fz   }|st          d ||fD                       S t          ||          S )Nr   c              3      K   | ]}||V  	d S r\   r   )r   vs     r'   	<genexpr>z%BitEncoder.forward.<locals>.<genexpr>  s"      SSqQ]]]]]SSr)   )last_hidden_stater~   )r   tupler   )r6   rF   r   r  r~   stage_modules         r'   rG   zBitEncoder.forwardr  s     3< K 	6 	6L# @ - ?'<55LL 	<)\O;M 	TSS\=$ASSSSSS-*'
 
 
 	
r)   )FT)rI   rJ   rK   r   r2   r   r   boolr   rG   rM   rN   s   @r'   r   r   C  s        $:y $: $: $: $: $: $:L. . . ]a
 
"
:>
UY
	'
 
 
 
 
 
 
 
r)   r   c                   *    e Zd ZdZeZdZdZdgZd Z	dS )BitPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    bitr   r   c                    t          |t          j                  r)t          j                            |j        dd           d S t          |t          j                  rt          j                            |j        t          j	        d                     |j
        ot          j                            |j                  \  }}|dk    rdt          j	        |          z  nd}t          j                            |j
        | |           d S d S t          |t          j        t          j        f          rLt          j                            |j        d           t          j                            |j
        d           d S d S )Nfan_outrelu)modenonlinearity   )ar   r   )r   r   Conv2dinitkaiming_normal_rB   Linearkaiming_uniform_rb   sqrtr0   _calculate_fan_in_and_fan_outuniform_BatchNorm2d	GroupNorm	constant_)r6   modulefan_inr   bounds        r'   _init_weightsz BitPreTrainedModel._init_weights  sF   fbi(( 	.G##FM	PV#WWWWW	** 	.G$$V]dill$CCC{&GAA&-PP	17!DIf----  ufe<<<<< '&  >?? 	.GfmQ///Gfk1-----	. 	.r)   N)
rI   rJ   rK   rL   r   config_classbase_model_prefixmain_input_name_no_split_modulesr"  r   r)   r'   r  r    sH         
 L$O(). . . . .r)   r  aE  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`BitConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aA  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`BitImageProcessor.__call__`]
            for details.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zLThe bare BiT model outputting raw features without any specific head on top.c                        e Zd Z fdZ ee           eeee	de
          	 d
dedee         dee         defd	                        Z xZS )BitModelc                    t                                          |           || _        t          |          | _        t          |          | _        |j        dk    rt          ||j	        d                   nt          j                    | _        t          j        d          | _        |                                  d S )Nr   r<   r   )r   r   )r1   r2   rX   r   embedderr   encoderr   rP   r   r   rW   r   AdaptiveAvgPool2dr   	post_initr   s     r'   r2   zBitModel.__init__  s       %f--!&))  O33 #68KB8OPPPP 		 *622r)   vision)
checkpointoutput_typer#  modalityexpected_outputNr   r   r  r   c                 P   ||n| j         j        }||n| j         j        }|                     |          }|                     |||          }|d         }|                     |          }|                     |          }|s||f|dd          z   S t          |||j                  S )Nr   r  r   r   )r  pooler_outputr~   )	rX   r   use_return_dictr*  r+  r   r   r   r~   )r6   r   r   r  embedding_outputencoder_outputsr  pooled_outputs           r'   rG   zBitModel.forward  s     %9$D  $+Jj 	 &1%<kk$+B]==66,,3GU` ' 
 
 ,A. II&788$566 	L%}58KKK7/')7
 
 
 	
r)   NN)rI   rJ   rK   r2   r   BIT_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r   r	  rG   rM   rN   s   @r'   r(  r(    s        
    " +*+?@@&<$.   pt
 
"
:B4.
^fgk^l
	1
 
 
  A@
 
 
 
 
r)   r(  z
    BiT Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    c                        e Zd Z fdZ ee           eeee	e
          	 	 	 	 d
deej                 deej                 dee         dee         def
d	                        Z xZS )BitForImageClassificationc                    t                                          |           |j        | _        t          |          | _        t          j        t          j                    |j        dk    r%t          j        |j	        d         |j                  nt          j
                              | _        |                                  d S )Nr   r<   )r1   r2   
num_labelsr(  r  r   r   Flattenr  r   rW   
classifierr-  r   s     r'   r2   z"BitForImageClassification.__init__   s        +F##-JLLEKEVYZEZEZBIf)"-v/@AAA`b`k`m`m
 

 	r)   )r/  r0  r#  r2  Nr   labelsr   r  r   c                    ||n| j         j        }|                     |||          }|r|j        n|d         }|                     |          }d}|Z| j         j        f| j        dk    rd| j         _        nN| j        dk    r7|j        t          j	        k    s|j        t          j
        k    rd| j         _        nd| j         _        | j         j        dk    rWt                      }	| j        dk    r1 |	|                                |                                          }n |	||          }n| j         j        dk    rGt                      }	 |	|                    d| j                  |                    d                    }n*| j         j        dk    rt                      }	 |	||          }|s|f|dd         z   }
||f|
z   n|
S t!          |||j        	          S )
a0  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr4  r   
regressionsingle_label_classificationmulti_label_classificationr<   r   )losslogitsr~   )rX   r6  r  r5  rD  problem_typerB  r   r   longrf   r	   squeezer   viewr   r   r~   )r6   r   rE  r   r  outputsr9  rK  rJ  loss_fctr   s              r'   rG   z!BitForImageClassification.forward  s   & &1%<kk$+B]((<>R`k(ll1<L--'!*//{'/?a''/;DK,,_q((flej.H.HFL\a\eLeLe/LDK,,/KDK,{'<77"99?a''#8FNN$4$4fnn6F6FGGDD#8FF33DD)-JJJ+--xB @ @&++b//RR)-III,..x// 	DY,F'+'7D7V##VC3f\c\qrrrrr)   )NNNN)rI   rJ   rK   r2   r   r;  r   _IMAGE_CLASS_CHECKPOINTr   r=  _IMAGE_CLASS_EXPECTED_OUTPUTr   r   FloatTensor
LongTensorr	  rG   rM   rN   s   @r'   r@  r@    s        
 
 
 
 
 +*+?@@*8$4	   59-1/3&*/s /su01/s )*/s 'tn	/s
 d^/s 
./s /s /s  A@/s /s /s /s /sr)   r@  zL
    BiT backbone, to be used with frameworks like DETR and MaskFormer.
    c                        e Zd Z fdZ ee           eee          	 d	de	de
e         de
e         defd                        Z xZS )
BitBackbonec                    t                                          |           t                                          |           t          |          | _        |j        g|j        z   | _        |                                  d S r\   )	r1   r2   _init_backboner(  r  r   r   num_featuresr-  r   s     r'   r2   zBitBackbone.__init__L  sp       v&&&F###23f6II 	r)   )r0  r#  Nr   r   r  r   c                 @   ||n| j         j        }||n| j         j        }|                     |dd          }|j        }d}t          | j                  D ]\  }}|| j        v r|||         fz  }|s|f}	|r|	|j        fz  }	|	S t          ||r|j        ndd          S )a`  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("google/bit-50")
        >>> model = AutoBackbone.from_pretrained("google/bit-50")

        >>> inputs = processor(image, return_tensors="pt")
        >>> outputs = model(**inputs)
        ```NTr4  r   )feature_mapsr~   
attentions)	rX   r6  r   r  r~   r   stage_namesout_featuresr   )
r6   r   r   r  rP  r~   r\  idxr   r   s
             r'   rG   zBitBackbone.forwardV  s    2 &1%<kk$+B]$8$D  $+Jj 	 ((<dPT(UU-#D$455 	6 	6JC)))s!3 55 	"_F# 37022M%3GQ'//T
 
 
 	
r)   r:  )rI   rJ   rK   r2   r   r;  r   r   r=  r   r   r	  rG   rM   rN   s   @r'   rW  rW  E  s             +*+?@@>XXXos/
 /
"/
:B4./
^fgk^l/
	/
 /
 /
 YX A@/
 /
 /
 /
 /
r)   rW  )Nr   r   r   )r=   F)r   )ErL   ru   rb   typingr   r   numpyr   r   torch.utils.checkpointr   r   torch.nnr   r   r	   activationsr   modeling_outputsr   r   r   r   modeling_utilsr   utilsr   r   r   r   r   utils.backbone_utilsr   configuration_bitr   
get_loggerrI   loggerr=  r<  r>  rR  rS  r	  r(   r  r+   r  rP   Moduler3   	MaxPool2drr   r   r   r   r   r   r   r   r   r   r   r  BIT_START_DOCSTRINGr;  r(  r@  rW  r   r)   r'   <module>rp     si   @ ?      " " " " " " " "                  A A A A A A A A A A ! ! ! ! ! !            . - - - - -              2 1 1 1 1 1 ( ( ( ( ( ( 
	H	%	%  & (  * * & &ERWY]R]L^ & & & &R- - - - -ry - - -`    R\   $0 0 0 0 029 0 0 0f
 
 
 
 
2< 
 
 
:/ / / / /BI / / /f U\ e T V[Vb    *- - - - -") - - -   A( A( A( A( A(bi A( A( A(HF F F F F F F FR' ' ' ' '	 ' ' '.G G G G Gry G G GTC
 C
 C
 C
 C
 C
 C
 C
L. . . . . . . .4	   R 5
 5
 5
 5
 5
! 5
 5
	 5
p   Cs Cs Cs Cs Cs 2 Cs Cs CsL  	 <
 <
 <
 <
 <
$m <
 <
 <
 <
 <
r)   