
    g             	          d Z ddlZddlZddlZddlmZ ddlmZm	Z	m
Z
 ddlZddlZddlmZmZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZmZm Z m!Z!m"Z"m#Z#m$Z$ ddl%m&Z& ddl'm(Z(  e"j)        e*          Z+dZ,dZ-g dZ.dZ/dZ0e G d de                      Z1e G d de                      Z2e G d de                      Z3e G d de                      Z4d Z5d Z6dMd ej        d!e7d"e8d#ej        fd$Z9 G d% d&ej:                  Z; G d' d(ej:                  Z< G d) d*ej:                  Z= G d+ d,ej:                  Z> G d- d.ej:                  Z? G d/ d0ej:                  Z@ G d1 d2ej:                  ZA G d3 d4ej:                  ZB G d5 d6ej:                  ZC G d7 d8ej:                  ZD G d9 d:ej:                  ZE G d; d<ej:                  ZF G d= d>e          ZGd?ZHd@ZI e dAeH           G dB dCeG                      ZJ e dDeH           G dE dFeG                      ZK e dGeH           G dH dIeG                      ZL e dJeH           G dK dLeGe&                      ZMdS )Nz!PyTorch Swinv2 Transformer model.    N)	dataclass)OptionalTupleUnion)Tensornn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BackboneOutput)PreTrainedModel) find_pruneable_heads_and_indicesmeshgridprune_linear_layer)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings	torch_int)BackboneMixin   )Swinv2Configr   z(microsoft/swinv2-tiny-patch4-window8-256)r   @   i   zEgyptian catc                       e Zd ZU dZdZej        ed<   dZe	e
ej        df                  ed<   dZe	e
ej        df                  ed<   dZe	e
ej        df                  ed<   dS )Swinv2EncoderOutputa  
    Swinv2 encoder's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlast_hidden_state.hidden_states
attentionsreshaped_hidden_states)__name__
__module____qualname____doc__r    torchFloatTensor__annotations__r!   r   r   r"   r#        f/var/www/html/ai-engine/env/lib/python3.11/site-packages/transformers/models/swinv2/modeling_swinv2.pyr   r   >   s          2 ,0u(///=AM8E%"3S"89:AAA:>Ju0#567>>>FJHU5+<c+A%BCJJJJJr,   r   c                       e Zd ZU dZdZej        ed<   dZe	ej                 ed<   dZ
e	eej        df                  ed<   dZe	eej        df                  ed<   dZe	eej        df                  ed<   dS )	Swinv2ModelOutputaV  
    Swinv2 model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
            Average pooling of the last layer hidden-state.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr    pooler_output.r!   r"   r#   )r$   r%   r&   r'   r    r(   r)   r*   r0   r   r!   r   r"   r#   r+   r,   r-   r/   r/   `   s          6 ,0u(///15M8E-.555=AM8E%"3S"89:AAA:>Ju0#567>>>FJHU5+<c+A%BCJJJJJr,   r/   c                      e Zd ZU dZdZeej                 ed<   dZ	ej        ed<   dZ
eeej        df                  ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   ed	             ZdS )
Swinv2MaskedImageModelingOutputa  
    Swinv2 masked image model outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
            Masked image modeling (MLM) loss.
        reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Reconstructed pixel values.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlossreconstruction.r!   r"   r#   c                 D    t          j        dt                     | j        S )Nzlogits attribute is deprecated and will be removed in version 5 of Transformers. Please use the reconstruction attribute to retrieve the final output instead.)warningswarnFutureWarningr4   selfs    r-   logitsz&Swinv2MaskedImageModelingOutput.logits   s*    ]	
 	
 	

 ""r,   )r$   r%   r&   r'   r3   r   r(   r)   r*   r4   r!   r   r"   r#   propertyr;   r+   r,   r-   r2   r2      s          6 )-D(5$
%,,,(,NE%,,,=AM8E%"3S"89:AAA:>Ju0#567>>>FJHU5+<c+A%BCJJJ# # X# # #r,   r2   c                       e Zd ZU dZdZeej                 ed<   dZ	ej        ed<   dZ
eeej        df                  ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   dS )	Swinv2ImageClassifierOutputa  
    Swinv2 outputs for image classification.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr3   r;   .r!   r"   r#   )r$   r%   r&   r'   r3   r   r(   r)   r*   r;   r!   r   r"   r#   r+   r,   r-   r>   r>      s          6 )-D(5$
%,,, $FE$$$=AM8E%"3S"89:AAA:>Ju0#567>>>FJHU5+<c+A%BCJJJJJr,   r>   c                     | j         \  }}}}|                     |||z  |||z  ||          } |                     dddddd                                                              d|||          }|S )z2
    Partitions the given input into windows.
    r   r   r            shapeviewpermute
contiguous)input_featurewindow_size
batch_sizeheightwidthnum_channelswindowss          r-   window_partitionrP      s     /<.A+J|!&&Fk);8Lk[g M ##Aq!Q155@@BBGGKYdfrssGNr,   c                     | j         d         }|                     d||z  ||z  |||          } |                     dddddd                                                              d|||          } | S )z?
    Merges windows to produce higher resolution features.
    rC   r   r   r   r@   rA   rB   rD   )rO   rJ   rL   rM   rN   s        r-   window_reverserR      sx     =$Lll2v4e{6JKYdfrssGooaAq!Q//::<<AA"feUabbGNr,           Finput	drop_probtrainingreturnc                     |dk    s|s| S d|z
  }| j         d         fd| j        dz
  z  z   }|t          j        || j        | j                  z   }|                                 |                     |          |z  }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    rS   r   r   )r   )dtypedevice)rE   ndimr(   randrY   rZ   floor_div)rT   rU   rV   	keep_probrE   random_tensoroutputs          r-   	drop_pathrb      s     CxII[^
Q 77E
5EL Y Y YYMYYy!!M1FMr,   c                   j     e Zd ZdZd	dee         ddf fdZdej        dej        fdZ	de
fdZ xZS )
Swinv2DropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).NrU   rW   c                 V    t                                                       || _        d S N)super__init__rU   )r:   rU   	__class__s     r-   rh   zSwinv2DropPath.__init__	  s$    "r,   r!   c                 8    t          || j        | j                  S rf   )rb   rU   rV   r:   r!   s     r-   forwardzSwinv2DropPath.forward  s    FFFr,   c                 6    d                     | j                  S )Nzp={})formatrU   r9   s    r-   
extra_reprzSwinv2DropPath.extra_repr  s    }}T^,,,r,   rf   )r$   r%   r&   r'   r   floatrh   r(   r   rl   strro   __classcell__ri   s   @r-   rd   rd     s        bb# #(5/ #T # # # # # #GU\ Gel G G G G-C - - - - - - - -r,   rd   c            
            e Zd ZdZd fd	Zdej        dededej        fdZ	 	 dd
