
    g                         d Z ddlZddlmZmZmZmZ ddlZddlZddlm	Z	 ddl
mZ ddlmZ ddlmZ dd	lmZmZ dd
lmZmZ ddlmZ ddlmZmZmZmZ ddlmZ  ej        e           Z!dZ"dZ#dZ$dZ% G d de	j&                  Z' G d de	j(                  Z) G d de	j(                  Z* G d de	j(                  Z+ G d de          Z, ede$           G d de,                      Z- ed e$           G d! d"e,e                      Z.dS )#zPyTorch XGLM model.    N)ListOptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN)GenerationMixin)_prepare_4d_attention_mask!_prepare_4d_causal_attention_mask))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions)PreTrainedModel)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )
XGLMConfigzfacebook/xglm-564Mr   aI  
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`XGLMConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
            the decoder.
        encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
            Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
            selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c            
       \     e Zd ZdZd
dedededee         f fdZdej	        f fd	Z
 xZS )XGLMScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
          ?num_embeddingsembedding_dimpadding_idxembed_scalec                 \    t                                          |||           || _        d S N)super__init__r   )selfr   r   r   r   	__class__s        b/var/www/html/ai-engine/env/lib/python3.11/site-packages/transformers/models/xglm/modeling_xglm.pyr!   z XGLMScaledWordEmbedding.__init__   s-    DDD&    	input_idsc                 V    t                                          |          | j        z  S r   )r    forwardr   )r"   r&   r#   s     r$   r(   zXGLMScaledWordEmbedding.forward   s!    wwy))D,<<<r%   )r   )__name__
__module____qualname____doc__intr   floatr!   torchTensorr(   __classcell__r#   s   @r$   r   r      s         ' 's '3 'S '_ghm_n ' ' ' ' ' '= = = = = = = = = = =r%   r   c            	            e Zd ZdZddededee         f fdZddededee         fdZeddededee         fd	            Z	 e
j                    dde
j        defd            Z xZS )!XGLMSinusoidalPositionalEmbeddingzDThis module produces sinusoidal positional embeddings of any length.Nnum_positionsr   r   c                     t                                                       d| _        || _        || _        |                     || j        z   ||           d S )N   )r    r!   offsetr   r   make_weights)r"   r5   r   r   r#   s       r$   r!   z*XGLMSinusoidalPositionalEmbedding.__init__   sU    *&-$+5}kRRRRRr%   r   c                     |                      |||          }t          | d          r+|                    | j        j        | j        j                  }|                     d|d           d S )NweightsdtypedeviceF)
persistent)get_embeddinghasattrtor;   r=   r>   register_buffer)r"   r   r   r   emb_weightss        r$   r9   z.XGLMSinusoidalPositionalEmbedding.make_weights   sl    ((TT4## 	_%..t|/A$,J].^^KYFFFFFr%   c                    |dz  }t          j        d          |dz
  z  }t          j        t          j        |t          j                                                  | z            }t          j        | t          j                                                                      d          |                    d          z  }t          j        t          j	        |          t          j
        |          gd                              | d          }|dz  dk    r+t          j        |t          j        | d          gd          }|	d||ddf<   |                    t          j                              S )	z
        Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
        "Attention Is All You Need".
        r7   i'  r   )r=   r   dimN)mathlogr/   exparangeint64r.   	unsqueezecatsincosviewzerosrB   get_default_dtype)r   r   r   half_dimembs        r$   r@   z/XGLMSinusoidalPositionalEmbedding.get_embedding   s?    !A%huooA.iXU[AAAGGIISDPQQl>===CCEEOOPQRRUXUbUbcdUeUeei338a@@@EEnVXYY1!!)S%+na"@"@AqIIIC""#CQQQvve-//000r%   r   position_idspast_key_values_lengthc                    |                                 \  }}|| j        z  }d|z   |z   }|| j                             d          k    r!|                     || j        | j                   | j                            d|                    d                                        ||| j        j        d                   	                                S )Nr7   r   rH   )
sizer8   r;   r9   r   r   index_selectrR   shapedetach)r"   rW   rX   bszseq_lenmax_poss         r$   r(   z)XGLMSinusoidalPositionalEmbedding.forward   s    #((**W# g+ 66T\&&q))))gt'94;KLLL|((L,=,=b,A,ABBGGWVZVbVhikVlmmttvvvr%   r   )Nr   )r)   r*   r+   r,   r-   r   r!   r9   staticmethodr@   r/   no_gradr0   r(   r1   r2   s   @r$   r4   r4      s1       NNS Sc S# SHUXM S S S S S SG G3 Gs GQYZ]Q^ G G G G 1 1c 1# 1HUXM 1 1 1 \1( U]__	w 	wEL 	wQT 	w 	w 	w _	w 	w 	w 	w 	wr%   r4   c                   P    e Zd ZdZ	 	 	 ddedededed	ef
 fd
Zdej	        dedefdZ
	 	 	 	 	 ddej	        deej	                 deeej	                          deej	                 deej	                 dedeej	        eej	                 eeej	                          f         fdZ xZS )XGLMAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        FT	embed_dim	num_headsdropout
is_decoderbiasc                    t                                                       || _        || _        || _        ||z  | _        | j        |z  | j        k    rt          d| j         d| d          | j        dz  | _        || _        t          j
        |||          | _        t          j
        |||          | _        t          j
        |||          | _        t          j
        |||          | _        d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      ࿩rj   )r    r!   rf   rg   rh   head_dim
ValueErrorscalingri   r   Lineark_projv_projq_projout_proj)r"   rf   rg   rh   ri   rj   r#   s         r$   r!   zXGLMAttention.__init__   s    	""!Y.MI%$.883dn 3 3%.3 3 3   }d*$i	94@@@i	94@@@i	94@@@	)YTBBBr%   tensorr_   r^   c                     |                     ||| j        | j                                      dd                                          S )Nr   r7   )rR   rg   rm   	transpose
contiguous)r"   ru   r_   r^   s       r$   _shapezXGLMAttention._shape   s<    {{3GGQQRSUVWWbbdddr%   Nhidden_stateskey_value_statespast_key_valueattention_masklayer_head_maskoutput_attentionsreturnc                 d   |du}|                                 \  }}	}
|                     |          | j        z  }|r||d         }|d         }n>|rU|                     |                     |          d|          }|                     |                     |          d|          }n||                     |                     |          d|          }|                     |                     |          d|          }t          j        |d         |gd          }t          j        |d         |gd          }nT|                     |                     |          d|          }|                     |                     |          d|          }| j        r||f}|| j	        z  d| j
        f} |                     ||	|          j        | } |j        | } |j        | }|                     d          }t          j        ||                    dd                    }|                                 || j	        z  |	|fk    r2t          d|| j	        z  |	|f d|                                            ||                                 |d|	|fk    r+t          d	|d|	|f d|                                            |                    || j	        |	|          |z   }t          j        |t          j        t          j        |j                  j        |j        
                    }|                    || j	        z  |	|          }|j        t          j        k    rJt,          j                            |dt          j                                      t          j                  }n!t,          j                            |d          }||                                 | j	        fk    r-t          d| j	        f d|                                            |                    dddd          |                    || j	        |	|          z  }|                    || j	        z  |	|          }|r=|                    || j	        |	|          }|                    || j	        z  |	|          }nd}t,          j                            || j        | j                  }t          j        ||          }|                                 || j	        z  |	| j
        fk    r5t          d|| j	        |	| j
        f d|                                            |                    || j	        |	| j
                  }|                    dd          }|                    ||	| j                  }|                     |          }|||fS )z#Input shape: Batch x Time x ChannelNr   r   rH   r7   rF   z$Attention weights should be of size z	, but is z!Attention mask should be of size )r>   )rG   r=   z/Head mask for a single layer should be of size ptrainingz `attn_output` should be of size ) rZ   rs   ro   ry   rq   rr   r/   rO   ri   rg   rm   rR   bmmrw   rn   maxru   finfor=   minr>   float16r   
functionalsoftmaxfloat32rB   rh   r   reshaperf   rt   )r"   rz   r{   r|   r}   r~   r   is_cross_attentionr^   tgt_len_query_states
key_statesvalue_states
proj_shapesrc_lenattn_weightsattn_weights_reshaped
attn_probsattn_outputs                       r$   r(   zXGLMAttention.forward   s    .T9',,..Wa {{=11DL@ 	L."<'*J)!,LL 	LT[[1A%B%BBLLJ;;t{{3C'D'Db#NNLL'T[[%?%?SIIJ;;t{{='A'A2sKKLN1$5z#BJJJJ 9nQ&7%FANNNLL T[[%?%?SIIJ;;t{{='A'A2sKKL? 	8 ),7NDN*B>
Ct{{<#>>CZP$Z_j1
(|(*5//!$$yz/C/CAq/I/IJJ3#7'"JJJ*dn8LgW^7_ * * %%''* *  
 %""$$a'(BBB ta'8Rtt]k]p]p]r]rtt   (,,S$.'7SSVddL 9el5;|7I+J+J+NWcWjkkk L (,,S4>-A7GTTL ..=002U]0[[^^_d_lmmLL=0020FFL&##%%$.)::: 1t~FW 1 1',,..1 1   +//2q!<<|?P?PQTVZVdfmov?w?wwL',,S4>-A7GTTL 	)
 %1$5$5c4>7T[$\$\!055cDN6JGU\]]LL$(!]**<4<RVR_*``
i
L99#"6!OOO)CRVR_3` ) )$$&&) )  
 "&&sDNGT]SS!++Aq11 "))#wGGmmK001>AAr%   )re   FT)NNNNF)r)   r*   r+   r,   r-   r.   boolr!   r/   r0   ry   r   r   r(   r1   r2   s   @r$   rd   rd      s       GG  C CC C 	C
 C C C C C C C6eU\ eC ec e e e e 488<1526"'vB vB|vB #5<0vB !u|!45	vB
 !.vB "%,/vB  vB 
