
    gv              	       2   d Z ddlZddlZddlmZ ddlmZmZm	Z	 ddl
Z
ddlZ
ddl
mZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZmZmZ ddlmZ ddl m!Z!  ej"        e#          Z$dZ%dZ&g dZ'dZ(dZ)e G d de                      Z*e G d de                      Z+e G d de                      Z,e G d de                      Z- G d dej.                  Z/ G d dej.                  Z0dBd!e
j1        d"e2d#e3d$e
j1        fd%Z4 G d& d'ej.                  Z5 G d( d)ej.                  Z6 G d* d+ej.                  Z7 G d, d-ej.                  Z8 G d. d/ej.                  Z9 G d0 d1ej.                  Z: G d2 d3e          Z;d4Z<d5Z= ed6e<           G d7 d8e;                      Z> ed9e<           G d: d;e;                      Z? ed<e<           G d= d>e;                      Z@ ed?e<           G d@ dAe;e                      ZAdS )CzPyTorch FocalNet model.    N)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BackboneOutput)PreTrainedModel)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings)BackboneMixin   )FocalNetConfigr   zmicrosoft/focalnet-tiny)r   1   i   ztabby, tabby catc                       e Zd ZU dZdZej        ed<   dZe	e
ej                          ed<   dZe	e
ej                          ed<   dS )FocalNetEncoderOutputa  
    FocalNet encoder's outputs, with potential hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.

        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlast_hidden_statehidden_statesreshaped_hidden_states)__name__
__module____qualname____doc__r   torchFloatTensor__annotations__r   r   r   r        j/var/www/html/ai-engine/env/lib/python3.11/site-packages/transformers/models/focalnet/modeling_focalnet.pyr   r   8   sm          ( ,0u(///8<M8E%"345<<<AEHU5+<%=>EEEEEr&   r   c                       e Zd ZU dZdZej        ed<   dZe	ej                 ed<   dZ
e	eej                          ed<   dZe	eej                          ed<   dS )FocalNetModelOutputa  
    FocalNet model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
            Average pooling of the last layer hidden-state.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr   pooler_outputr   r   )r   r   r    r!   r   r"   r#   r$   r*   r   r   r   r   r%   r&   r'   r)   r)   S   s          * ,0u(///15M8E-.5558<M8E%"345<<<AEHU5+<%=>EEEEEr&   r)   c                       e Zd ZU dZdZeej                 ed<   dZ	ej        ed<   dZ
eeej                          ed<   dZeeej                          ed<   dS )!FocalNetMaskedImageModelingOutputa  
    FocalNet masked image model outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
            Masked image modeling (MLM) loss.
        reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Reconstructed pixel values.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlossreconstructionr   r   )r   r   r    r!   r-   r   r"   r#   r$   r.   r   r   r   r%   r&   r'   r,   r,   p   s          * )-D(5$
%,,,(,NE%,,,8<M8E%"345<<<AEHU5+<%=>EEEEEr&   r,   c                       e Zd ZU dZdZeej                 ed<   dZ	ej        ed<   dZ
eeej                          ed<   dZeeej                          ed<   dS )FocalNetImageClassifierOutputaS  
    FocalNet outputs for image classification.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr-   logitsr   r   )r   r   r    r!   r-   r   r"   r#   r$   r1   r   r   r   r%   r&   r'   r0   r0      s          * )-D(5$
%,,, $FE$$$8<M8E%"345<<<AEHU5+<%=>EEEEEr&   r0   c                   ~     e Zd ZdZd	 fd	Z	 d
deej                 deej                 de	ej
                 fdZ xZS )FocalNetEmbeddingszX
    Construct the patch embeddings and layernorm. Optionally, also the mask token.
    Fc           	         t                                                       t          ||j        |j        |j        |j        |j        d          | _        | j        j	        | _
        |r-t          j        t          j        dd|j                            nd | _        t          j        |j        |j                  | _        t          j        |j                  | _        d S )NT)config