e	ej
                 de	ej                 dedeej                 fdZ xZS )Swinv2EmbeddingszW
    Construct the patch and position embeddings. Optionally, also the mask token.
    Fc                 <   t                                                       t          |          | _        | j        j        }| j        j        | _        |r-t          j        t          j
        dd|j                            nd | _        |j        r6t          j        t          j
        d|dz   |j                            | _        nd | _        t          j        |j                  | _        t          j        |j                  | _        |j        | _        || _        d S )Nr   )rg   rh   Swinv2PatchEmbeddingspatch_embeddingsnum_patches	grid_size
patch_gridr   	Parameterr(   zeros	embed_dim
mask_tokenuse_absolute_embeddingsposition_embeddings	LayerNormnormDropouthidden_dropout_probdropout
patch_sizeconfig)r:   r   use_mask_tokenry   ri   s       r-   rh   zSwinv2Embeddings.__init__  s     5f = =+7/9O]g",u{1a9I'J'JKKKcg) 	,')|EK;QR?TZTd4e4e'f'fD$$'+D$L!122	z&"<== +r,   
embeddingsrL   rM   rW   c                    |j         d         dz
  }| j        j         d         dz
  }t          j                                        s||k    r||k    r| j        S | j        ddddf         }| j        ddddf         }|j         d         }|| j        z  }	|| j        z  }
t          |dz            }|                    d|||          }|                    dddd          }t          j
                            ||	|
fdd	
          }|                    dddd                              dd|          }t          j        ||fd          S )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   NrC         ?r   r   r@   bicubicF)sizemodealign_cornersdim)rE   r   r(   jit
is_tracingr   r   reshaperG   r   
functionalinterpolaterF   cat)r:   r   rL   rM   ry   num_positionsclass_pos_embedpatch_pos_embedr   
new_height	new_widthsqrt_num_positionss               r-   interpolate_pos_encodingz)Swinv2Embeddings.interpolate_pos_encoding-  sr    !&q)A-06q9A= y##%% 	,+*F*F6UZ??++2111bqb592111abb59r"t.
T_,	&}c'9::)11!5GI[]`aa)11!Q1==-33i(	 4 
 
 *11!Q1==BB1b#NNy/?;CCCCr,   Npixel_valuesbool_masked_posr   c                    |j         \  }}}}|                     |          \  }}	|                     |          }|                                \  }
}}|R| j                            |
|d          }|                    d                              |          }|d|z
  z  ||z  z   }| j        '|r|| 	                    |||          z   }n
|| j        z   }| 
                    |          }||	fS )NrC         ?)rE   rx   r   r   r   expand	unsqueezetype_asr   r   r   )r:   r   r   r   _rN   rL   rM   r   output_dimensionsrK   seq_lenmask_tokensmasks                 r-   rl   zSwinv2Embeddings.forwardU  s	    *6);&<(,(=(=l(K(K%
%YYz**
!+!2!2
GQ&/00WbIIK",,R0088EED#sTz2[45GGJ#/' C'$*G*G
TZ\a*b*bb

'$*BB
\\*--
,,,r,   )FNF)r$   r%   r&   r'   rh   r(   r   intr   r   r)   
BoolTensorboolr   rl   rr   rs   s   @r-   ru   ru     s              &&D5< &D &DUX &D]b]i &D &D &D &DV 7;).	- -u01- "%"23- #'	-
 
u|	- - - - - - - -r,   ru   c                   t     e Zd ZdZ fdZd Zdeej                 de	ej
        e	e         f         fdZ xZS )rw   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    t                                                       |j        |j        }}|j        |j        }}t          |t          j        j	                  r|n||f}t          |t          j        j	                  r|n||f}|d         |d         z  |d         |d         z  z  }|| _        || _        || _        || _
        |d         |d         z  |d         |d         z  f| _        t          j        ||||          | _        d S )Nr   r   )kernel_sizestride)rg   rh   
image_sizer   rN   r~   
isinstancecollectionsabcIterablery   rz   r   Conv2d
projection)r:   r   r   r   rN   hidden_sizery   ri   s          r-   rh   zSwinv2PatchEmbeddings.__init__y  s   !'!2F4EJ
$*$79Ik#-j+/:R#S#SqZZZdfpYq
#-j+/:R#S#SqZZZdfpYq
!!}
15*Q-:VW=:XY$$(&$Q-:a=8*Q-:VW=:XY)L+:^hiiir,   c                 Z   || j         d         z  dk    r@d| j         d         || j         d         z  z
  f}t          j                            ||          }|| j         d         z  dk    rBddd| j         d         || j         d         z  z
  f}t          j                            ||          }|S )Nr   r   )r   r   r   pad)r:   r   rL   rM   