u|Xel3XeEL>Q5RR	SvB vB vB vB vB vB vB vBr%   rd   c                   "    e Zd Zdef fdZ	 	 	 	 	 	 	 	 ddej        deej                 deej                 d	eej                 d
eej                 deej                 deeej                          dee	         dee	         dej        fdZ
 xZS )XGLMDecoderLayerconfigc                    t                                                       |j        | _        t	          | j        |j        |j        d          | _        |j        | _        t          |j
                 | _        |j        | _        |j        rEt	          | j        |j        |j        d          | _        t          j        | j                  | _        t          j        | j                  | _        t          j        | j        |j                  | _        t          j        |j        | j                  | _        t          j        | j                  | _        d S )NT)rf   rg   rh   ri   )r    r!   d_modelrf   rd   attention_headsattention_dropout	self_attnrh   r
   activation_functionactivation_fnactivation_dropoutadd_cross_attentionencoder_attnr   	LayerNormencoder_attn_layer_normself_attn_layer_normrp   ffn_dimfc1fc2final_layer_normr"   r   r#   s     r$   r!   zXGLMDecoderLayer.__init__\  s   &n,,	
 
 
 ~#F$>?"(";% 	H -. 00	! ! !D ,.<+G+GD($&L$@$@!9T^V^<<9V^T^<< "T^ < <r%   NFTrz   r}   encoder_hidden_statesencoder_attention_maskr~   cross_attn_layer_head_maskr|   r   	use_cacher   c
                 x   |}