image_size
patch_sizenum_channels	embed_dimuse_conv_embedis_stemr   eps)super__init__FocalNetPatchEmbeddingsr6   r7   r8   r9   r:   patch_embeddings	grid_size
patch_gridr   	Parameterr"   zeros
mask_token	LayerNormlayer_norm_epsnormDropouthidden_dropout_probdropout)selfr5   use_mask_token	__class__s      r'   r?   zFocalNetEmbeddings.__init__   s     7((,&!0!
 !
 !
 /9O]g",u{1a9I'J'JKKKcgL!1v7LMMM	z&"<==r&   Npixel_valuesbool_masked_posreturnc                 f   |                      |          \  }}|                     |          }|                                \  }}}|R| j                            ||d          }|                    d                              |          }	|d|	z
  z  ||	z  z   }|                     |          }||fS )N      ?)rA   rI   sizerF   expand	unsqueezetype_asrL   )
rM   rP   rQ   
embeddingsoutput_dimensions
batch_sizeseq_len_mask_tokensmasks
             r'   forwardzFocalNetEmbeddings.forward   s     )-(=(=l(K(K%
%YYz**
!+!2!2
GQ&/00WbIIK",,R0088EED#sTz2[45GGJ\\*--
,,,r&   )FN)r   r   r    r!   r?   r   r"   r#   
BoolTensorr   Tensorra   __classcell__rO   s   @r'   r3   r3      s         > > > > > >& hl- -$U%67-JRSXScJd-	u|	- - - - - - - -r&   r3   c                   x     e Zd Z	 	 	 d fd	Zd Zdeej                 deej	        ee
         f         fdZ xZS )r@   Fc	                    t                                                       t          |t          j        j                  r|n||f}t          |t          j        j                  r|n||f}|d         |d         z  |d         |d         z  z  }	|| _        || _        || _        |	| _	        |d         |d         z  |d         |d         z  f| _
        |r.|rd}
d}d}nd}
d}d}t          j        |||
||          | _        nt          j        ||||          | _        |r"t          j        ||j        	          | _        d S d | _        d S )
Nr   r            r   )kernel_sizestridepadding)rl   rm   r<   )r>   r?   
isinstancecollectionsabcIterabler6   r7   r8   num_patchesrB   r   Conv2d
projectionrG   rH   rI   )rM   r5   r6   r7   r8   r9   add_normr:   r;   rs   rl   rn   rm   rO   s                r'   r?   z FocalNetPatchEmbeddings.__init__   sq    	#-j+/:R#S#SqZZZdfpYq
#-j+/:R#S#SqZZZdfpYq
!!}
15*Q-:VW=:XY$$(&$Q-:a=8*Q-:VW=:XY 	l  ii[Y`  DOO !iiZ`jkkkDO 	YF4IJJJDIIIDIIIr&   c                 Z   || j         d         z  dk    r@d| j         d         || j         d         z  z
  f}t          j                            ||          }|| j         d         z  dk    rBddd| j         d         || j         d         z  z
  f}t          j                            ||          }|S )Nr   r   )r7   r   
functionalpad)rM   rP   heightwidth
pad_valuess        r'   	maybe_padz!FocalNetPatchEmbeddings.maybe_pad   s    4?1%%**T_Q/%$/!:L2LLMJ=,,\:FFLDOA&&!++Q4?1#5QRAS8S#STJ=,,\:FFLr&   rP   rR   c                 X   |j         \  }}}}|| j        k    rt          d          |                     |||          }|                     |          }|j         \  }}}}||f}|                    d                              dd          }| j        |                     |          }||fS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.rj   r   )shaper8   
ValueErrorr}   ru   flatten	transposerI   )rM   rP   r^   r8   rz   r{   rZ   r[   s           r'   ra   zFocalNetPatchEmbeddings.forward  s    )5);&<4,,,w   ~~lFEBB__\22
(.1fe#UO''**44Q::
9 :..J,,,r&   )FFF)r   r   r    r?   r}   r   r"   r#   r   rd   intra   re   rf   s   @r'   r@   r@      s         ( ( ( ( ( (T  -HU->$? -E%,X]^aXbJbDc - - - - - - - -r&   r@           Finput	drop_probtrainingrR   c                     |dk    s|s| S d|z
  }| j         d         fd| j        dz
  z  z   }|t          j        || j        | j                  z   }|                                 |                     |          |z  }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r   r   r   )r   )dtypedevice)r   ndimr"   randr   r   floor_div)r   r   r   	keep_probr   random_tensoroutputs          r'   	drop_pathr     s     CxII[^