pad_valuess        r-   	maybe_padzSwinv2PatchEmbeddings.maybe_pad  s    4?1%%**T_Q/%$/!:L2LLMJ=,,\:FFLDOA&&!++Q4?1#5QRAS8S#STJ=,,\:FFLr,   r   rW   c                     |j         \  }}}}|                     |||          }|                     |          }|j         \  }}}}||f}|                    d                              dd          }||fS )Nr@   r   )rE   r   r   flatten	transpose)r:   r   r   rN   rL   rM   r   r   s           r-   rl   zSwinv2PatchEmbeddings.forward  s    )5);&<~~lFEBB__\22
(.1fe#UO''**44Q::
,,,r,   )r$   r%   r&   r'   rh   r   r   r(   r)   r   r   r   rl   rr   rs   s   @r-   rw   rw   r  s         j j j j j  	-HU->$? 	-E%,X]^aXbJbDc 	- 	- 	- 	- 	- 	- 	- 	-r,   rw   c            	            e Zd ZdZej        fdee         dedej        ddf fdZ	d Z
d	ej        d
eeef         dej        fdZ xZS )Swinv2PatchMerginga'  
    Patch Merging Layer.

    Args:
        input_resolution (`Tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    input_resolutionr   
norm_layerrW   Nc                     t                                                       || _        || _        t	          j        d|z  d|z  d          | _         |d|z            | _        d S )NrA   r@   Fbias)rg   rh   r   r   r   Linear	reductionr   )r:   r   r   r   ri   s       r-   rh   zSwinv2PatchMerging.__init__  sa     01s7AG%@@@Jq3w''			r,   c                     |dz  dk    p|dz  dk    }|r.ddd|dz  d|dz  f}t           j                            ||          }|S )Nr@   r   r   )r   r   r   )r:   rI   rL   rM   
should_padr   s         r-   r   zSwinv2PatchMerging.maybe_pad  s\    qjAo:519>
 	IQ519a!<JM--mZHHMr,   rI   input_dimensionsc                    |\  }}|j         \  }}}|                    ||||          }|                     |||          }|d d dd ddd dd d f         }|d d dd ddd dd d f         }	|d d dd ddd dd d f         }
|d d dd ddd dd d f         }t          j        ||	|
|gd          }|                    |dd|z            }|                     |          }|                     |          }|S )Nr   r@   r   rC   rA   )rE   rF   r   r(   r   r   r   )r:   rI   r   rL   rM   rK   r   rN   input_feature_0input_feature_1input_feature_2input_feature_3s               r-   rl   zSwinv2PatchMerging.forward  sD   ((5(;%
C%**:vulSS}feDD'14a4Aqqq(89'14a4Aqqq(89'14a4Aqqq(89'14a4Aqqq(89	?O_Ve"fhjkk%**:r1|;KLL}55		-00r,   )r$   r%   r&   r'   r   r   r   r   Modulerh   r   r(   r   rl   rr   rs   s   @r-   r   r     s        
 
 XZWc ( (s (# (29 (hl ( ( ( ( ( (  U\ U3PS8_ Y^Ye        r,   r   c                        e Zd Zddgf fd	Zd Z	 	 	 ddej        deej                 deej                 d	ee	         d
e
ej                 f
dZ xZS )Swinv2SelfAttentionr   c           
         t                                                       ||z  dk    rt          d| d| d          || _        t	          ||z            | _        | j        | j        z  | _        t          |t          j	        j
                  r|n||f| _        || _        t          j        t          j        dt          j        |ddf          z                      | _        t          j        t          j        ddd	
          t          j        d	          t          j        d|d
                    | _        t          j        | j        d         dz
   | j        d         t          j                                                  }t          j        | j        d         dz
   | j        d         t          j                                                  }t          j        t7          ||gd                                        ddd                                                              d          }|d         dk    rG|d d d d d d dfxx         |d         dz
  z  cc<   |d d d d d d dfxx         |d         dz
  z  cc<   nV|dk    rP|d d d d d d dfxx         | j        d         dz
  z  cc<   |d d d d d d dfxx         | j        d         dz
  z  cc<   |dz  }t          j        |          t          j         t          j!        |          dz             z  tE          j         d          z  }|#                    tI          | j        %                                          j&                  }| '                    d|d           t          j        | j        d                   }	t          j        | j        d                   }
t          j        t7          |	|
gd                    }t          j(        |d          }|d d d d d f         |d d d d d f         z
  }|                    ddd                                          }|d d d d dfxx         | j        d         dz
  z  cc<   |d d d d dfxx         | j        d         dz
  z  cc<   |d d d d dfxx         d| j        d         z  dz
  z  cc<   |)                    d          }| '                    d|d           t          j        | j        | j        |j*        
          | _+        t          j        | j        | j        d
          | _,        t          j        | j        | j        |j*        
          | _-        t          j.        |j/                  | _0        d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()
   r   r@   i   Tr   )inplaceFrY   ij)indexing   r   relative_coords_table)
persistentrC   relative_position_index)1rg   rh   
ValueErrornum_attention_headsr   attention_head_sizeall_head_sizer   r   r   r   rJ   pretrained_window_sizer   r|   r(   logoneslogit_scale
Sequentialr   ReLUcontinuous_position_bias_mlparangeint64rp   stackr   rG   rH   r   signlog2absmathtonext
parametersrY   register_bufferr   sumqkv_biasquerykeyvaluer   attention_probs_dropout_probr   )r:   r   r   	num_headsrJ   r   relative_coords_hrelative_coords_wr   coords_hcoords_wcoordscoords_flattenrelative_coordsr   ri   s                  r-   rh   zSwinv2SelfAttention.__init__  s   ?akCkk_hkkk   $- #&sY#7#7 !58PP%k;?3KLLlKKS^`kRl 	 '=#<	"uz9aQRBS7T7T2T(U(UVV,.MIa4((("'$*?*?*?3PY`eAfAfAf-
 -
)
 "L4+;A+>+B)CTEUVWEX`e`klllrrtt!L4+;A+>+B)CTEUVWEX`e`klllrrttK"35F!GRVWWWXXWQ1Z\\Yq\\	 	 "!$q((!!!!QQQ1*---1G1JQ1NN---!!!!QQQ1*---1G1JQ1NN----1__!!!!QQQ1*---1A!1Dq1HH---!!!!QQQ1*---1A!1Dq1HH---"J,--
59EZ;[;[^a;a0b0bbeienopeqeqq 	 !6 8 8d>_>j>j>l>l9m9m9s t t46KX]^^^ < 0 344< 0 344Xx&:TJJJKKvq11(AAAt4~aaaqqqj7QQ)11!Q::EEGG111a   D$4Q$7!$;;   111a   D$4Q$7!$;;   111a   A(8(;$;a$??   "1"5"5b"9"968O\abbbYt143EFO\\\
9T/1C%PPPYt143EFO\\\
z&"EFFr,   c                     |                                 d d         | j        | j        fz   }|                    |          }|                    dddd          S )NrC   r   r@   r   r   )r   r   r   rF   rG   )r:   xnew_x_shapes      r-   transpose_for_scoresz(Swinv2SelfAttention.transpose_for_scores  sP    ffhhssmt'?AY&ZZFF;yyAq!$$$r,   NFr!   attention_mask	head_maskoutput_attentionsrW   c                    |j         \  }}}|                     |          }|                     |                     |                    }	|                     |                     |                    }
|                     |          }t
          j                            |d          t
          j                            |	d                              dd          z  }t          j
        | j        t          j        d                                                    }||z  }|                     | j                                      d| j                  }|| j                            d                                       | j        d         | j        d         z  | j        d         | j        d         z  d          }|                    ddd                                          }d	t          j        |          z  }||                    d          z   }||j         d         }|                    ||z  || j        ||          |                    d                              d          z   }||                    d                              d          z   }|                    d| j        ||          }t
          j                            |d          }|                     |          }|||z  }t          j        ||
          }|                    dddd
                                          }|                                d d         | j        fz   }|                    |          }|r||fn|f}|S )NrC   r   g      Y@)maxr   r   r@      r   )rE   r   r	  r   r   r   r   	normalizer   r(   clampr   r   r   expr   r   rF   r   r   rJ   rG   rH   sigmoidr   softmaxr   matmulr   r   )r:   r!   r
  r  r  rK   r   rN   mixed_query_layer	key_layervalue_layerquery_layerattention_scoresr   relative_position_bias_tablerelative_position_bias
mask_shapeattention_probscontext_layernew_context_layer_shapeoutputss                        r-   rl   zSwinv2SelfAttention.forward  sd    )6(;%
C JJ}55--dhh}.E.EFF	//

=0I0IJJ//0ABB =22;B2GG"-JaJa2 Kb K
 K

)B

 k$"28L8LMMMQQSS+k9'+'H'HIc'd'd'i'i((
 (
$ ">d>Z>_>_`b>c>c!d!i!iQ$"21"55t7G7JTM]^_M`7`bd"
 "
 "8!?!?1a!H!H!S!S!U!U!#em4J&K&K!K+.D.N.Nq.Q.QQ%'-a0J/44j(*d6NPSUX   ((++55a88 9  0.2J2J12M2M2W2WXY2Z2ZZ/44R9QSVX[\\ -//0@b/II ,,77  -	9O_kBB%--aAq99DDFF"/"4"4"6"6ss";t?Q>S"S%**+BCC6G]=/22mM]r,   NNF)r$   r%   r&   rh   r	  r(   r   r   r)   r   r   rl   rr   rs   s   @r-   r   r     s        TUWXSY ;G ;G ;G ;G ;G ;Gz% % % 7;15,1; ;|; !!23; E-.	;
 $D>; 
