
    g                     B   d Z ddlZddlmZ ddlmZmZmZ ddlZddl	Zddlm
Z
 ddlmZmZmZ ddlmZ dd	lmZ dd
lmZmZmZmZmZmZ ddlmZ ddlmZmZmZ ddl m!Z!m"Z"m#Z#m$Z$m%Z%m&Z& ddl'm(Z(  e&j)        e*          Z+dZ,dZ-dZ.da/d Z0d Z1dQdZ2dQdZ3dQdZ4d Z5 G d dej6        j7                  Z8 G d dej6        j7                  Z9 G d d          Z:dRd Z;d! Z<	 	 	 dSd"Z= G d# d$e
j>                  Z? G d% d&e
j>                  Z@ G d' d(e
j>                  ZA G d) d*e
j>                  ZB G d+ d,e
j>                  ZC G d- d.e
j>                  ZD G d/ d0e
j>                  ZE G d1 d2e
j>                  ZF G d3 d4e
j>                  ZG G d5 d6e
j>                  ZH G d7 d8e
j>                  ZI G d9 d:e          ZJd;ZKd<ZL e"d=eK           G d> d?eJ                      ZM e"d@eK           G dA dBeJ                      ZN G dC dDe
j>                  ZO e"dEeK           G dF dGeJ                      ZP e"dHeK           G dI dJeJ                      ZQ e"dKeK           G dL dMeJ                      ZR e"dNeK           G dO dPeJ                      ZSdS )TzPyTorch MRA model.    N)Path)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss)load   )ACT2FN)"BaseModelOutputWithCrossAttentionsMaskedLMOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardis_ninja_availableis_torch_cuda_availablelogging   )	MraConfigzuw-madison/mra-base-512-4r   AutoTokenizerc                      t          t                                                    j        j        j        dz  dz  fd}  | g d          }t	          d|d          ad S )Nkernelsmrac                      fd| D             S )Nc                     g | ]}|z  S  r&   ).0file
src_folders     `/var/www/html/ai-engine/env/lib/python3.11/site-packages/transformers/models/mra/modeling_mra.py
<listcomp>z:load_cuda_kernels.<locals>.append_root.<locals>.<listcomp>@   s    444d
T!444    r&   )filesr)   s    r*   append_rootz&load_cuda_kernels.<locals>.append_root?   s    4444e4444r,   )zcuda_kernel.cuzcuda_launch.cuztorch_extension.cppcuda_kernelT)verbose)r   __file__resolveparentr   mra_cuda_kernel)r.   	src_filesr)   s     @r*   load_cuda_kernelsr6   ;   sv    h''))07>JURJ5 5 5 5 5 WWWXXI=)TBBBOOOr,   c                    t          |                                           dk    rt          d          t          |                                          dk    rt          d          |                     d          dk    rt          d          |                     d          dk    rt          d          |                     d	
          j                            dd	          }|                                }|                                }|                                }t          	                    ||||          \  }}|                    dd	          dddddddf         }||fS )z8
    Computes maximum values for softmax stability.
       z.sparse_qk_prod must be a 4-dimensional tensor.   'indices must be a 2-dimensional tensor.    z>The size of the second dimension of sparse_qk_prod must be 32.r   z=The size of the third dimension of sparse_qk_prod must be 32.dimN)
lensize
ValueErrormaxvalues	transpose
contiguousintr4   	index_max)sparse_qk_prodindicesquery_num_blockkey_num_block
index_valsmax_valsmax_vals_scatters          r*   
sparse_maxrP   G   s`    >  !!Q&&IJJJ
7<<>>aBCCC1##YZZZ1##XYYY###++2<<RDDJ&&((JkkmmG  ""G!0!:!::wP_an!o!oH'11"b99!!!QQQaaa-H%%%r,   r;   c                 B   t          |                                           dk    rt          d          t          |                                          dk    rt          d          | j        d         |j        d         k    rt          d          | j        \  }}||z  }t	          j        |                    d          t          j        |j                  }|                     |||          } | |dddf         ||z                                  ddf         } | S )zN
    Converts attention mask to a sparse mask for high resolution logits.
    r9   z$mask must be a 2-dimensional tensor.r:   r   zBmask and indices must have the same size in the zero-th dimension.dtypedeviceN)	r@   rA   rB   shapetorcharangelongrT   reshape)maskrJ   