Q 77E
5EL Y Y YYMYYy!!M1FMr&   c                   j     e Zd ZdZd	dee         ddf fdZdej        dej        fdZ	de
fdZ xZS )
FocalNetDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr   rR   c                 V    t                                                       || _        d S rb   )r>   r?   r   )rM   r   rO   s     r'   r?   zFocalNetDropPath.__init__2  s$    "r&   r   c                 8    t          || j        | j                  S rb   )r   r   r   )rM   r   s     r'   ra   zFocalNetDropPath.forward6  s    FFFr&   c                 6    d                     | j                  S )Nzp={})formatr   rM   s    r'   
extra_reprzFocalNetDropPath.extra_repr9  s    }}T^,,,r&   rb   )r   r   r    r!   r   floatr?   r"   rd   ra   strr   re   rf   s   @r'   r   r   /  s        bb# #(5/ #T # # # # # #GU\ Gel G G G G-C - - - - - - - -r&   r   c                   &     e Zd Zd fd	Zd Z xZS )FocalNetModulationrj   Tr   c                    t                                                       || _        |j        |         | _        |j        |         | _        || _        |j        | _        |j	        | _	        t          j        |d|z  | j        dz   z   |          | _        t          j        ||dd|          | _        t          j                    | _        t          j        ||          | _        t          j        |          | _        t          j                    | _        g | _        t/          | j                  D ]}| j        |z  | j        z   }| j                            t          j        t          j        |||d||dz  d          t          j                                         | j                            |           | j        r"t          j        ||j                  | _        d S d S )Nrj   r   )bias)rl   rm   r   F)rl   rm   groupsrn   r   r<   )r>   r?   dimfocal_windowsfocal_windowfocal_levelsfocal_levelfocal_factor use_post_layernorm_in_modulationnormalize_modulatorr   Linearprojection_inrt   projection_contextGELU
activationprojection_outrJ   projection_dropout
ModuleListfocal_layerskernel_sizesrangeappend
SequentialrG   rH   	layernorm)
rM   r5   indexr   r   r   r   krl   rO   s
            r'   r?   zFocalNetModulation.__init__>  s   "07!.u5(060W-#)#= YsAGt7G!7K,LSWXXX"$)C!ATX"Y"Y"Y')) iS11"$*-?"@"@MOOt'(( 
	2 
	2A+a/$2CCK$$ISk!CYdhiYipu   GII	    $$[11110 	J\#63HIIIDNNN	J 	Jr&   c                    |j         d         }|                     |                              dddd                                          }t	          j        |||| j        dz   fd          \  }}| _        d}t          | j                  D ]4} | j	        |         |          }||| j        dd||dz   f         z  z   }5| 
                    |                    dd                              dd                    }||| j        dd| j        df         z  z   }| j        r|| j        dz   z  }|                     |          | _        || j        z  }	|	                    dddd                                          }	| j        r|                     |	          }	|                     |	          }	|                     |	          }	|	S )	z
        Args:
            hidden_state:
                Input features with shape of (batch_size, height, width, num_channels)
        rT   r   r   r   rj   NT)keepdim)r   r   permute
contiguousr"   splitr   gatesr   r   r   meanr   r   	modulatorr   r   r   r   )
rM   hidden_stater8   xqctxctx_alllevel
ctx_globalx_outs
             r'   ra   zFocalNetModulation.forward_  s    $)"- |,,44Q1a@@KKMM"[\<IY\]I],^`abb3
 4+,, 	G 	GE*$#E*3//CdjEEAI4E1E&F FFGG__SXXaX%>%>%C%CAt%C%T%TUU
JAAAt7G7I7I4I)JJJ # 	7!1A!56G 0099DN"aAq))44660 	*NN5))E ##E**''..r&   )rj   Tr   r   r   r    r?   ra   re   rf   s   @r'   r   r   =  sS        J J J J J JB" " " " " " "r&   r   c                   &     e Zd Zd fd	Zd Z xZS )FocalNetMlpNr   c                     t                                                       |p|}|p|}t          j        ||          | _        t
          |j                 | _        t          j        ||          | _        t          j	        |          | _
        d S rb   )r>   r?   r   r   fc1r   