u|	; ; ; ; ; ; ; ;r,   r   c                   P     e Zd Z fdZdej        dej        dej        fdZ xZS )Swinv2SelfOutputc                     t                                                       t          j        ||          | _        t          j        |j                  | _        d S rf   )rg   rh   r   r   denser   r   r   r:   r   r   ri   s      r-   rh   zSwinv2SelfOutput.__init__V  sD    YsC((
z&"EFFr,   r!   input_tensorrW   c                 Z    |                      |          }|                     |          }|S rf   r'  r   )r:   r!   r)  s      r-   rl   zSwinv2SelfOutput.forward[  s*    

=11]33r,   r$   r%   r&   rh   r(   r   rl   rr   rs   s   @r-   r%  r%  U  sn        G G G G G
U\  RWR^        r,   r%  c                        e Zd Zd fd	Zd Z	 	 	 ddej        deej                 deej                 d	ee	         d
e
ej                 f
dZ xZS )Swinv2Attentionr   c           
         t                                                       t          ||||t          |t          j        j                  r|n||f          | _        t          ||          | _	        t                      | _        d S )Nr   r   r   rJ   r   )rg   rh   r   r   r   r   r   r:   r%  ra   setpruned_heads)r:   r   r   r   rJ   r   ri   s         r-   rh   zSwinv2Attention.__init__c  s    '#0+/2JKK$B#9#9(*@A
 
 
	 'vs33EEr,   c                    t          |          dk    rd S t          || j        j        | j        j        | j                  \  }}t          | j        j        |          | j        _        t          | j        j        |          | j        _        t          | j        j	        |          | j        _	        t          | j
        j        |d          | j
        _        | j        j        t          |          z
  | j        _        | j        j        | j        j        z  | j        _        | j                            |          | _        d S )Nr   r   r   )lenr   r:   r   r   r2  r   r   r   r   ra   r'  r   union)r:   headsindexs      r-   prune_headszSwinv2Attention.prune_headsq  s    u::??F7490$)2OQUQb
 
u
 -TY_eDD	*49=%@@	,TY_eDD	.t{/@%QOOO )-	(EE

(R	%"&)"?$)B_"_	 -33E::r,   NFr!   r
  r  r  rW   c                     |                      ||||          }|                     |d         |          }|f|dd          z   }|S )Nr   r   )r:   ra   )r:   r!   r
  r  r  self_outputsattention_outputr"  s           r-   rl   zSwinv2Attention.forward  sO     yy	K\]];;|AFF#%QRR(88r,   r   r#  )r$   r%   r&   rh   r8  r(   r   r   r)   r   r   rl   rr   rs   s   @r-   r.  r.  b  s        " " " " " "; ; ;* 7;15,1
 
|
 !!23
 E-.	

 $D>
 
u|	
 
 
 
 
 
 
 
r,   r.  c                   B     e Zd Z fdZdej        dej        fdZ xZS )Swinv2Intermediatec                 $   t                                                       t          j        |t	          |j        |z                      | _        t          |j        t                    rt          |j                 | _        d S |j        | _        d S rf   )rg   rh   r   r   r   	mlp_ratior'  r   
hidden_actrq   r   intermediate_act_fnr(  s      r-   rh   zSwinv2Intermediate.__init__  sx    YsC(83(>$?$?@@
f'-- 	9'-f.?'@D$$$'-'8D$$$r,   r!   rW   c                 Z    |                      |          }|                     |          }|S rf   )r'  rB  rk   s     r-   rl   zSwinv2Intermediate.forward  s,    

=1100??r,   r,  rs   s   @r-   r>  r>    s^        9 9 9 9 9U\ el        r,   r>  c                   B     e Zd Z fdZdej        dej        fdZ xZS )Swinv2Outputc                     t                                                       t          j        t	          |j        |z            |          | _        t          j        |j                  | _	        d S rf   )
rg   rh   r   r   r   r@  r'  r   r   r   r(  s      r-   rh   zSwinv2Output.__init__  sT    Ys6#3c#9::C@@
z&"<==r,   r!   rW   c                 Z    |                      |          }|                     |          }|S rf   r+  rk   s     r-   rl   zSwinv2Output.forward  s*    

=11]33r,   r,  rs   s   @r-   rE  rE    s^        > > > > >
U\ el        r,   rE  c                        e Zd Zd fd	Zdeeeef         eeef         f         fdZd Zd Z	 	 dd	e	j
        d
eeef         dee	j                 dee         dee	j
        e	j
        f         f
dZ xZS )Swinv2Layerr   c           
         t                                                       || _        |                     |j        |j        f||f          \  }}|d         | _        |d         | _        t          |||| j        t          |t          j	        j
                  r|n||f          | _        t          j        ||j                  | _        |j        dk    rt#          |j                  nt          j                    | _        t)          ||          | _        t-          ||          | _        t          j        ||j                  | _        d S )Nr   r0  epsrS   )rg   rh   r   _compute_window_shiftrJ   