block_size
batch_sizeseq_len	num_block	batch_idxs          r*   sparse_maskr`   c   s    499;;1?@@@
7<<>>aBCCCz!}a(((]^^^*J:%IW\\!__EJw~VVVI<<
Iz::D	!!!T'"Wy%8$>$>$@$@!!!CDDKr,   c                 R   |                                  \  }}}|                                 \  }}}||z  dk    rt          d          ||z  dk    rt          d          |                     |||z  ||                              dd          } |                    |||z  ||                              dd          }t	          |                                            dk    rt          d          t	          |                                           dk    rt          d          t	          |                                           d	k    rt          d
          |                      d          dk    rt          d          |                     d          dk    rt          d          |                                 } |                                }|                                }|                                }t                              | ||                                          S )z7
    Performs Sampled Dense Matrix Multiplication.
    r   zTquery_size (size of first dimension of dense_query) must be divisible by block_size.Pkey_size (size of first dimension of dense_key) must be divisible by block_size.r?   r<   r8   z+dense_query must be a 4-dimensional tensor.)dense_key must be a 4-dimensional tensor.r9   r:   r   r;   z.The third dimension of dense_query must be 32.z,The third dimension of dense_key must be 32.)	rA   rB   rY   rE   r@   rF   rG   r4   mm_to_sparse)	dense_query	dense_keyrJ   r[   r\   
query_sizer>   _key_sizes	            r*   rd   rd   z   s    #."2"2"4"4J
C ~~''AxJ!##oppp*!!klll%%j*
2JJX[\\ffgikmnnK!!*h*.DjRUVV``aceghhI
;!##FGGG
9>>!!DEEE
7<<>>aBCCCb  IJJJ~~aBGHHH((**K$$&&IkkmmG  ""G''YNNNr,   c                 "   |                                 \  }}}||z  dk    rt          d          |                      d          |k    rt          d          |                      d          |k    rt          d          |                    |||z  ||                              dd          }t	          |                                            d	k    rt          d
          t	          |                                           d	k    rt          d          t	          |                                           dk    rt          d          |                     d          dk    rt          d          |                                 } |                                }|                                }|                                }t                              | |||          }|                    dd                              |||z  |          }|S )zP
    Performs matrix multiplication of a sparse matrix with a dense matrix.
    r   rb   r9   zQThe size of the second dimension of sparse_query must be equal to the block_size.r   zPThe size of the third dimension of sparse_query must be equal to the block_size.r?   r<   r8   ,sparse_query must be a 4-dimensional tensor.rc   r:   r;   z8The size of the third dimension of dense_key must be 32.)	rA   rB   rY   rE   r@   rF   rG   r4   sparse_dense_mm)	sparse_queryrJ   rf   rK   r[   r\   ri   r>   dense_qk_prods	            r*   rl   rl      s    !* 0 0J#*!!klllz))lmmmz))klll!!*h*.DjRUVV``aceghhI
<1$$GHHH
9>>!!DEEE
7<<>>aBCCC~~aBSTTT**,,LkkmmG  ""G$$&&I#33L'9VeffM!++B33;;JZdHdfijjMr,   c                 f    | |z  |z  t          j        | |d          z                                   S )Nfloorrounding_mode)rV   divrX   )rJ   dim_1_blockdim_2_blocks      r*   transpose_indicesrv      s5    {"k1EIg{bi4j4j4jjpprrrr,   c                   R    e Zd Zed             Zed             Zedd            ZdS )MraSampledDenseMatMulc                 f    t          ||||          }|                     |||           || _        |S N)rd   save_for_backwardr[   )ctxre   rf   rJ   r[   rI   s         r*   forwardzMraSampledDenseMatMul.forward   s:    %k9gzRRk9g>>>#r,   c                 $   | j         \  }}}| j        }|                    d          |z  }|                    d          |z  }t          |||          }t	          |                    dd          |||          }	t	          ||||          }
|
|	d d fS Nr   r?   r<   )saved_tensorsr[   rA   rv   rl   rE   )r|   gradre   rf   rJ   r[   rK   rL   	indices_Tgrad_key
grad_querys              r*   backwardzMraSampledDenseMatMul.backward   s    *-*;'Y^
%**1--;!q))Z7%gNN	"4>>"b#9#99kS`aa$T7IOO
8T4//r,   r;   c                 <    t                               | |||          S rz   )rx   apply)re   rf   rJ   r[   s       r*   operator_callz#MraSampledDenseMatMul.operator_call   s    $**;	7JWWWr,   Nr;   __name__
__module____qualname__staticmethodr}   r   r   r&   r,   r*   rx   rx      sn          \ 0 0 \0 X X X \X X Xr,   rx   c                   P    e Zd Zed             Zed             Zed             ZdS )MraSparseDenseMatMulc                 f    t          ||||          }|                     |||           || _        |S rz   )rl   r{   rK   )r|   rm   rJ   rf   rK   rI   s         r*   r}   zMraSparseDenseMatMul.forward   s;    (w	?[[lGY???-r,   c                    | j         \  }}}| j        }|                    d          |                    d          z  }t          |||          }t	          |                    dd          |||          }t          |||          }	|	d |d fS r   )r   rK   rA   rv   rl   rE   rd   )
r|   r   rm   rJ   rf   rK   rL   r   r   r   s
             r*   r   zMraSparseDenseMatMul.backward   s    +.+<(gy-!q))\->->r-B-BB%gNN	"<#9#9"b#A#A9dTabb!$	7;;
44//r,   c                 <    t                               | |||          S rz   )r   r   )rm   rJ   rf   rK   s       r*   r   z"MraSparseDenseMatMul.operator_call   s    #)),O\\\r,   Nr   r&   r,   r*   r   r      sh          \ 0 0 \0 ] ] \] ] ]r,   r   c                   $    e Zd Zed             ZdS )MraReduceSumc                    |                                  \  }}}}t          |                                            dk    rt          d          t          |                                           dk    rt          d          |                                  \  }}}}|                                 \  }}|                     d                              ||z  |          } t          j        |                     d          t
          j        |j                  }t          j	        ||d	                                          |d d d f         |z  z                       ||z            }	t          j
        ||z  |f| j        | j                  }
|
                    d|	|                               |||          }|                    |||z            }|S )
Nr8   rk   r9   r:   r=   r   rR   rp   rq   )rA   r@   rB   sumrY   rV   rW   rX   rT   rs   zerosrS   	index_add)rm   rJ   rK   rL   r\   r^   r[   rh   r_   global_idxestempoutputs               r*   r   zMraReduceSum.operator_call   s   /;/@/@/B/B,
Iz1|  ""##q((KLLLw||~~!##FGGG*//111j! '
I#''A'..66zI7MzZZLa
7>ZZZ	Ig}GDDDIIKKiXYXYXY[_X_N`crNrr
'*y(
)
) 	 {/):6l>PYeYl
 
 
 <>>FFzSbdnoo
Oj,HIIr,   N)r   r   r   r   r   r&   r,   r*   r   r      s-          \  r,   r   c                    |                                  \  }}}||z  }d}	||                    |||                              d          }
|                     ||||                              d          |
dddddf         dz   z  }|                    ||||                              d          |
dddddf         dz   z  }|?|                    ||||                              d          |
dddddf         dz   z  }	n|t          j        ||t          j        | j                  z  }
|                     ||||                              d          }|                    ||||                              d          }|,|                    ||||                              d          }	t          j        ||	                    dd                    t          j        |          z  }|                    dd          j        }|;|d	|
dddddf         |
dddddf         z  d
k                                     z  z
  }||
||	fS )z/
    Compute low resolution approximation.
    Nr?   r=   r<   ư>rR   T)r>   keepdims     @g      ?)rA   rY   r   rV   onesfloatrT   meanmatmulrE   mathsqrtrC   rD   )querykeyr[   rZ   valuer\   r]   head_dimnum_block_per_row	value_hattoken_count	query_hatkey_hatlow_resolution_logitlow_resolution_logit_row_maxs                  r*   get_low_resolution_logitr     s    %*JJLL!J:-Ill:/@*MMQQVXQYYMM*.?XVVZZ_aZbb111d
#d*
	 ++j*;ZRRVV[]V^^111d
#d*
 j2CZQYZZ^^ce^ffAAAqqq$J'$.I !5:j:KSXS^glgs#t#t#ttMM*.?XVV[[`b[cc	++j*;ZRRWW\^W__j2CZQYZZ__df_ggI <	73D3DR3L3LMMPTPYZbPcPcc#7#;#;T#;#R#R#Y  3;qqq$z+B[QRQRQRTUTUTUW[Q[E\+\`c*c)j)j)l)l#ll 	  .JIUUr,   c                    | j         \  }}}|dk    ra|dz  }t          j        ||| j                  }	t          j        t          j        |	|           |          }
| |
dddddf         dz  z   } |dk    r@| ddd|ddf         dz   | ddd|ddf<   | ddddd|f         dz   | ddddd|f<   t          j        |                     |d          |ddd	
          }|j        }|dk    rD|j	        
                    d          j	        }| |ddddf         k                                    }n|dk    rd}nt          | d          ||fS )zZ
    Compute the indices of the subset of components to be used in the approximation.
    r   r9   rT   )diagonalNg     @r?   TF)r>   largestsortedfullr=   sparsez# is not a valid approx_model value.)rU   rV   r   rT   triltriutopkrY   rJ   rD   minr   rB   )r   