hidden_actr   fc2rJ   drop)rM   r5   in_featureshidden_featuresout_featuresr   rO   s         r'   r?   zFocalNetMlp.__init__  sw    #2{)8[9[/:: !239_l;;Jt$$			r&   c                     |                      |          }|                     |          }|                     |          }|                     |          }|                     |          }|S rb   )r   r   r   r   )rM   r   s     r'   ra   zFocalNetMlp.forward  s]    xx--|44yy..xx--yy..r&   )NNr   r   rf   s   @r'   r   r     sL        % % % % % %      r&   r   c                   *     e Zd ZdZd fd	Zd Z xZS )FocalNetLayera  Focal Modulation Network layer (block).

    Args:
        config (`FocalNetConfig`):
            Model config.
        index (`int`):
            Layer index.
        dim (`int`):
            Number of input channels.
        input_resolution (`Tuple[int]`):
            Input resulotion.
        drop_path (`float`, *optional*, defaults to 0.0):
            Stochastic depth rate.
    r   c                     t                                                       || _        || _        || _        |j        | _        |j        | _        t          j	        ||j
                  | _        t          |||| j                  | _        |dk    rt          |          nt          j                    | _        t          j	        ||j
                  | _        t%          ||j        z            }t)          |||| j                  | _        d| _        d| _        |j        rlt          j        |j        t7          j        |          z  d          | _        t          j        |j        t7          j        |          z  d          | _        d S d S )Nr<   )r5   r   r   r   r   )r5   r   r   r   rU   T)requires_grad)r>   r?   r5   r   input_resolutionrK   r   use_post_layernormr   rG   rH   norm1r   
modulationr   Identityr   norm2r   	mlp_ratior   mlpgamma_1gamma_2use_layerscalerD   layerscale_valuer"   ones)rM   r5   r   r   r   r   mlp_hidden_dimrO   s          r'   r?   zFocalNetLayer.__init__  sq     0 .	"(";\#6+@AAA
,#y	
 
 
 9BC))444R[]]\#6+@AAA
S6#3344f#~dhdmnnn  	i<(?%*cBSBS(ScghhhDL<(?%*cBSBS(ScghhhDLLL	i 	ir&   c           	      V   |\  }}|j         \  }}}|}| j        r|n|                     |          }|                    ||||          }|                     |                              |||z  |          }| j        s|n|                     |          }||                     | j        |z            z   }||                     | j        | j        r(|                     | 	                    |                    n'| 	                    |                     |                    z            z   }|S rb   )
r   r   r   viewr   r   r   r   r   r   )	rM   r   input_dimensionsrz   r{   r\   r^   r8   shortcuts	            r'   ra   zFocalNetLayer.forward  s0   (&2&8#
A| (,'>\||DJJ|D\D\#((VULQQ|4499*funVbcc+/+B`||

S_H`H`  $..1L"M"MM#dnnL595Lttzz$((<00111RVRZRZ[_[e[efr[s[sRtRtv'
 '
 

 r&   )r   )r   r   r    r!   r?   ra   re   rf   s   @r'   r   r     s]         i i i i i i@      r&   r   c                   b     e Zd Z fdZdej        deeef         deej                 fdZ xZ	S )FocalNetStagec           
        	 t                                                       | _        t          j                  | _        fdt          | j                  D             }|         | j        dz
  k     r|dz            nd }| j        dz
  k     rt          nd }d t          j	        dj
        t          j                            D             }|t          j        d                    t          j        d dz                               	t          j        	fdt          j                           D                       | _        | |d|dj        d	          | _        nd | _        d| _        d S )