shift_sizer.  r   r   r   r   	attentionr   r   layer_norm_epslayernorm_beforedrop_path_raterd   Identityrb   r>  intermediaterE  ra   layernorm_after)	r:   r   r   r   r   rN  r   rJ   ri   s	           r-   rh   zSwinv2Layer.__init__  sD    0"&"<"<!34z:6N#
 #
Z 'q>$Q-((0+/2JKK$B#9#9(*@A
 
 
 !#Sf6K L L LBHBWZ]B]B](=>>>cecncpcp.vs;;"63//!|CV5JKKKr,   rW   c                     d t          | j        |          D             }d t          | j        ||          D             }||fS )Nc                 (    g | ]\  }}||k    r|n|S r+   r+   ).0rws      r-   
<listcomp>z5Swinv2Layer._compute_window_shift.<locals>.<listcomp>  s(    eeedaAFFqqeeer,   c                 *    g | ]\  }}}||k    rd n|S r<  r+   )rX  rY  rZ  ss       r-   r[  z5Swinv2Layer._compute_window_shift.<locals>.<listcomp>  s*    sssWQ1166aaqsssr,   )zipr   )r:   target_window_sizetarget_shift_sizerJ   rN  s        r-   rM  z!Swinv2Layer._compute_window_shift  sR    eec$:OQc6d6deeessD<QS^`q8r8rsss
J&&r,   c           	         | j         dk    rvt          j        d||df|          }t          d| j                   t          | j         | j                    t          | j          d           f}t          d| j                   t          | j         | j                    t          | j          d           f}d}|D ]}|D ]}	||d d ||	d d f<   |dz  }t          || j                  }
|
                    d| j        | j        z            }
|
                    d          |
                    d          z
  }|                    |dk    t          d                                        |dk    t          d                    }nd }|S )Nr   r   r   rC   r@   g      YrS   )
rN  r(   r}   slicerJ   rP   rF   r   masked_fillrp   )r:   rL   rM   rY   img_maskheight_sliceswidth_slicescountheight_slicewidth_slicemask_windows	attn_masks               r-   get_attn_maskzSwinv2Layer.get_attn_mask  s   ?Q{Avua#8FFFHa$**++t''$/)9::t&--M a$**++t''$/)9::t&--L
 E -  #/  K@EHQQQk111<=QJEE ,Hd6FGGL',,R1ADDT1TUUL$..q11L4J4J14M4MMI!--i1neFmmLLXXYbfgYginorisisttIIIr,   c                     | j         || j         z  z
  | j         z  }| j         || j         z  z
  | j         z  }ddd|d|f}t          j                            ||          }||fS Nr   )rJ   r   r   r   )r:   r!   rL   rM   	pad_right
pad_bottomr   s          r-   r   zSwinv2Layer.maybe_pad  sp    %0@(@@DDTT	&$2B)BBdFVV
Ay!Z8
))-DDj((r,   NFr!   r   r  r  c                    |\  }}|                                 \  }}}	|}
|                    ||||	          }|                     |||          \  }}|j        \  }}}}| j        dk    r&t          j        || j         | j         fd          }n|}t          || j                  }|                    d| j        | j        z  |	          }| 	                    |||j
                  }||                    |j                  }|                     ||||          }|d         }|                    d| j        | j        |	          }t          || j        ||          }| j        dk    r$t          j        || j        | j        fd          }n|}|d         dk    p|d         dk    }|r&|d d d |d |d d f                                         }|                    |||z  |	          }|                     |          }|
|                     |          z   }|                     |          }|                     |          }||                     |                     |                    z   }|r
||d	         fn|f}|S )
Nr   )r   r@   )shiftsdimsrC   r   )r  r   rB   r   )r   rF   r   rE   rN  r(   rollrP   rJ   rl  rY   r   rZ   rO  rR   rH   rQ  rb   rT  ra   rU  )r:   r!   r   r  r  rL   rM   rK   r   channelsshortcutr   
height_pad	width_padshifted_hidden_stateshidden_states_windowsrk  attention_outputsr;  attention_windowsshifted_windows
was_paddedlayer_outputlayer_outputss                           r-   rl   zSwinv2Layer.forward  s    )"/"4"4"6"6
Ax  &**:vuhOO$(NN=&%$P$P!z&3&9#:y!?Q$)J}tFVY]YhXhEipv$w$w$w!!$1! !11FHX Y Y 5 : :2t?ORVRb?bdl m m&&z9MDW&XX	 !%:%ABBI NN!9iK\ + 
 
 -Q/,11"d6FHXZbcc():D<LjZcdd ?Q %
?DOUYUdCelr s s s /]Q&;*Q-!*;
 	V 1!!!WfWfufaaa2G H S S U U-22:v~xXX--.?@@ 4>>-#@#@@((77{{<00$t~~d6J6J<6X6X'Y'YY@Qf'8';<<XdWfr,   )r   r   r   )r$   r%   r&   rh   r   r   rM  rl  r   r(   r   r   r)   r   rl   rr   rs   s   @r-   rI  rI    s       L L L L L L.'eTYZ]_bZbTcejknpsksetTtNu ' ' ' '
  8) ) ) 26,18 8|8  S/8 E-.	8
 $D>8 
u|U\)	*8 8 8 8 8 8 8 8r,   rI  c                        e Zd Z	 d fd	Z	 	 ddej        deeef         deej	                 dee
         d	eej                 f
d
Z xZS )Swinv2Stager   c	           	         t                                                       || _        || _        g }	t	          |          D ]?}
t          |||||
dz  dk    rdn	|j        dz  |          }|	                    |           @t          j	        |	          | _
        | |||t          j                  | _        nd | _        d| _        d S )Nr@   r   )r   r   r   r   rN  r   )r   r   F)rg   rh   r   r   rangerI  rJ   appendr   
ModuleListblocksr   
downsamplepointing)r:   r   r   r   depthr   rb   r  r   r  iblockri   s               r-   rh   zSwinv2Stage.__init__)  s     	u 		! 		!A!1#!"Q!11&2D2I'=  E MM%    mF++ !(j)9sr|\\\DOO"DOr,   NFr!   r   r  r  rW   c                 (   |\  }}t          | j                  D ]'\  }}|||         nd }	 ||||	|          }
|
d         }(|}| j        -|dz   dz  |dz   dz  }}||||f}|                     ||          }n||||f}|||f}|r||
dd          z  }|S )Nr   r   r@   )	enumerater  r  )r:   r!   r   r  r  rL   rM   r  layer_modulelayer_head_maskr  !hidden_states_before_downsamplingheight_downsampledwidth_downsampledr   stage_outputss                   r-   rl   zSwinv2Stage.forwardD  s     )(55 
	- 
	-OA|.7.CillO(L !	 M *!,MM,9)?&5;aZA4EPQ	VWGW 1!'0BDU V OO,MO_``MM!' >&(IK\] 	/]122..Mr,   r<  r   )r$   r%   r&   rh   r(   r   r   r   r   r)   r   rl   rr   rs   s   @r-   r  r  (  s        mn     > 26,1   |   S/  E-.	 
 $D>  
u|	               r,   r  c                        e Zd Zd fd	Z	 	 	 	 	 ddej        deeef         deej	                 d	ee
         d
ee
         dee
         dee
         deeef         fdZ xZS )Swinv2Encoderr   r   r   r   c                 *   t                                                       t          |j                  | _        || _        | j        j        |j        }d t          j        d|j	        t          |j                            D             }g }t          | j                  D ]}t          |t          |j        d|z  z            |d         d|z  z  |d         d|z  z  f|j        |         |j        |         |t          |j        d |                   t          |j        d |dz                               || j        dz
  k     rt           nd ||                   }|                    |           t%          j        |          | _        d| _        d S )Nc                 6    g | ]}|                                 S r+   )item)rX  r  s     r-   r[  z*Swinv2Encoder.__init__.<locals>.<listcomp>n  s     ^^^Aqvvxx^^^r,   r   r@   r   )r   r   r   r  r   rb   r  r   F)rg   rh   r4  depths
num_layersr   pretrained_window_sizesr(   linspacerR  r   r  r  r   r~   r   r   r  r   r  layersgradient_checkpointing)	r:   r   rz   r  dprr  i_layerstageri   s	           r-   rh   zSwinv2Encoder.__init__h  s   fm,,;.:&,&D#^^63H#fmJ\J\!]!]^^^T_-- 	! 	!G(1g:566"+A,1g:">	!QRT[Q[@\!]mG, *73c&-"9::S}QX[\Q\}A]=^=^^_29DOa<O2O2O--VZ'>w'G	 	 	E MM%    mF++&+###r,   NFTr!   r   r  r  output_hidden_states(output_hidden_states_before_downsamplingreturn_dictrW   c                 h   |rdnd }|rdnd }	|rdnd }
|r?|j         \  }}} |j        |g||R  }|                    dddd          }||fz  }|	|fz  }	t          | j                  D ]\  }}|||         nd }| j        r%| j        r|                     |j        |||          }n |||||          }|d         }|d         }|d         }|d         |d         f}|rP|rN|j         \  }}} |j        |g|d         |d         f|R  }|                    dddd          }||fz  }|	|fz  }	nC|rA|s?|j         \  }}} |j        |g||R  }|                    dddd          }||fz  }|	|fz  }	|r|
|dd          z  }
|st          d |||
|	fD                       S t          |||
|		          S )
Nr+   r   r   r   r@   r  rC   c              3      K   | ]}||V  	d S rf   r+   )rX  vs     r-   	<genexpr>z(Swinv2Encoder.forward.<locals>.<genexpr>  s0        =  === r,   )r    r!   r"   r#   )rE   rF   rG   r  r  r  rV   _gradient_checkpointing_func__call__tupler   )r:   r!   r   r  r  r  r  r  all_hidden_statesall_reshaped_hidden_statesall_self_attentionsrK   r   r   reshaped_hidden_stater  r  r  r  r  r   s                        r-   rl   zSwinv2Encoder.forward  s    #7@BBD+?%IRRT"$5?bb4 	C)6)<&J;$6M$6z$bDT$bVa$b$b$b!$9$A$A!Q1$M$M!-!11&+@*BB&(55 (	9 (	9OA|.7.CillO* 
t} 
 $ A A )=:JO! ! !-!$#%	! ! *!,M0=a0@- -a 0 1" 57H7LM# G(P G-N-T*
A{ )O(I(N)"3A"68I!8L!M)OZ) ) )% )>(E(EaAq(Q(Q%!&G%II!*/D.FF**% G.V G-:-@*
A{(:(::(fHX(fZe(f(f(f%(=(E(EaAq(Q(Q%!m%55!*/D.FF*  9#}QRR'88# 	  '):<OQkl      #++*#=	
 
 
 	
r,   )r  )NFFFT)r$   r%   r&   rh   r(   r   r   r   r   r)   r   r   r   rl   rr   rs   s   @r-   r  r  g  s        , , , , , ,: 26,1/4CH&*L
 L
|L
  S/L
 E-.	L

 $D>L
 'tnL
 3;4.L
 d^L
 
u))	*L
 L
 L
 L
 L
 L
 L
 L
r,   r  c                   .    e Zd ZdZeZdZdZdZdgZ	d Z
dS )Swinv2PreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    swinv2r   Tr  c                    t          |t          j        t          j        f          rT|j        j                            d| j        j                   |j	         |j	        j        
                                 dS dS t          |t          j                  r?|j	        j        
                                 |j        j                            d           dS dS )zInitialize the weightsrS   )meanstdNr   )r   r   r   r   weightdatanormal_r   initializer_ranger   zero_r   fill_)r:   modules     r-   _init_weightsz#Swinv2PreTrainedModel._init_weights  s    fry")455 	* M&&CT[5R&SSS{& &&((((( '&-- 	*K""$$$M$$S)))))	* 	*r,   N)r$   r%   r&   r'   r   config_classbase_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modulesr  r+   r,   r-   r  r    sM         
  L $O&*#&
* 
* 
* 
* 
*r,   r  aI  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`Swinv2Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
            for details.
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        interpolate_pos_encoding (`bool`, *optional*, default `False`):
            Whether to interpolate the pre-trained position encodings.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z`The bare Swinv2 Model transformer outputting raw hidden-states without any specific head on top.c                   (    e Zd Zd fd	Zd Zd Z ee           ee	e
ede          	 	 	 	 	 	 	 dd	eej                 d
eej                 deej                 dee         dee         dedee         deee
f         fd                        Z xZS )Swinv2ModelTFc                    t                                          |           || _        t          |j                  | _        t          |j        d| j        dz
  z  z            | _        t          ||          | _
        t          || j
        j                  | _        t          j        | j        |j                  | _        |rt          j        d          nd | _        |                                  d S )Nr@   r   )r   rK  )rg   rh   r   r4  r  r  r   r~   num_featuresru   r   r  r{   encoderr   r   rP  	layernormAdaptiveAvgPool1dpooler	post_init)r:   r   add_pooling_layerr   ri   s       r-   rh   zSwinv2Model.__init__  s       fm,, 0119L3M MNN*6.QQQ$VT_-GHHd&7V=RSSS1BLb*1--- 	r,   c                     | j         j        S rf   r   rx   r9   s    r-   get_input_embeddingsz Swinv2Model.get_input_embeddings"      //r,   c                     |                                 D ]/\  }}| j        j        |         j                            |           0dS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  layerrO  r8  )r:   heads_to_pruner  r6  s       r-   _prune_headszSwinv2Model._prune_heads%  sU    
 +0022 	C 	CLE5Lu%/;;EBBBB	C 	Cr,   vision)
checkpointoutput_typer  modalityexpected_outputNr   r   r  r  r  r   r  rW   c                 ~   ||n| j         j        }||n| j         j        }||n| j         j        }|t	          d          |                     |t          | j         j                            }|                     |||          \  }}	| 	                    ||	||||          }
|
d         }| 
                    |          }d}| j        >|                     |                    dd                    }t          j        |d          }|s||f|
dd         z   }|S t          |||
j        |
j        |
j                  S )	z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)r   r   )r  r  r  r  r   r   r@   )r    r0   r!   r"   r#   )r   r  r  use_return_dictr   get_head_maskr4  r  r   r  r  r  r   r(   r   r/   r!   r"   r#   )r:   r   r   r  r  r  r   r  embedding_outputr   encoder_outputssequence_outputpooled_outputra   s                 r-   rl   zSwinv2Model.forward-  s   , 2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+B]?@@@ &&y#dk6H2I2IJJ	-1__/Tl .= .
 .
** ,,/!5# ' 
 
 *!,..99;" KK(A(A!Q(G(GHHM!M-;;M 	%}58KKFM -')7&1#2#I
 
 
 	
r,   )TFNNNNNFN)r$   r%   r&   rh   r  r  r   SWINV2_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr/   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r(   r)   r   r   r   r   rl   rr   rs   s   @r-   r  r    sS            0 0 0C C C +*+BCC&%$.   596:15,0/3).&*>
 >
u01>
 "%"23>
 E-.	>

 $D>>
 'tn>
 #'>
 d^>
 
u''	(>
 >
 >
  DC>
 >
 >
 >
 >
r,   r  aY  Swinv2 Model with a decoder on top for masked image modeling, as proposed in
[SimMIM](https://arxiv.org/abs/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    c                       e Zd Z fdZ ee           eee          	 	 	 	 	 	 	 dde	e
j                 de	e
j                 de	e
j                 de	e         d	e	e         d
ede	e         deeef         fd                        Z xZS )Swinv2ForMaskedImageModelingc                    t                                          |           t          |dd          | _        t	          |j        d|j        dz
  z  z            }t          j        t          j	        ||j
        dz  |j        z  d          t          j        |j
                            | _        |                                  d S )NFT)r  r   r@   r   )in_channelsout_channelsr   )rg   rh   r  r  r   r~   r  r   r   r   encoder_striderN   PixelShuffledecoderr  )r:   r   r  ri   s      r-   rh   z%Swinv2ForMaskedImageModeling.__init__  s       !&ERVWWW6+aF4E4I.JJKK}I(v7La7ORXRe7est   OF122	
 
 	r,   r  r  NFr   r   r  r  r  r   r  rW   c           	         ||n| j         j        }|                     |||||||          }|d         }	|	                    dd          }	|	j        \  }
}}t          j        |dz            x}}|	                    |
|||          }	|                     |	          }d}|| j         j	        | j         j
        z  }|                    d||          }|                    | j         j
        d                              | j         j
        d                              d                                          }t          j                            ||d	          }||z                                  |                                d
z   z  | j         j        z  }|s|f|dd         z   }||f|z   n|S t'          |||j        |j        |j                  S )aQ  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Returns:

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, Swinv2ForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
        >>> model = Swinv2ForMaskedImageModeling.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 256, 256]
        ```N)r   r  r  r  r   r  r   r   r@   r   rC   none)r   gh㈵>)r3   r4   r!   r"   r#   )r   r  r  r   rE   r   floorr   r  r   r   repeat_interleaver   rH   r   r   l1_lossr   rN   r2   r!   r"   r#   )r:   r   r   r  r  r  r   r  r"  r  rK   rN   sequence_lengthrL   rM   reconstructed_pixel_valuesmasked_im_lossr   r   reconstruction_lossra   s                        r-   rl   z$Swinv2ForMaskedImageModeling.forward  s   R &1%<kk$+B]+++/!5%=#  
 
 "!*)33Aq994C4I1