|                      |          }|
|dd         nd}|                     |||||          \  }}}t          j                            || j        | j                  }|
|z   }d}d}|z|}
|                     |          }|
|dd         nd}|                     ||||||          \  }}}t          j                            || j        | j                  }|
|z   }||z   }|}
|                     |          }| 	                    | 
                    |                    }t          j                            || j        | j                  }|                     |          }t          j                            || j        | j                  }|
|z   }|f}|r|||fz  }|	r||fz  }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size `(decoder_attention_heads,)`.
            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        Nr7   )rz   r|   r}   r~   r   r   )rz   r{   r}   r~   r|   r   )r   r   r   r   rh   r   r   r   r   r   r   r   r   )r"   rz   r}   r   r   r~   r   r|   r   r   residualself_attn_past_key_valueself_attn_weightspresent_key_valuecross_attn_present_key_valuecross_attn_weightscross_attn_past_key_valueoutputss                     r$   r(   zXGLMDecoderLayer.forwardy  s,   < !11-@@ :H9S>"1"#5#5Y] >Bnn'3)+/ ?M ?
 ?
;(*; --mt|VZVc-dd =0 (,$! ,$H 88GGM @N?Yrss(;(;_c%NRN_N_+!65 :8"3 O` O OKM-/K M11-4<Z^Zg1hhM$}4M !24P P !--m<<**488M+B+BCC--mt?Vaean-oo//--mt|VZVc-dd =0 " 	?)+=>>G 	,)++Gr%   )NNNNNNFT)r)   r*   r+   r   r!   r/   r0   r   r   r   r(   r1   r2   s   @r$   r   r   [  s(       =z = = = = = =@ 268<9=26=A8<,1$(W W|W !.W  (5	W
 !) 6W "%,/W %-U\$:W !u|!45W $D>W D>W 
W W W W W W W Wr%   r   c                   &    e Zd ZeZdZdZdgZd ZdS )XGLMPreTrainedModelmodelTr   c                    | j         j        }t          |t          j                  rJ|j        j                            d|           |j         |j        j        	                                 d S d S t          |t          j
                  rS|j        j                            d|           |j        -|j        j        |j                 	                                 d S d S d S )Nre   )meanstd)r   init_std