Nc                 *    g | ]}j         d |z  z  S )rj   )r9   ).0ir5   s     r'   
<listcomp>z*FocalNetStage.__init__.<locals>.<listcomp>  s%    OOO1V%A.OOOr&   r   c                 6    g | ]}|                                 S r%   )item)r   r   s     r'   r   z*FocalNetStage.__init__.<locals>.<listcomp>  s     ^^^Aqvvxx^^^r&   r   c                 r    g | ]3}t          t          t                    r|         n           4S ))r5   r   r   r   r   )r   ro   list)r   r   r5   r   r   r   r   s     r'   r   z*FocalNetStage.__init__.<locals>.<listcomp>  s`     	 	 	  !%5.8D.I.IXilly  	 	 	r&   rj   TF)r5   r6   r7   r8   r9   rv   r:   r;   )r>   r?   r5   lendepths
num_stagesr   r@   r"   linspacedrop_path_ratesumr   r   layersr:   
downsamplepointing)rM   r5   r   r   r9   out_dimr   dprr   r   rO   s    ```    @@r'   r?   zFocalNetStage.__init__  s   fm,,OOOOdo8N8NOOO	+04?Q3F+F+F)EAI&&T1619L1L1L,,SW
 _^63H#fmJ\J\!]!]^^^FM&5&122S{QR{9S5T5TTU	m	 	 	 	 	 	 	 	 v}U344	 	 	
 
 !(j+ !%4	 	 	DOO #DOr&   r   r   rR   c                    |\  }}| j         D ]} |||          }|}| j        U|\  }}|                    dd                              |j        d         d||          }|                     |          \  }}n||||f}|||f}|S )Nr   rj   r   rT   )r   r   r   reshaper   )	rM   r   r   rz   r{   layer_module!hidden_states_before_downsamplingr[   stage_outputss	            r'   ra   zFocalNetStage.forward	  s    ( K 	J 	JL(L8HIIMM,9)?&,MFE)33Aq99AA17:B M 04}/M/M,M,, "( >&(IK\]r&   )
r   r   r    r?   r"   rd   r   r   ra   re   rf   s   @r'   r   r     sw        * * * * *XU\ U3PS8_ Y^_d_kYl        r&   r   c                        e Zd Z fdZ	 	 	 ddej        deeef         dee	         dee	         dee	         d	e
eef         fd
Z xZS )FocalNetEncoderc                 
   t                                                       t          j                  | _        | _        t          j        fdt          | j                  D                       | _	        d| _
        d S )Nc           
      h    g | ].}t          |d          d|z  z  d         d|z  z  f          /S )r   rj   r   )r5   r   r   )r   )r   i_layerr5   rB   s     r'   r   z,FocalNetEncoder.__init__.<locals>.<listcomp>%  se         !!&/lq'z&BIaLUVX_U_D`%a    r&   F)r>   r?   r   r   r   r5   r   r   r   stagesgradient_checkpointing)rM   r5   rB   rO   s    ``r'   r?   zFocalNetEncoder.__init__  s    fm,,m      %T_55  	
 	
 ',###r&   FTr   r   output_hidden_states(output_hidden_states_before_downsamplingreturn_dictrR   c                    |rdnd }|rdnd }|r?|j         \  }}	}
 |j        |g||
R  }|                    dddd          }||fz  }||fz  }t          | j                  D ]\  }}| j        r$| j        r|                     |j        ||          }n |||          }|d         }|d         }|d         }|d         |d         f}|rP|rN|j         \  }}	}
 |j        |g|d         |d         f|
R  }|                    dddd          }||fz  }||fz  }|rA|s?|j         \  }}	}
 |j        |g||
R  }|                    dddd          }||fz  }||fz  }|st          d ||fD                       S t          |||	          S )
Nr%   r   r   r   rj   rT   c              3      K   | ]}||V  	d S rb   r%   )r   vs     r'   	<genexpr>z*FocalNetEncoder.forward.<locals>.<genexpr>g  s"      XXq!-----XXr&   )r   r   r   )r   r   r   	enumerater  r  r   _gradient_checkpointing_func__call__tupler   )rM   r   r   r  r  r  all_hidden_statesall_reshaped_hidden_statesr\   r^   hidden_sizereshaped_hidden_stater   stage_moduler  r  r[   s                    r'   ra   zFocalNetEncoder.forward1  s    #7@BBD+?%IRRT" 	C)6)<&J;$6M$6z$bDT$bVa$b$b$b!$9$A$A!Q1$M$M!-!11&+@*BB&(55  	G  	GOA|* Nt} N $ A A )!$! ! !-]<L M M)!,M0=a0@- -a 0 1" 57H7LM# G(P G-N-T*
A{ )O(I(N)"3A"68I!8L!M)OZ) ) )% )>(E(EaAq(Q(Q%!&G%II!*/D.FF**% G.V G-:-@*
A{(:(::(fHX(fZe(f(f(f%(=(E(EaAq(Q(Q%!m%55!*/D.FF* 	YXX]4E$FXXXXXX$++#=
 
 
 	
r&   )FFT)r   r   r    r?   r"   rd   r   r   r   boolr   r   ra   re   rf   s   @r'   r  r    s        , , , , ,, 05CH&*<
 <
|<
  S/<
 'tn	<

 3;4.<
 d^<
 
u++	,<
 <
 <
 <
 <
 <
 <
 <
r&   r  c                   .    e Zd ZdZeZdZdZdZdgZ	d Z
dS )FocalNetPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    focalnetrP   Tr   c                    t          |t          j        t          j        f          rT|j        j                            d| j        j                   |j	         |j	        j        
                                 dS dS t          |t          j                  r?|j	        j        
                                 |j        j                            d           dS dS )zInitialize the weightsr   )r   stdNrU   )ro   r   r   rt   weightdatanormal_r5   initializer_ranger   zero_rG   fill_)rM   modules     r'   _init_weightsz%FocalNetPreTrainedModel._init_weights}  s    fry")455 	* M&&CT[5R&SSS{& &&((((( '&-- 	*K""$$$M$$S)))))	* 	*r&   N)r   r   r    r!   r   config_classbase_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modulesr,  r%   r&   r'   r!  r!  q  sN         
 "L"$O&*#()
* 
* 
* 
* 
*r&   r!  aK  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`FocalNetConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aB  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`AutoImageProcessor.__call__`] for details.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zVThe bare FocalNet Model outputting raw hidden-states without any specific head on top.c                        e Zd Zd fd	Zd Z ee           eee	e
de          	 	 	 	 ddeej                 d	eej                 d
ee         dee         deee	f         f
d                        Z xZS )FocalNetModelTFc                    t                                          |           || _        t          |j                  | _        t          |j        d| j        dz
  z  z            | _        t          ||          | _
        t          || j
        j                  | _        t          j        | j        |j                  | _        |rt          j        d          nd | _        |                                  d S )Nrj   r   )rN   r<   )r>   r?   r5   r   r   r   r   r9   num_featuresr3   rZ   r  rC   encoderr   rG   rH   r   AdaptiveAvgPool1dpooler	post_init)rM   r5   add_pooling_layerrN   rO   s       r'   r?   zFocalNetModel.__init__  s       fm,, 0119L3M MNN,VNSSS&vt/IJJd&7V=RSSS1BLb*1--- 	r&   c                     | j         j        S rb   )rZ   rA   r   s    r'   get_input_embeddingsz"FocalNetModel.get_input_embeddings  s    //r&   vision)