num_blocksapprox_modeinitial_prior_first_n_blocksinitial_prior_diagonal_n_blocksr\   total_blocks_per_rowrh   offset	temp_maskdiagonal_mask
top_k_valsrJ   	thresholdhigh_resolution_masks                  r*   get_block_idxesr   B  s    +?*D'J$a&**0A5J35IRfRmnnn	
5:i6'#J#J#JU[\\\3mD!!!QQQJ6ORU6UU#a'' $A%A$A111!DEK 	QQQ =!= =qqq@A !AAA'D(D'D!DEK 	QQQ#@$@#@@A $$Z44jbRV_d  J  Gf%))b)118	 4	!!!T4-8P PWWYY		 	 #KLLLMMM(((r,   c	                    t           &t          j        |                                           S |                                 \  }	}
}}|	|
z  }||z  dk    rt          d          ||z  }|                     |||          } |                    |||          }|                    |||          }|6| |dddddf         z  } ||dddddf         z  }||dddddf         z  }|dk    rt          | ||||          \  }}}}nX|dk    rCt          j                    5  t          | |||          \  }}}}ddd           n# 1 swxY w Y   nt          d          t          j                    5  ||z
  }t          |||||          \  }}ddd           n# 1 swxY w Y   t                              | |||          t          j        |          z  }t          ||||          \  }}||z
  }|)|dd	t!          ||          dddddddf         z
  z  z
  }t          j        |          }t$                              ||||          }t&                              ||||          }|dk    rt          j        ||z
  d|z  z
            |dddddf         z  }t          j        ||          dddddddf                             d	d	|d	                              |||          }|                    d
          dddddf                             d	d	|                              ||          }|                    d	d	|                              ||          |z
  } || |z  } t          j        | | dk                                    z            }!||!dddddf         z  }||!z  }t          j        |  | dk                                    z            }"||"dddddf         z  }||"z  }||z   |dddddf         |dddddf         z   dz   z  }#n+|dk    r||dddddf         dz   z  }#nt          d          ||#|dddddf         z  }#|#                    |	|
||          }#|#S )z0
    Use Mra to approximate self-attention.
    Nr   z4sequence length must be divisible by the block_size.r   r   z&approx_mode must be "full" or "sparse")r[   r   r   r?   r=   r   z-config.approx_mode must be "full" or "sparse")r4   rV   
zeros_likerequires_grad_rA   rB   rY   r   no_grad	Exceptionr   rx   r   r   r   rP   r`   expr   r   r   repeatr   r   )$r   r   r   rZ   r   r   r[   r   r   r\   num_headr]   r   
meta_batchr   r   r   r   r   rh   low_resolution_logit_normalizedrJ   r   high_resolution_logitrN   rO   high_resolution_attnhigh_resolution_attn_outhigh_resolution_normalizerlow_resolution_attnlow_resolution_attn_outlow_resolution_normalizerlog_correctionlow_resolution_corrhigh_resolution_corrcontext_layers$                                       r*   mra2_attentionr   h  s2    &&55777.3jjll+J'8h&Jq  OPPP:-MM*gx88E
++j'8
4
4CMM*gx88EQQQ4Z((DAAAt$$QQQ4Z((fUm3
D%V
 V
Rk+G 
	 	 ]__ 	 	QisJR RN +/KQ	 	 	 	 	 	 	 	 	 	 	 	 	 	 	
 @AAA	 
 
*>A]*]'(7+(+)
 )
%%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 2??sG
 @  	( ",,A7L]_p!q!qH14DD 5q;tU\C]C]^_^_^_abababdededegk^kCl?l8m m 9%:;;3AAgu.?    ".!;!;g'8:K" " fI*-IICRfLffgg!!!T111*%& 	 L,i88AAAtQQQGVAq*a((WZ(33 	   ###++AAAqqq$J7>>q!ZPPXXYcelmm 	" 6<<Q:NNVVWacjkknvv+d2N#i.A:M9T9T9V9V(VWW"9<OPQPQPQSTSTSTVZPZ<["[$=@S$S!$y.NQ<N;U;U;W;W)WXX#;>RSTSTSTVWVWVWY]S]>^#^ %?BV%V"14KK&qqq!!!Tz25NqqqRSRSRSUYz5ZZ]aa
 
	 	 04NqqqRSRSRSUYz4Z]a4abGHHH%QQQ4Z(88!))*hRRMs$   EE	E0FFFc                   *     e Zd ZdZ fdZddZ xZS )MraEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    t                                                       t          j        |j        |j        |j                  | _        t          j        |j        dz   |j                  | _	        t          j        |j
        |j                  | _        t          j        |j        |j                  | _        t          j        |j                  | _        |                     dt%          j        |j                                      d          dz              t+          |dd          | _        |                     dt%          j        | j                                        t$          j        | j        j        	          d
           d S )N)padding_idxr9   epsposition_ids)r   r?   position_embedding_typeabsolutetoken_type_idsrR   F)
persistent)super__init__r   	Embedding
vocab_sizehidden_sizepad_token_idword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutregister_bufferrV   rW   expandgetattrr   r   r   rA   rX   rT   selfconfig	__class__s     r*   r   zMraEmbeddings.__init__  sQ   !|F,=v?Q_e_rsss#%<0NQR0RTZTf#g#g %'\&2H&J\%]%]" f&8f>STTTz&"<== 	^U\&:X-Y-Y-`-`ah-i-ilm-mnnn'.v7PR\']']$K)..00
4K\Kcddd 	 	
 	
 	
 	
 	
r,   Nc                    ||                                 }n|                                 d d         }|d         }|| j        d d d |f         }|mt          | d          r2| j        d d d |f         }|                    |d         |          }|}n+t          j        |t
          j        | j        j                  }|| 	                    |          }| 
                    |          }	||	z   }
| j        dk    r|                     |          }|
|z  }
|                     |
          }
|                     |
          }
|
S )Nr?   r   r   r   rR   r   )rA   r   hasattrr   r   rV   r   rX   rT   r   r   r   r   r   r   )r   	input_idsr   r   inputs_embedsinput_shape
seq_lengthbuffered_token_type_ids buffered_token_type_ids_expandedr   
embeddingsr   s               r*   r}   zMraEmbeddings.forward  sb    #..**KK',,..ss3K ^
,QQQ^<L
 !t-.. m*.*=aaa*n*M'3J3Q3QR]^_R`bl3m3m0!A!&[
SWSdSk!l!l!l  00;;M $ : :> J J"%::
':55"&":":<"H"H--J^^J//
\\*--
r,   )NNNNr   r   r   __doc__r   r}   __classcell__r  s   @r*   r   r     sR        QQ
 
 
 
 
(               r,   r   c                   .     e Zd Zd fd	Zd ZddZ xZS )MraSelfAttentionNc                 4   t                                                       |j        |j        z  dk    r0t	          |d          s t          d|j         d|j         d          t          d u}t                      rTt                      rF|sD	 t                       n4# t          $ r'}t                              d|            Y d }~nd }~ww xY w|j        | _        t          |j        |j        z            | _        | j        | j        z  | _        t!          j        |j        | j                  | _        t!          j        |j        | j                  | _        t!          j        |j        | j                  | _        t!          j        |j                  | _        ||n|j        | _        |j        dz  |j        z  | _        t9          | j        t          |j        dz  dz                      | _        |j        | _        |j        | _        |j        | _        d S )	Nr   embedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()zGCould not load the custom kernel for multi-scale deformable attention: r;   r9   ) r   r   r   num_attention_headsr  rB   r4   r   r   r6   r   loggerwarningrG   attention_head_sizeall_head_sizer   Linearr   r   r   r   attention_probs_dropout_probr   r   r   block_per_rowr^   r   r   r   r   )r   r   r   kernel_loadeder  s        r*   r   zMraSelfAttention.__init__  s.    ::a??PVXhHiHi?8F$6 8 8 48 8 8  
 (t3"$$ 	n);)=)= 	nm 	nn!#### n n nlijllmmmmmmmmn $*#= #&v'9F<V'V#W#W !58PPYv143EFF
9V/1CDDYv143EFF
z&"EFF'>'J##PVPn 	$ !8B>&BVVT^S&2PTV2V[\1\-]-]^^!-,2,O)/5/U,,,s   B 
C&CCc                     |                                 d d         | j        | j        fz   } |j        | }|                    dddd          S )Nr?   r   r9   r   r   )rA   r  r  viewpermute)r   layernew_layer_shapes      r*   transpose_for_scoresz%MraSelfAttention.transpose_for_scores9  sN    **,,ss+t/GIa.bb
O,}}Q1a(((r,   c           
      *   |                      |          }|                     |                     |                    }|                     |                     |                    }|                     |          }|                                \  }}}	}