isinstancer   rp   weightdatanormal_rj   zero_	Embeddingr   )r"   moduler   s      r$   _init_weightsz!XGLMPreTrainedModel._init_weights  s    k"fbi(( 	?M&&CS&999{& &&((((( '&-- 	?M&&CS&999!-"6#56<<>>>>>	? 	?--r%   N)	r)   r*   r+   r   config_classbase_model_prefixsupports_gradient_checkpointing_no_split_modulesr    r%   r$   r   r     s=        L&*#+,	? 	? 	? 	? 	?r%   r   z^The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.c            !           e Zd ZdZddedeej                 f fdZd Z	d Z
 ee           eeee          	 	 	 	 	 	 	 	 	 	 	 	 	 dd	eej                 d
eej                 deej                 deej                 deej                 deej                 deej                 deeej                          deej                 dee         dee         dee         dee         deeej                 ef         fd                        Z xZS )	XGLMModelz
    Transformer decoder consisting of *config.num_layers* layers. Each layer is a [`XGLMDecoderLayer`]

    Args:
        config: XGLMConfig
        embed_tokens (nn.Embedding): output embedding
    Nr   embed_tokensc                 |   t                                                     j        | _        j        | _        j        | _        j        | _        j        rt          j
        j                  nd}||| _        n't          j        j        | j        |          | _        t          j        j        j                  | _        t#          j        fdt'          j                  D                       | _        t#          j        j                  | _        d| _        |                                  d S )Nr   )r   c                 .    g | ]}t                    S r   )r   ).0r   r   s     r$   
<listcomp>z&XGLMModel.__init__.<locals>.<listcomp>  s"    $`$`$`!%5f%=%=$`$`$`r%   F)r    r!   rh   	layerdroppad_token_idr   max_position_embeddingsmax_target_positionsscale_embeddingrI   sqrtr   r   r   
vocab_sizer4   embed_positionsr   
ModuleListrange
num_layerslayersr   
layer_normgradient_checkpointing	post_init)r"   r   r   r   r#   s    `  r$   r!   zXGLMModel.__init__  s'      ~)!.$*$B!393IRdi///s# ,D 7!6>43CQ\! ! !D  A*N 
  

 m$`$`$`$`uVM^G_G_$`$`$`aa,v~66&+#r%   c                     | j         S r   r   r"   s    r$   get_input_embeddingszXGLMModel.get_input_embeddings  s      r%   c                     || _         d S r   r   r"   values     r$   set_input_embeddingszXGLMModel.set_input_embeddings  s    !r%   
checkpointoutput_typer   r&   r}   rW   r   r   	head_maskcross_attn_head_maskpast_key_valuesinputs_embedsr   r   output_hidden_statesreturn_dictr   c                    ||n| j         j        }||n| j         j        }|
|
n| j         j        }
||n| j         j        }||	t          d          |G|                     ||           |                                }|                    d|d                   }n.|	|	                                d d         }nt          d          ||d         d         j	        d         nd}|Nt          j        ||d         |z   t          j        ||j        n|	j                  }|                    d          }|	|                     |          }	t!          |||	|          }||t#          ||	j        |d                   }|	|                     ||          z   }t(          j                            |t/          | j                  | j                  }| j        r%| j        r|
rt4                              d	           d
}
|rdnd }|rdnd }|r|dnd }|
rdnd }t9          ||gddg          D ]z\  }}|s|                                d         t;          | j                  k    rCt          d| dt;          | j                   d|                                d          d          {t?          | j                  D ]\  }}|r||fz  }| j        r t          j         g           }|| j!        k     r4|||         nd }| j        r?| j        r8| "                    |j#        |||||||         nd |||         nd d ||

  
        }n( ||||||||         nd |||         nd |||
	  	        }|d         }|
r|||rdnd         fz  }|r||d         fz  }|||d         fz  }| $                    |          }|r||fz  }|
r|nd }|stK          d |||||fD                       S tM          |||||          S )NzDYou cannot specify both input_ids and inputs_embeds at the same timerH   z5You have to specify either input_ids or inputs_embedsr   r7   r<   )r   r   z_`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...Fr   r   r   zThe `z` should be specified for z layers, but it is for .)r}   r   r   r~   r   r|   r   r   r	   r   c              3      K   | ]}||V  	d S r   r   )r   vs     r$   	<genexpr>z$XGLMModel.forward.<locals>.<genexpr>  s0        =  === r%   )last_hidden_stater   rz   
attentionscross_attentions)'r   r   r   r   use_return_dictrn   %warn_if_padding_and_no_attention_maskrZ   rR   r\   r/   rL   longr>   rN   r   r   r   r=   r   r   r   rh   r.   r   r   loggerwarning_onceziplenr   	enumeraterandr   _gradient_checkpointing_func__call__r   tupler   )r"   r&   r}   rW   r   r   r   r   r   r   r   r   r   r   input_shaperX   rz   all_hidden_statesall_self_attnsall_cross_attentionsnext_decoder_cache	attn_mask	mask_nameidxdecoder_layerdropout_probabilityr|   layer_outputs
next_caches                                r$   r(   zXGLMModel.forward  s   , 2C1N--TXT_Tq$8$D  $+Jj 	 "+!6IIDK<Q	%0%<kk$+B]  ]%>cddd"66y.QQQ#..**K!r;r?;;II&',,..ss3KKTUUUCRC^!3A!6!<Q!?!?de <&B"88j+4+@y''mFZ	  L (11!44L  --i88M:K8N
 

 !,1G1S%?&(;[QS_& & &" &(<(<\Ka(b(bb--muT\?R?R]a]j-kk& 	"4= 	" "##!   "	 #7@BBD0:d&7h<Q<]rrdh#,6RR$ %(4H(IKYoKp$q$q 	 	 Iy$>>##A&#dk*:*:::$3	 3 3SEUEU 3 3%NN,,Q/3 3 3   #,DK"8"8 /	@ /	@C# 6!m%55!} &+jnn#&775D5P_S11VZN* t}  $ A A!*!")*&/&;IcNN1E1Q(--W[%! ! !.!#1*?+A7@7LYs^^RV5I5U,S11[_#1&7'! ! ! *!,M V"}:K5RQQQR'S&UU"  @=#3"55(4(]1-=,??(66   	2-!11+4>''$
 	  '5FXlm     
 9+&+%1
 
 
 	
r%   r   )NNNNNNNNNNNNN)r)   r*   r+   r,   r   r   r   r   r!   r   r   r   XGLM_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOCr/   r0   r   FloatTensorr   r   r   r(   r1   r2   s   @r$   r   r     s       
  z ",9O      6! ! !" " " +*+@AA&=$   -115/38<9=,07;=A04$(,0/3&*Y
 Y
EL)Y
 !.Y
 u|,	Y

  (5Y
 !) 6Y
 EL)Y
 'u|4Y
 "$u'8"9:Y
  -Y
 D>Y
 $D>Y
 'tnY
 d^Y
 