checkpointoutput_typer-  modalityexpected_outputNrP   rQ   r  r  rR   c                    ||n| j         j        }||n| j         j        }|t          d          |                     ||          \  }}|                     ||||          }|d         }|                     |          }d}	| j        >|                     |                    dd                    }	t          j
        |	d          }	|s||	f|dd         z   }
|
S t          ||	|j        |j                  S )	z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)rQ   r  r  r   r   rj   )r   r*   r   r   )r5   r  use_return_dictr   rZ   r6  r   r8  r   r"   r   r)   r   r   )rM   rP   rQ   r  r  embedding_outputr   encoder_outputssequence_outputpooled_outputr   s              r'   ra   zFocalNetModel.forward  s1   ( %9$D  $+Jj 	 &1%<kk$+B]?@@@-1__\[j_-k-k**,,!5#	 ' 
 
 *!,..99;" KK(A(A!Q(G(GHHM!M-;;M 	%}58KKFM"-')7#2#I	
 
 
 	
r&   )TFNNNN)r   r   r    r?   r<  r   FOCALNET_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr)   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r"   r#   rc   r  r   r   ra   re   rf   s   @r'   r3  r3    s       
     0 0 0 +*+DEE&'$.   596:/3&*.
 .
u01.
 "%"23.
 'tn	.

 d^.
 
u))	*.
 .
 .
  FE.
 .
 .
 .
 .
r&   r3  a|  FocalNet Model with a decoder on top for masked image modeling.

    This follows the same implementation as in [SimMIM](https://arxiv.org/abs/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    c                        e Zd Z fdZ ee           eee          	 	 	 	 d
de	e
j                 de	e
j                 de	e         de	e         deeef         f
d	                        Z xZS )FocalNetForMaskedImageModelingc                    t                                          |           t          |dd          | _        t	          |j                  | _        t          |j        d| j        dz
  z  z            }t          j
        t          j        ||j        dz  |j        z  d          t          j        |j                            | _        |                                  d S )NFT)r:  rN   rj   r   )in_channelsout_channelsrl   )r>   r?   r3  r"  r   r   r   r   r9   r   r   rt   encoder_strider8   PixelShuffledecoderr9  )rM   r5   r5  rO   s      r'   r?   z'FocalNetForMaskedImageModeling.__init__  s       %fVZ[[[fm,,6+aDOa4G.HHII}I(v7La7ORXRe7est   OF122	
 
 	r&   r?  r-  NrP   rQ   r  r  rR   c                    ||n| j         j        }|                     ||||          }|d         }|                    dd          }|j        \  }}}	t          j        |	dz            x}
}|                    |||
|          }|                     |          }d}|| j         j	        | j         j
        z  }|                    d||          }|                    | j         j
        d                              | j         j
        d                              d                                          }t          j                            ||d	          }||z                                  |                                d