d|dz  z   }|                                                    d|d                              ||z  |	                                          }d}|
|k     r|||	||
z
  f}t          j
        |t          j        ||j                  gd          }t          j
        |t          j        ||j                  gd          }t          j
        |t          j        ||j                  gd          }t          |                                |                                |                                |                                | j        | j        | j        | j                  }|
|k     r|d d d d d d d |
f         }|                    |||	|
          }|                    d	d
dd                                          }|                                d d         | j        fz   } |j        | }|f}|S )N      ?r   r   r;   r   r?   r=   )r   r   r   r   r9   r   r<   )r   r#  r   r   rA   squeezer   rY   rG   rV   catr   rT   r   r   r^   r   r   r   r   rF   r  r  )r   hidden_statesattention_maskmixed_query_layer	key_layervalue_layerquery_layerr\   	num_headsr]   r   gpu_warp_sizepad_sizer   new_context_layer_shapeoutputss                   r*   r}   zMraSelfAttention.forward>  s    JJ}55--dhh}.E.EFF	//

=0I0IJJ//0ABB3>3C3C3E3E0
Iw ~77""$$++Ay!<<DDZR[E[]deeiikk 	 m##!9g}x7OOH)[%+h{Oa2b2b2b$ciklllK	9ek(9K[.\.\.\"]cefffI)[%+h{Oa2b2b2b$ciklllK&OO  ""N()-)J,0,P	
 	
 	
 m##)!!!QQQ9H9*<=M%--j)WhWW%--aAq99DDFF"/"4"4"6"6ss";t?Q>S"S**,CD "r,   rz   )r   r   r   r   r#  r}   r  r  s   @r*   r  r    sg        !V !V !V !V !V !VF) ) )
0 0 0 0 0 0 0 0r,   r  c                   P     e Zd Z fdZdej        dej        dej        fdZ xZS )MraSelfOutputc                    t                                                       t          j        |j        |j                  | _        t          j        |j        |j                  | _        t          j        |j	                  | _
        d S Nr   )r   r   r   r  r   denser   r   r   r   r   r   s     r*   r   zMraSelfOutput.__init__s  sf    Yv163EFF
f&8f>STTTz&"<==r,   r(  input_tensorreturnc                     |                      |          }|                     |          }|                     ||z             }|S rz   r7  r   r   r   r(  r8  s      r*   r}   zMraSelfOutput.forwardy  @    

=11]33}|'CDDr,   r   r   r   r   rV   Tensorr}   r  r  s   @r*   r4  r4  r  i        > > > > >U\  RWR^        r,   r4  c                   .     e Zd Zd fd	Zd ZddZ xZS )MraAttentionNc                     t                                                       t          ||          | _        t	          |          | _        t                      | _        d S )N)r   )r   r   r  r   r4  r   setpruned_heads)r   r   r   r  s      r*   r   zMraAttention.__init__  sO    $VE\]]]	#F++EEr,   c                    t          |          dk    rd S t          || j        j        | j        j        | j                  \  }}t          | j        j        |          | j        _        t          | j        j        |          | j        _        t          | j        j	        |          | j        _	        t          | j
        j        |d          | j
        _        | j        j        t          |          z
  | j        _        | j        j        | j        j        z  | j        _        | j                            |          | _        d S )Nr   r   r=   )r@   r   r   r  r  rE  r   r   r   r   r   r7  r  union)r   headsindexs      r*   prune_headszMraAttention.prune_heads  s    u::??F7490$)2OQUQb
 