L/OS$8999)11*lFTYZZ &*\\/%B%B"&;)T[-CCD-55b$EEO11$+2H!LL""4;#91==1	  #%-"7"7F`lr"7"s"s1D8==??488::PTCTUX\XcXppN 	Z02WQRR[@F3A3M^%..SYY.5!/)#*#A
 
 
 	
r,   r  )r$   r%   r&   rh   r   r  r   r2   r  r   r(   r)   r   r   r   r   rl   rr   rs   s   @r-   r  r  v  s.             +*+BCC+JYhiii 596:15,0/3).&*T
 T
u01T
 "%"23T
 E-.	T

 $D>T
 'tnT
 #'T
 d^T
 
u55	6T
 T
 T
 ji DCT
 T
 T
 T
 T
r,   r  a  
    Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
    of the [CLS] token) e.g. for ImageNet.

    <Tip>

        Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by
        setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
        position embeddings to the higher resolution.

    </Tip>
    c                       e Zd Z fdZ ee           eeee	e
          	 	 	 	 	 	 	 ddeej                 deej                 deej                 dee         d	ee         d
edee         deeef         fd                        Z xZS )Swinv2ForImageClassificationc                 @   t                                          |           |j        | _        t          |          | _        |j        dk    r$t          j        | j        j        |j                  nt          j                    | _	        | 
                                 d S rn  )rg   rh   