z   z  | j         j        z  }|s|f|dd         z   }||f|z   n|S t'          |||j        |j                  S )aQ  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Returns:

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, FocalNetConfig, FocalNetForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-base-simmim-window6-192")
        >>> config = FocalNetConfig()
        >>> model = FocalNetForMaskedImageModeling(config)

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 192, 192]
        ```N)rQ   r  r  r   r   rj   g      ?rT   none)	reductiongh㈵>)r-   r.   r   r   )r5   rD  r"  r   r   mathfloorr  rU  r6   r7   repeat_interleaverX   r   r   rx   l1_lossr   r8   r,   r   r   )rM   rP   rQ   r  r  outputsrG  r\   r8   sequence_lengthrz   r{   reconstructed_pixel_valuesmasked_im_lossrV   r`   reconstruction_lossr   s                     r'   ra   z&FocalNetForMaskedImageModeling.forward  s   N &1%<kk$+B]--+!5#	   
 
 "!*)33Aq994C4I1
L/OS$8999)11*lFTYZZ &*\\/%B%B"&;)T[-CCD-55b$EEO11$+2H!LL""4;#91==1	  #%-"7"7F`lr"7"s"s1D8==??488::PTCTUX\XcXppN 	Z02WQRR[@F3A3M^%..SYY05!/#*#A	
 
 
 	
r&   rI  )r   r   r    r?   r   rJ  r   r,   rL  r   r"   r#   rc   r  r   r   ra   re   rf   s   @r'   rO  rO    s            " +*+DEE+L[jkkk 596:/3&*N
 N
u01N
 "%"23N
 'tn	N

 d^N
 
u77	8N
 N
 N
 lk FEN
 N
 N
 N
 N
r&   rO  z
    FocalNet Model with an image classification head on top (a linear layer on top of the pooled output) e.g. for
    ImageNet.
    c                        e Zd Z fdZ ee           eeee	e
          	 	 	 	 d
deej                 deej                 dee         dee         deeef         f
d	                        Z xZS )FocalNetForImageClassificationc                 @   t                                          |           |j        | _        t          |          | _        |j        dk    r$t          j        | j        j        |j                  nt          j                    | _	        | 
                                 d S )Nr   )r>   r?   
num_labelsr3  r"  r   r   r5  r   
classifierr9  rM   r5   rO   s     r'   r?   z'FocalNetForImageClassification.__init__o  s        +%f-- IOHY\]H]H]BIdm0&2CDDDcecncpcp 	
 	r&   )r>  r?  r-  rA  NrP   labelsr  r  rR   c                    ||n| j         j        }|                     |||          }|d         }|                     |          }d}|Z| j         j        f| j        dk    rd| j         _        nN| j        dk    r7|j        t          j        k    s|j        t          j	        k    rd| j         _        nd| j         _        | j         j        dk    rWt                      }	| j        dk    r1 |	|                                |                                          }n |	||          }n| j         j        dk    rGt                      }	 |	|                    d| j                  |                    d                    }n*| j         j        dk    rt                      }	 |	||          }|s|f|dd         z   }
||f|
z   n|
S t          |||j        |j        	          S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NrC  r   
regressionsingle_label_classificationmulti_label_classificationrT   rj   )r-   r1   r   r   )r5   rD  r"  rg  problem_typerf  r   r"   longr   r
   squeezer	   r   r   r0   r   r   )rM   rP   ri  r  r  r^  rH  r1   r-   loss_fctr   s              r'   ra   z&FocalNetForImageClassification.forward}  s   ( &1%<kk$+B]--!5#   
 
  