u
 -TY_eDD	*49=%@@	,TY_eDD	.t{/@%QOOO )-	(EE

(R	%"&)"?$)B_"_	 -33E::r,   c                     |                      ||          }|                     |d         |          }|f|dd          z   }|S Nr   r   )r   r   )r   r(  r)  self_outputsattention_outputr2  s         r*   r}   zMraAttention.forward  sH    yy??;;|AFF#%QRR(88r,   rz   )r   r   r   r   rJ  r}   r  r  s   @r*   rB  rB    s`        " " " " " "; ; ;$       r,   rB  c                   B     e Zd Z fdZdej        dej        fdZ xZS )MraIntermediatec                    t                                                       t          j        |j        |j                  | _        t          |j        t                    rt          |j                 | _        d S |j        | _        d S rz   )r   r   r   r  r   intermediate_sizer7  
isinstance
hidden_actstrr   intermediate_act_fnr   s     r*   r   zMraIntermediate.__init__  sn    Yv163KLL
f'-- 	9'-f.?'@D$$$'-'8D$$$r,   r(  r9  c                 Z    |                      |          }|                     |          }|S rz   )r7  rV  r   r(  s     r*   r}   zMraIntermediate.forward  s,    

=1100??r,   r>  r  s   @r*   rP  rP    s^        9 9 9 9 9U\ el        r,   rP  c                   P     e Zd Z fdZdej        dej        dej        fdZ xZS )	MraOutputc                    t                                                       t          j        |j        |j                  | _        t          j        |j        |j                  | _        t          j	        |j
                  | _        d S r6  )r   r   r   r  rR  r   r7  r   r   r   r   r   r   s     r*   r   zMraOutput.__init__  sf    Yv79KLL
f&8f>STTTz&"<==r,   r(  r8  r9  c                     |                      |          }|                     |          }|                     ||z             }|S rz   r;  r<  s      r*   r}   zMraOutput.forward  r=  r,   r>  r  s   @r*   rZ  rZ    r@  r,   rZ  c                   ,     e Zd Z fdZddZd Z xZS )MraLayerc                     t                                                       |j        | _        d| _        t	          |          | _        |j        | _        t          |          | _        t          |          | _
        d S Nr   )r   r   chunk_size_feed_forwardseq_len_dimrB  	attentionadd_cross_attentionrP  intermediaterZ  r   r   s     r*   r   zMraLayer.__init__  si    '-'E$%f--#)#= +F33''r,   Nc                     |                      ||          }|d         }|dd          }t          | j        | j        | j        |          }|f|z   }|S rL  )rc  r   feed_forward_chunkra  rb  )r   r(  r)  self_attention_outputsrN  r2  layer_outputs          r*   r}   zMraLayer.forward  sd    !%~!N!N1!4(,0#T%A4CSUe
 
  /G+r,   c                 \    |                      |          }|                     ||          }|S rz   )re  r   )r   rN  intermediate_outputri  s       r*   rg  zMraLayer.feed_forward_chunk  s2    "//0@AA{{#68HIIr,   rz   )r   r   r   r   r}   rg  r  r  s   @r*   r^  r^    s[        ( ( ( ( (         r,   r^  c                   .     e Zd Z fdZ	 	 	 	 ddZ xZS )
MraEncoderc                     t                                                       | _        t          j        fdt          j                  D                       | _        d| _        d S )Nc                 .    g | ]}t                    S r&   )r^  )r'   rh   r   s     r*   r+   z'MraEncoder.__init__.<locals>.<listcomp>  s!    #^#^#^HV$4$4#^#^#^r,   F)	r   r   r   r   
ModuleListrangenum_hidden_layersr!  gradient_checkpointingr   s    `r*   r   zMraEncoder.__init__  s`    ]#^#^#^#^eFD\>]>]#^#^#^__
&+###r,   NFTc                 <   |rdnd }t          | j                  D ]L\  }}|r||fz   }| j        r$| j        r|                     |j        ||          }	n |||          }	|	d         }M|r||fz   }|st          d ||fD                       S t          ||          S )Nr&   r   c              3      K   | ]}||V  	d S rz   r&   )r'   vs     r*   	<genexpr>z%MraEncoder.forward.<locals>.<genexpr>   s"      XXq!-----XXr,   )last_hidden_stater(  )	enumerater!  rs  training_gradient_checkpointing_func__call__tupler   )
r   r(  r)  	head_maskoutput_hidden_statesreturn_dictall_hidden_statesilayer_modulelayer_outputss
             r*   r}   zMraEncoder.forward  s     #7@BBD(44 	- 	-OA|# I$58H$H!* Lt} L $ A A )!"! ! !-]N K K)!,MM 	E 1]4D D 	YXX]4E$FXXXXXX1++
 
 
 	
r,   )NNFT)r   r   r   r   r}   r  r  s   @r*   rm  rm    sZ        , , , , , "!
 !
 !
 !
 !
 !
 !
 !
r,   rm  c                   B     e Zd Z fdZdej        dej        fdZ xZS )MraPredictionHeadTransformc                 V   t                                                       t          j        |j        |j                  | _        t          |j        t                    rt          |j                 | _
        n|j        | _
        t          j        |j        |j                  | _        d S r6  )r   r   r   r  r   r7  rS  rT  rU  r   transform_act_fnr   r   r   s     r*   r   z#MraPredictionHeadTransform.__init__	  s    Yv163EFF
f'-- 	6$*6+<$=D!!$*$5D!f&8f>STTTr,   r(  r9  c                     |                      |          }|                     |          }|                     |          }|S rz   )r7  r  r   rX  s     r*   r}   z"MraPredictionHeadTransform.forward  s=    