uU\"$MM	NY
 Y
 Y
  BAY
 Y
 Y
 Y
 Y
r%   r   z
    The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    c            #       @    e Zd ZdZdgZ fdZd Zd Zd Zd Z	 e
e           eeee          	 	 	 	 	 	 	 	 	 	 	 	 	 	 dd
eej                 deej                 deej                 deej                 deej                 deej                 deej                 deeej                          deej                 deej                 dee         dee         dee         dee         deeej                 ef         fd                        Zed             Z xZS )XGLMForCausalLMr   zlm_head.weightc                     t                                          |           t          |          | _        t	          j        |j        |j        d          | _        | 	                                 d S )NFrl   )
r    r!   r   r   r   rp   hidden_sizer   lm_headr   r   s     r$   r!   zXGLMForCausalLM.__init__  s`       v&&
y!3V5FUSSS 	r%   c                     | j         j        S r   r   r   r   s    r$   r   z$XGLMForCausalLM.get_input_embeddings  s    z&&r%   c                     || j         _        d S r   r#  r   s     r$   r   z$XGLMForCausalLM.set_input_embeddings  s    "'
r%   c                     | j         S r   r!  r   s    r$   get_output_embeddingsz%XGLMForCausalLM.get_output_embeddings  s
    |r%   c                     || _         d S r   r&  )r"   new_embeddingss     r$   set_output_embeddingsz%XGLMForCausalLM.set_output_embeddings  s    %r%   r   Nr&   r}   rW   r   r   r   r   r   r   labelsr   r   r   r   r   c                    ||n| j         j        }||n| j         j        }||n| j         j        }|                     |||||||||	||||          }|                     |d                   }d}|
|
                    |
j                  }|
ddddf                                         |ddddf<   | j         j	        |dddf<   t                      } ||                    d| j         j                  |                    d                    }|s|f|dd         z   }||f|z   n|S t          |||j        |j        |j        |j                  S )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        N)r&   r}   rW   r   r   r   r   r   r   r   r   r   r   r   r   rH   )losslogitsr   rz   r   r   )r   r   r   r  r   r!  	new_zerosr\   cloner   r   rR   r   r   r   rz   r   r   )r"   r&   r}   rW   r   r   r   r   r   r   r+  r   r   r   r   r   r.  r-  shift_labelsloss_fctoutputs                        r$   r(   zXGLMForCausalLM.forward  s   < 2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+B] **)%"7#9!5+'/!5#  
 
  gaj))!++FL99L#)!!!QRR%=#6#6#8#8LCRC "&+":LB'))H8FKKDK,BCC\EVEVWYEZEZ[[D 	DY,F'+'7D7V##VC0#3!/)$5
 
 
 	
r%   c                 T    d}| D ]!}|t          fd|D                       fz  }"|S )Nr   c              3   t   K   | ]2}|                     d                     |j                            V  3dS )r   N)r[   rB   r>   )r   
past_statebeam_idxs     r$   r   z1XGLMForCausalLM._reorder_cache.<locals>.<genexpr>'  sC      nnU_j--aZ=N1O1OPPnnnnnnr%   )r  )r   r7  reordered_past
layer_pasts    `  r$   _reorder_cachezXGLMForCausalLM._reorder_cache"  sQ    ) 	 	Jnnnncmnnnnn NN r%   )NNNNNNNNNNNNNN)r)   r*   r+   r   _tied_weights_keysr!   r   r   r'  r*  r   r  r   r  r   r  r   r/   r0   r   r  r   r   r   r(   ra   r:  r1   r2   s   @r$   r  r    s>         *+    ' ' '( ( (  & & & +*+@AA&5$   -115/38<9=,07;=A04)-$(,0/3&*F
 F