//{'/?a''/;DK,,_q((flej.H.HFL\a\eLeLe/LDK,,/KDK,{'<77"99?a''#8FNN$4$4fnn6F6FGGDD#8FF33DD)-JJJ+--xB @ @&++b//RR)-III,..x// 	FY,F)-)9TGf$$vE,!/#*#A	
 
 
 	
r&   rI  )r   r   r    r?   r   rJ  r   _IMAGE_CLASS_CHECKPOINTr0   rL  _IMAGE_CLASS_EXPECTED_OUTPUTr   r"   r#   
LongTensorr  r   r   ra   re   rf   s   @r'   rd  rd  f  s             +*+DEE*1$4	   59-1/3&*9
 9
u019
 )*9
 'tn	9

 d^9
 
u33	49
 9
 9
  FE9
 9
 9
 9
 9
r&   rd  zG
    FocalNet backbone, to be used with frameworks like X-Decoder.
    c                        e Zd Zdef fdZ ee           eee	          	 	 d
de
j        dee         dee         defd	                        Z xZS )FocalNetBackboner5   c                    t                                          |           t                                          |           |j        g|j        z   | _        t          |          | _        |                                  d S rb   )	r>   r?   _init_backboner9   hidden_sizesr5  r3  r"  r9  rh  s     r'   r?   zFocalNetBackbone.__init__  sp       v&&&#-.1DD%f-- 	r&   rV  NrP   r  r  rR   c                 @   ||n| j         j        }||n| j         j        }|                     |dd          }|j        }d}t          | j                  D ]\  }}|| j        v r|||         fz  }|s|f}	|r|	|j        fz  }	|	S t          ||r|j        ndd          S )a|  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny-lrf")
        >>> model = AutoBackbone.from_pretrained("microsoft/focalnet-tiny-lrf")

        >>> inputs = processor(image, return_tensors="pt")
        >>> outputs = model(**inputs)
        ```NTrC  r%   )feature_mapsr   
attentions)
r5   rD  r  r"  r   r  stage_namesr   r   r   )
rM   rP   r  r  r^  r   r{  idxstager   s
             r'   ra   zFocalNetBackbone.forward  s    8 &1%<kk$+B]$8$D  $+Jj 	 --4UY-ZZ6#D$455 	6 	6JC)))s!3 55 	"_F# 37022M%3GQ'//T
 
 
 	
r&   )NN)r   r   r    r   r?   r   rJ  r   r   rL  r"   rd   r   r  ra   re   rf   s   @r'   rv  rv    s        ~       +*+DEE>XXX 04&*	2
 2
l2
 'tn2
 d^	2

 
2
 2
 2
 YX FE2
 2
 2
 2
 2
r&   rv  )r   F)Br!   collections.abcrp   rZ  dataclassesr   typingr   r   r   r"   torch.utils.checkpointr   torch.nnr   r	   r
   activationsr   modeling_outputsr   modeling_utilsr   utilsr   r   r   r   r   r   utils.backbone_utilsr   configuration_focalnetr   
get_loggerr   loggerrL  rK  rM  rr  rs  r   r)   r,   r0   Moduler3   r@   rd   r   r  r   r   r   r   r   r   r  r!  FOCALNET_START_DOCSTRINGrJ  r3  rO  rd  rv  r%   r&   r'   <module>r     s          ! ! ! ! ! ! ) ) ) ) ) ) ) ) ) )            A A A A A A A A A A ! ! ! ! ! ! . . . . . . - - - - - -                2 1 1 1 1 1 2 2 2 2 2 2 
	H	%	% # 0 %  4 1  F F F F FK F F F4 F F F F F+ F F F8 F F F F F F F F8 F F F F FK F F F8%- %- %- %- %- %- %- %-PD- D- D- D- D-bi D- D- D-P U\ e T V[Vb    *- - - - -ry - - -D D D D D D D DN    ")   &B B B B BBI B B BJ? ? ? ? ?BI ? ? ?DO
 O
 O
 O
 O
bi O
 O
 O
f* * * * *o * * *2	   \ I
 I
 I
 I
 I
+ I
 I
	 I
X 
  b
 b
 b
 b
 b
%< b
 b
 b
J   P
 P
 P
 P
 P
%< P
 P
 P
f  	 ?
 ?
 ?
 ?
 ?
. ?
 ?
 ?
 ?
 ?
r&   