num_labelsr  r  r   r   r  rS  
classifierr  r:   r   ri   s     r-   rh   z%Swinv2ForImageClassification.__init__  s        +!&)) GMFWZ[F[F[BIdk.0ABBBacalanan 	
 	r,   )r  r  r  r  NFr   r  labelsr  r  r   r  rW   c                    ||n| j         j        }|                     ||||||          }|d         }	|                     |	          }
d}|Z| j         j        f| j        dk    rd| j         _        nN| j        dk    r7|j        t          j        k    s|j        t          j	        k    rd| j         _        nd| j         _        | j         j        dk    rWt                      }| j        dk    r1 ||
                                |                                          }n ||
|          }n| j         j        dk    rGt                      } ||
                    d| j                  |                    d                    }n*| j         j        dk    rt                      } ||
|          }|s|
f|dd         z   }||f|z   n|S t          ||
|j        |j        |j        	          S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r  r  r  r   r  r   
regressionsingle_label_classificationmulti_label_classificationrC   r@   )r3   r;   r!   r"   r#   )r   r  r  r  problem_typer  rY   r(   longr   r   squeezer
   rF   r	   r>   r!   r"   r#   )r:   r   r  r  r  r  r   r  r"  r  r;   r3   loss_fctra   s                 r-   rl   z$Swinv2ForImageClassification.forward  s   . &1%<kk$+B]++/!5%=#  
 
  