=11--m<<}55r,   r>  r  s   @r*   r  r    sc        U U U U UU\ el        r,   r  c                   *     e Zd Z fdZd Zd Z xZS )MraLMPredictionHeadc                 >   t                                                       t          |          | _        t	          j        |j        |j        d          | _        t	          j	        t          j        |j                            | _        | j        | j        _        d S )NF)bias)r   r   r  	transformr   r  r   r   decoder	ParameterrV   r   r  r   s     r*   r   zMraLMPredictionHead.__init__  sz    3F;; y!3V5FUSSSLV->!?!?@@	 !Ir,   c                 (    | j         | j        _         d S rz   )r  r  r   s    r*   _tie_weightsz MraLMPredictionHead._tie_weights(  s     Ir,   c                 Z    |                      |          }|                     |          }|S rz   )r  r  rX  s     r*   r}   zMraLMPredictionHead.forward+  s*    }55]33r,   )r   r   r   r   r  r}   r  r  s   @r*   r  r    sV        & & & & && & &      r,   r  c                   B     e Zd Z fdZdej        dej        fdZ xZS )MraOnlyMLMHeadc                 p    t                                                       t          |          | _        d S rz   )r   r   r  predictionsr   s     r*   r   zMraOnlyMLMHead.__init__3  s/    .v66r,   sequence_outputr9  c                 0    |                      |          }|S rz   )r  )r   r  prediction_scoress      r*   r}   zMraOnlyMLMHead.forward7  s     ,,_==  r,   r>  r  s   @r*   r  r  2  s^        7 7 7 7 7!u| ! ! ! ! ! ! ! ! !r,   r  c                   $    e Zd ZdZeZdZdZd ZdS )MraPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    r#   Tc                    t          |t          j                  rT|j        j                            d| j        j                   |j         |j        j        	                                 dS dS t          |t          j
                  r_|j        j                            d| j        j                   |j        +|j        j        |j                 	                                 dS dS t          |t          j                  r?|j        j        	                                 |j        j                            d           dS dS )zInitialize the weightsg        )r   stdNr%  )rS  r   r  weightdatanormal_r   initializer_ranger  zero_r   r   r   fill_)r   modules     r*   _init_weightsz MraPreTrainedModel._init_weightsG  s)   fbi(( 	* M&&CT[5R&SSS{& &&((((( '&-- 	*M&&CT[5R&SSS!-"6#56<<>>>>> .--- 	*K""$$$M$$S)))))	* 	*r,   N)	r   r   r   r  r   config_classbase_model_prefixsupports_gradient_checkpointingr  r&   r,   r*   r  r  =  s@         
 L&*#* * * * *r,   r  aF  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`MraConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
ak	  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z]The bare MRA Model transformer outputting raw hidden-states without any specific head on top.c                       e Zd Z fdZd Zd Zd Z ee	                    d                     e
eee          	 	 	 	 	 	 	 	 ddeej                 d	eej                 d
eej                 deej                 deej                 deej                 dee         dee         deeef         fd                        Z xZS )MraModelc                     t                                          |           || _        t          |          | _        t          |          | _        |                                  d S rz   )r   r   r   r   r
  rm  encoder	post_initr   s     r*   r   zMraModel.__init__  sX       '//!&)) 	r,   c                     | j         j        S rz   r
  r   r  s    r*   get_input_embeddingszMraModel.get_input_embeddings  s    ..r,   c                     || j         _        d S rz   r  )r   r   s     r*   set_input_embeddingszMraModel.set_input_embeddings  s    */'''r,   c                     |                                 D ]/\  }}| j        j        |         j                            |           0dS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  r!  rc  rJ  )r   heads_to_pruner!  rH  s       r*   _prune_headszMraModel._prune_heads  sU    
 +0022 	C 	CLE5Lu%/;;EBBBB	C 	Cr,   batch_size, sequence_length
checkpointoutput_typer  Nr  r)  r   r   r~  r  r  r  r9  c	                    ||n| j         j        }||n| j         j        }||t          d          |+|                     ||           |                                }	n.||                                d d         }	nt          d          |	\  }
}||j        n|j        }|t          j        |
|f|          }|gt          | j
        d          r1| j
        j        d d d |f         }|                    |
|          }|}n!t          j        |	t          j        |          }|                     ||	          }|                     || j         j                  }| 
                    ||||          }|                     |||||          }|d	         }|s|f|d
d          z   S t'          ||j        |j        |j                  S )NzDYou cannot specify both input_ids and inputs_embeds at the same timer?   z5You have to specify either input_ids or inputs_embedsr   r   rR   )r  r   r   r  )r)  r~  r  r  r   r   )rx  r(  
attentionscross_attentions)r   r  use_return_dictrB   %warn_if_padding_and_no_attention_maskrA   rT   rV   r   r  r
  r   r   r   rX   get_extended_attention_maskget_head_maskrr  r  r   r(  r  r  )r   r  r)  r   r   r~  r  r  r  r  r\   r  rT   r  r	  extended_attention_maskembedding_outputencoder_outputsr  s                      r*   r}   zMraModel.forward  s2   $ %9$D  $+Jj 	 &1%<kk$+B] ]%>cddd"66y.QQQ#..**KK&',,..ss3KKTUUU!,
J%.%:!!@T!"Z*j)A6RRRN!t(899 [*./*HKZK*X'3J3Q3QR\^h3i3i0!A!&[
SY!Z!Z!Z 150P0PQ_al0m0m &&y$+2OPP	??%)'	 + 
 
 ,,2!5# ' 
 
 *!, 	<#%(;;;1-)7&1,=	
 
 
 	
r,   )NNNNNNNN)r   r   r   r   r  r  r  r   MRA_INPUTS_DOCSTRINGformatr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOCr   rV   r?  boolr   r   r}   r  r  s   @r*   r  r    s       
    / / /0 0 0C C C +*+?+F+FGd+e+eff&6$   -11515/3,004/3&*J
 J
EL)J
 !.J
 !.	J

 u|,J
 EL)J
  -J
 'tnJ
 d^J
 
u88	9J
 J
 J
  gfJ
 J
 J
 J
 J
r,   r  z1MRA Model with a `language modeling` head on top.c                       e Zd ZddgZ fdZd Zd Z ee	                    d                     e
eee          	 	 	 	 	 	 	 	 	 dd	eej                 d
eej                 deej                 deej                 deej                 deej                 deej                 dee         dee         deeef         fd                        Z xZS )MraForMaskedLMzcls.predictions.decoder.weightzcls.predictions.decoder.biasc                     t                                          |           t          |          | _        t	          |          | _        |                                  d S rz   )r   r   r  r#   r  clsr  r   s     r*   r   zMraForMaskedLM.__init__  sQ       F##!&)) 	r,   c                 $    | j         j        j        S rz   )r  r  r  r  s    r*   get_output_embeddingsz$MraForMaskedLM.get_output_embeddings  s    x#++r,   c                 T    || j         j        _        |j        | j         j        _        d S rz   )r  r  r  r  )r   new_embeddingss     r*   set_output_embeddingsz$MraForMaskedLM.set_output_embeddings  s%    '5$$2$7!!!r,   r  r  Nr  r)  r   r   r~  r  labelsr  r  r9  c
           
         |	|	n| j         j        }	|                     ||||||||	          }
|
d         }|                     |          }d}|Kt	                      } ||                    d| j         j                  |                    d                    }|	s|f|
dd         z   }||f|z   n|S t          |||
j        |
j	                  S )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        Nr)  r   r   r~  r  r  r  r   r?   r   losslogitsr(  r  )
r   r  r#   r  r	   r  r   r   r(  r  )r   r  r)  r   r   r~  r  r  r  r  r2  r  r  masked_lm_lossloss_fctr   s                   r*   r}   zMraForMaskedLM.forward  s   0 &1%<kk$+B](())%'!5#  	
 	
 "!* HH_55'))H%X&7&<&<RAW&X&XZ`ZeZefhZiZijjN 	Z')GABBK7F3A3M^%..SYY$!/)	
 
 
 	
r,   	NNNNNNNNN)r   r   r   _tied_weights_keysr   r  r  r   r  r  r   r  r   r  r   rV   r?  r  r   r   r}   r  r  s   @r*   r  r    s       :<Z[    , , ,8 8 8 +*+?+F+FGd+e+eff&"$   -11515/3,004)-/3&*0
 0
EL)0
 !.0
 !.	0

 u|,0
 EL)0
  -0
 &0
 'tn0
 d^0
 
un$	%0
 0
 0
  gf0
 0
 0
 0
 0
r,   r  c                   (     e Zd ZdZ fdZd Z xZS )MraClassificationHeadz-Head for sentence-level classification tasks.c                 "   t                                                       t          j        |j        |j                  | _        t          j        |j                  | _        t          j        |j        |j	                  | _
        || _        d S rz   )r   r   r   r  r   r7  r   r   r   
num_labelsout_projr   r   s     r*   r   zMraClassificationHead.__init__S  sj    Yv163EFF
z&"<==	&"4f6GHHr,   c                 
   |d d dd d f         }|                      |          }|                     |          }t          | j        j                 |          }|                      |          }|                     |          }|S )Nr   )r   r7  r   r   rT  r  )r   featureskwargsxs       r*   r}   zMraClassificationHead.forward[  st    QQQ111WLLOOJJqMM4;)*1--LLOOMM!r,   r  r  s   @r*   r  r  P  sM        77          r,   r  zMRA Model transformer with a sequence classification/regression head on top (a linear layer on top of
    the pooled output) e.g. for GLUE tasks.c                       e Zd Z fdZ ee                    d                     eee	e
          	 	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 deej                 d	eej                 d
eej                 deej                 dee         dee         deee	f         fd                        Z xZS )MraForSequenceClassificationc                     t                                          |           |j        | _        t          |          | _        t          |          | _        |                                  d S rz   )r   r   r  r  r#   r  
classifierr  r   s     r*   r   z%MraForSequenceClassification.__init__k  s[        +F##/77 	r,   r  r  Nr  r)  r   r   r~  r  r  r  r  r9  c
           
         |	|	n| j         j        }	|                     ||||||||	          }
|
d         }|                     |          }d}|Z| j         j        f| j        dk    rd| j         _        nN| j        dk    r7|j        t          j        k    s|j        t          j	        k    rd| j         _        nd| j         _        | j         j        dk    rWt                      }| j        dk    r1 ||                                |                                          }n |||          }n| j         j        dk    rGt                      } ||                    d| j                  |                    d                    }n*| j         j        dk    rt                      } |||          }|	s|f|
dd         z   }||f|z   n|S t          |||
j        |
j        	          S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr  r   r   
regressionsingle_label_classificationmulti_label_classificationr?   r  )r   r  r#   r  problem_typer  rS   rV   rX   rG   r
   r&  r	   r  r   r   r(  r  )r   r  r)  r   r   r~  r  r  r  r  r2  r  r  r  r  r   s                   r*   r}   z$MraForSequenceClassification.forwardt  s   0 &1%<kk$+B](())%'!5#  	
 	
 "!*11{'/?a''/;DK,,_q((flej.H.HFL\a\eLeLe/LDK,,/KDK,{'<77"99?a''#8FNN$4$4fnn6F6FGGDD#8FF33DD)-JJJ+--xB @ @&++b//RR)-III,..x// 	FY,F)-)9TGf$$vE'!/)	
 
 
 	
r,   r  )r   r   r   r   r   r  r  r   r  r   r  r   rV   r?  r  r   r   r}   r  r  s   @r*   r  r  e  sx            +*+?+F+FGd+e+eff&,$   -11515/3,004)-/3&*A
 A
EL)A
 !.A
 !.	A

 u|,A
 EL)A
  -A
 &A
 'tnA
 d^A
 
u..	/A
 A
 A
  gfA
 A
 A
 A
 A
r,   r  zMRA Model with a multiple choice classification head on top (a linear layer on top of
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks.c                       e Zd Z fdZ ee                    d                     eee	e
          	 	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 deej                 d	eej                 d
eej                 deej                 dee         dee         deee	f         fd                        Z xZS )MraForMultipleChoicec                     t                                          |           t          |          | _        t	          j        |j        |j                  | _        t	          j        |j        d          | _        | 	                                 d S r`  )
r   r   r  r#   r   r  r   pre_classifierr  r  r   s     r*   r   zMraForMultipleChoice.__init__  sr       F## i(:F<NOO)F$6:: 	r,   z(batch_size, num_choices, sequence_lengthr  Nr  r)  r   r   r~  r  r  r  r  r9  c
           
         |	|	n| j         j        }	||j        d         n|j        d         }
|)|                    d|                    d                    nd}|)|                    d|                    d                    nd}|)|                    d|                    d                    nd}|)|                    d|                    d                    nd}|=|                    d|                    d          |                    d                    nd}|                     ||||||||	          }|d         }|dddf         }|                     |          } t          j                    |          }| 	                    |          }|                    d|
          }d}|t                      } |||          }|	s|f|dd         z   }||f|z   n|S t          |||j        |j                  S )aJ  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr   r?   r<   r  r   r  )r   r  rU   r  rA   r#   r  r   ReLUr  r	   r   r(  r  )r   r  r)  r   r   r~  r  r  r  r  num_choicesr2  hidden_statepooled_outputr  reshaped_logitsr  r  r   s                      r*   r}   zMraForMultipleChoice.forward  sP   0 &1%<kk$+B],5,Aioa((}GZ[\G]>G>SINN2y~~b'9'9:::Y]	M[Mg,,R1D1DR1H1HIIImqM[Mg,,R1D1DR1H1HIIImqGSG_|((\->->r-B-BCCCei ( r=#5#5b#9#9=;M;Mb;Q;QRRR 	 (())%'!5#  	
 	
 qz$QQQT*++M::!		-00// ++b+66'))H8OV44D 	F%''!""+5F)-)9TGf$$vE("!/)	
 
 
 	
r,   r  )r   r   r   r   r   r  r  r   r  r   r  r   rV   r?  r  r   r   r}   r  r  s   @r*   r  r    sx            +*+?+F+FGq+r+rss&-$   -11515/3,004)-/3&*@
 @
EL)@
 !.@
 !.	@

 u|,@
 EL)@
  -@
 &@
 'tn@
 d^@
 
u//	0@
 @
 @
  ts@
 @
 @
 @
 @
r,   r  zMRA Model with a token classification head on top (a linear layer on top of
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.c                       e Zd Z fdZ ee                    d                     eee	e
          	 	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 deej                 d	eej                 d
eej                 deej                 dee         dee         deee	f         fd                        Z xZS )MraForTokenClassificationc                 6   t                                          |           |j        | _        t          |          | _        t          j        |j                  | _        t          j	        |j
        |j                  | _        |                                  d S rz   )r   r   r  r  r#   r   r   r   r   r  r   r  r  r   s     r*   r   z"MraForTokenClassification.__init__  sy        +F##z&"<==)F$68IJJ 	r,   r  r  Nr  r)  r   r   r~  r  r  r  r  r9  c
           
         |	|	n| j         j        }	|                     ||||||||	          }
|
d         }|                     |          }|                     |          }d}|t                      }||                    d          dk    }|                    d| j                  }t          j	        ||                    d          t          j
        |j                                      |                    } |||          }n8 ||                    d| j                  |                    d                    }|	s|f|
dd         z   }||f|z   n|S t          |||
j        |
j                  S )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr  r   r?   r   r  )r   r  r#   r   r  r	   r  r  rV   wheretensorignore_indextype_asr   r(  r  )r   r  r)  r   r   r~  r  r  r  r  r2  r  r  r  r  active_lossactive_logitsactive_labelsr   s                      r*   r}   z!MraForTokenClassification.forward(  s   , &1%<kk$+B](())%'!5#  	
 	
 "!*,,7711'))H),11"55: &B @ @ %R%,x?T2U2U2]2]^d2e2e! !  x}==xB @ @&++b//RR 	FY,F)-)9TGf$$vE$!/)	
 
 
 	
r,   r  )r   r   r   r   r   r  r  r   r  r   r  r   rV   r?  r  r   r   r}   r  r  s   @r*   r  r    se       	 	 	 	 	 +*+?+F+FGd+e+eff&)$   -11515/3,004)-/3&*9
 9
EL)9
 !.9
 !.	9

 u|,9
 EL)9
  -9
 &9
 'tn9
 d^9
 
u++	,9
 9
 9
  gf9
 9
 9
 9
 9
r,   r  zMRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).c                       e Zd Z fdZ ee                    d                     eee	e
          	 	 	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 deej                 d	eej                 d
eej                 deej                 deej                 dee         dee         deee	f         fd                        Z xZS )MraForQuestionAnsweringc                    t                                          |           d|_        |j        | _        t          |          | _        t          j        |j        |j                  | _        | 	                                 d S )Nr9   )
r   r   r  r  r#   r   r  r   
qa_outputsr  r   s     r*   r   z MraForQuestionAnswering.__init__p  sm        +F##)F$68IJJ 	r,   r  r  Nr  r)  r   r   r~  r  start_positionsend_positionsr  r  r9  c           
      f   |
|
n| j         j        }
|                     |||||||	|
          }|d         }|                     |          }|                    dd          \  }}|                    d          }|                    d          }d}||t          |                                          dk    r|                    d          }t          |                                          dk    r|                    d          }|                    d          }|                    d|          }|                    d|          }t          |          } |||          } |||          }||z   dz  }|
s||f|dd         z   }||f|z   n|S t          ||||j        |j        	          S )
a  
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        Nr  r   r   r?   r=   )r  r9   )r  start_logits
end_logitsr(  r  )r   r  r#   r	  splitr&  r@   rA   clampr	   r   r(  r  )r   r  r)  r   r   r~  r  r
  r  r  r  r2  r  r  r  r  
total_lossignored_indexr  
start_lossend_lossr   s                         r*   r}   zMraForQuestionAnswering.forward|  s   : &1%<kk$+B](())%'!5#  	
 	
 "!*11#)<<r<#:#: j#++B//''++

&=+D?''))**Q.."1"9"9""="==%%''((1,, - 5 5b 9 9(--a00M-33A}EEO)//=AAM']CCCH!,@@Jx
M::H$x/14J 	R"J/'!""+=F/9/EZMF**6Q+%!!/)
 
 
 	
r,   )
NNNNNNNNNN)r   r   r   r   r   r  r  r   r  r   r  r   rV   r?  r  r   r   r}   r  r  s   @r*   r  r  j  s       
 
 
 
 
 +*+?+F+FGd+e+eff&0$   -11515/3,0042604/3&*F
 F
EL)F
 !.F
 !.	F

 u|,F
 EL)F
  -F
 "%,/F
  -F
 'tnF
 d^F
 
u22	3F
 F
 F
  gfF
 F
 F
 F
 F
r,   r  r   )NN)r;   r   r   )Tr  r   pathlibr   typingr   r   r   rV   torch.utils.checkpointr   torch.nnr   r	   r
   torch.utils.cpp_extensionr   activationsr   modeling_outputsr   r   r   r   r   r   modeling_utilsr   pytorch_utilsr   r   r   utilsr   r   r   r   r   r   configuration_mrar   
get_loggerr   r  r  r  _TOKENIZER_FOR_DOCr4   r6   rP   r`   rd   rl   rv   autogradFunctionrx   r   r   r   r   r   Moduler   r  r4  rB  rP  rZ  r^  rm  r  r  r  r  MRA_START_DOCSTRINGr  r  r  r  r  r  r  r  r&   r,   r*   <module>r&     s            ) ) ) ) ) ) ) ) ) )            A A A A A A A A A A * * * * * * ! ! ! ! ! !                . - - - - - l l l l l l l l l l                ) ( ( ( ( ( 
	H	%	%1 $  	C 	C 	C& & &8   .%O %O %O %OP% % % %Ps s sX X X X XEN3 X X X0] ] ] ] ]5>2 ] ] ].       :%V %V %V %VP#) #) #)Z !"$%p p p pf7 7 7 7 7BI 7 7 7tY Y Y Y Yry Y Y Yz    BI       29   B    bi        	       ry   :(
 (
 (
 (
 (
 (
 (
 (
X       $    ")   0! ! ! ! !RY ! ! !* * * * * * * *6	 , ^ c i
 i
 i
 i
 i
! i
 i
	 i
X MObccI
 I
 I
 I
 I
' I
 I
 dcI
Z    BI   * / 
Q
 Q
 Q
 Q
 Q
#5 Q
 Q
 
Q
h H 
Q
 Q
 Q
 Q
 Q
- Q
 Q
 
Q
h P 
K
 K
 K
 K
 K
 2 K
 K
 
K
\ h 
Y
 Y
 Y
 Y
 Y
0 Y
 Y
 
Y
 Y
 Y
r,   