EL)F
 !.F
 u|,	F

  (5F
 !) 6F
 EL)F
 'u|4F
 "$u'8"9:F
  -F
 &F
 D>F
 $D>F
 'tnF
 d^F
  
uU\"$EE	F!F
 F
 F
  BAF
P   \    r%   r  )/r,   rI   typingr   r   r   r   r/   torch.utils.checkpointr   torch.nnr   activationsr
   
generationr   modeling_attn_mask_utilsr   r   modeling_outputsr   r   modeling_utilsr   utilsr   r   r   r   configuration_xglmr   
get_loggerr)   r  r  r  XGLM_START_DOCSTRINGr  r   r   Moduler4   rd   r   r   r   r  r   r%   r$   <module>rI     s      / / / / / / / / / / / /            % % % % % % ! ! ! ! ! ! ) ) ) ) ) ) e e e e e e e e l l l l l l l l - - - - - - u u u u u u u u u u u u * * * * * * 
	H	%	%*   D P
= 
= 
= 
= 
=bl 
= 
= 
=1w 1w 1w 1w 1w	 1w 1w 1whWB WB WB WB WBBI WB WB WBtu u u u ury u u up? ? ? ? ?/ ? ? ?$ d I
 I
 I
 I
 I
# I
 I
	 I
X   m m m m m)? m m m m mr%   