//{'/?a''/;DK,,_q((flej.H.HFL\a\eLeLe/LDK,,/KDK,{'<77"99?a''#8FNN$4$4fnn6F6FGGDD#8FF33DD)-JJJ+--xB @ @&++b//RR)-III,..x// 	FY,F)-)9TGf$$vE*!/)#*#A
 
 
 	
r,   r  )r$   r%   r&   rh   r   r  r   _IMAGE_CLASS_CHECKPOINTr>   r  _IMAGE_CLASS_EXPECTED_OUTPUTr   r(   r)   
LongTensorr   r   r   rl   rr   rs   s   @r-   r  r    s;       "     +*+BCC*/$4	   5915-1,0/3).&*@
 @
u01@
 E-.@
 )*	@

 $D>@
 'tn@
 #'@
 d^@
 
u11	2@
 @
 @
  DC@
 @
 @
 @
 @
r,   r  zO
    Swinv2 backbone, to be used with frameworks like DETR and MaskFormer.
    c                        e Zd Z fdZd Z ee           eee	          	 	 	 dde
dee         dee         dee         d	ef
d
                        Z xZS )Swinv2Backbonec                    t                                                     t                                                     j        gfdt	          t          j                            D             z   | _        t                    | _	        t          | j	        j                  | _        |                                  d S )Nc                 D    g | ]}t          j        d |z  z            S )r@   )r   r~   )rX  r  r   s     r-   r[  z+Swinv2Backbone.__init__.<locals>.<listcomp>b  s.    1r1r1rST#f6FA6M2N2N1r1r1rr,   )rg   rh   _init_backboner~   r  r4  r  r  ru   r   r  r{   r  r  r   s    `r-   rh   zSwinv2Backbone.__init__^  s       v&&&#-.1r1r1r1rX]^abhbo^p^pXqXq1r1r1rr*622$VT_-GHH 	r,   c                     | j         j        S rf   r  r9   s    r-   r  z#Swinv2Backbone.get_input_embeddingsi  r  r,   r  Nr   r  r  r  rW   c           	         ||n| j         j        }||n| j         j        }||n| j         j        }|                     |          \  }}|                     ||d|dd|          }|r|j        n|d         }d}	t          | j        |          D ]\  }
}|
| j	        v r|	|fz  }	|s!|	f}|r||d         fz  }|r||d         fz  }|S t          |	|r|j        nd|j                  S )	a]  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
        >>> model = AutoBackbone.from_pretrained(
        ...     "microsoft/swinv2-tiny-patch4-window8-256", out_features=["stage1", "stage2", "stage3", "stage4"]
        ... )

        >>> inputs = processor(image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> feature_maps = outputs.feature_maps
        >>> list(feature_maps[-1].shape)
        [1, 2048, 7, 7]
        ```NT)r  r  r  r  r  rC   r+   r   r@   )feature_mapsr!   r"   )r   r  r  r  r   r  r#   r^  stage_namesout_featuresr   r!   r"   )r:   r   r  r  r  r  r   r"  r!   r  r  hidden_statera   s                r-   rl   zSwinv2Backbone.forwardl  s_   F &1%<kk$+B]$8$D  $+Jj 	 2C1N--TXT_Tq-1__\-J-J**,,/!%59#  
 
 ;FV667SU;#&t'7#G#G 	0 	0E<)))/ 	"_F# (71:-'  (71:-'M%3GQ'//T)
 
 
 	
r,   )NNN)r$   r%   r&   rh   r  r   r  r   r   r  r   r   r   rl   rr   rs   s   @r-   r  r  W  s        	 	 	 	 	0 0 0 +*+BCC>XXX -1/3&*F
 F
F
 $D>F
 'tn	F

 d^F
 
F
 F
 F
 YX DCF
 F
 F
 F
 F
r,   r  )rS   F)Nr'   collections.abcr   r   r6   dataclassesr   typingr   r   r   r(   torch.utils.checkpointr   r   torch.nnr	   r
   r   activationsr   modeling_outputsr   modeling_utilsr   pytorch_utilsr   r   r   utilsr   r   r   r   r   r   r   utils.backbone_utilsr   configuration_swinv2r   
get_loggerr$   loggerr  r  r  r
  r  r   r/   r2   r>   rP   rR   rp   r   rb   r   rd   ru   rw   r   r   r%  r.  r>  rE  rI  r  r  r  SWINV2_START_DOCSTRINGr  r  r  r  r  r+   r,   r-   <module>r'     s   ( '       ! ! ! ! ! ! ) ) ) ) ) ) ) ) ) )              A A A A A A A A A A ! ! ! ! ! ! . . . . . . - - - - - - [ [ [ [ [ [ [ [ [ [                  2 1 1 1 1 1 . . . . . . 
	H	%	% ! A %  E -  K K K K K+ K K K@  K  K  K  K  K  K  K  KF )# )# )# )# )#k )# )# )#X  K  K  K  K  K+  K  K  KH	 	 	   U\ e T V[Vb    *- - - - -RY - - -Y- Y- Y- Y- Y-ry Y- Y- Y-z(- (- (- (- (-BI (- (- (-V3 3 3 3 3 3 3 3l~ ~ ~ ~ ~") ~ ~ ~D
 
 
 
 
ry 
 
 
+ + + + +bi + + +^        	 	 	 	 	29 	 	 	x x x x x") x x xv< < < < <") < < <~f
 f
 f
 f
 f
BI f
 f
 f
T* * * * *O * * *2	  0 f 
a
 a
 a
 a
 a
' a
 a
 
a
H 	  g
 g
 g
 g
 g
#8 g
 g
 g
T    V
 V
 V
 V
 V
#8 V
 V
!  V
r  	 W
 W
 W
 W
 W
*M W
 W
 W
 W
 W
r,   