
    g                        d Z ddlZddlmZ ddlmZ ddlmZmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZmZmZmZmZ ddlmZ  ej        e           Z!dZ"dZ#da$d Z% G d dej&        j'                  Z(d*dZ)d*dZ* G d dej+                  Z, G d dej+                  Z- G d dej+                  Z. G d de          Z/e G d de                      Z0e G d  d!e                      Z1d"Z2d#Z3 ed$e2           G d% d&e/                      Z4 ed'e2           G d( d)e/e                      Z5dS )+zPyTorch RWKV model.    N)	dataclass)Path)ListOptionalTupleUnion)nn)CrossEntropyLoss   )GenerationMixin)PreTrainedModel)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardis_bitsandbytes_availableis_ninja_availableis_torch_cuda_availablelogging   )
RwkvConfigzRWKV/rwkv-4-169m-piler   c                    ddl m} t          t                                                    j        j        j        dz  dz  fddD             }t          t          j        | k    rd S t          	                    d|  d           d	d
ddddd|  g} |d|  |t          j                    t          j        k    |          a| t          _        d S )Nr   )loadkernelsrwkvc                     g | ]}|z  S  r   ).0fkernel_folders     b/var/www/html/ai-engine/env/lib/python3.11/site-packages/transformers/models/rwkv/modeling_rwkv.py
<listcomp>z(load_wkv_cuda_kernel.<locals>.<listcomp>:   s    fffq*fff    )z
wkv_op.cppzwkv_cuda.cuzwkv_cuda_bf16.cuz2Loading CUDA kernel for RWKV at context length of .z
-res-usagez--maxrregcount 60z--use_fast_mathz-O3z-Xptxas -O3z--extra-device-vectorizationz-DTmax=wkv_)namesourcesverboseextra_cuda_cflags)torch.utils.cpp_extensionr   r   __file__resolveparentrwkv_cuda_kernelmax_seq_lengthloggerinfor   get_verbosityDEBUG)context_lengthload_kernelcuda_kernel_filesflagsr    s       @r!   load_wkv_cuda_kernelr8   4   s
   ====== NN**,,3:AIMPVVMffff4efff #(8(G>(Y(Y
KKV^VVVWWW 	&".""E #{$N$$!&((GM9	   '5###r#   c                   >    e Zd Zedd            Zedd            ZdS )RwkvLinearAttentionNFc                 B   |                                 \  }}}	|t          j        k    r t          d| dt          j         d          ||	z  t	          |	d          z  dk    r't          d| d|	 dt	          |	d           d	          |j        | _        |j        j        d
k    s0|j        j        d
k    s |j        j        d
k    s|j        j        d
k    rt          d          t          j
        |                                                                           }|j        t          j        k    r<|                                }|                                }|                                }|                                }|                                }|                                }t          j        |t          j                  }
|s||Kt          j        ||	dt          j        |j        t          j                  }|d d d d dfxx         dz  cc<   n2t          j        d |D             d                                          }|j        t          j        k    rt          j        }nt          j        } ||||||
|           n<|j        t          j        k    rt          j        nt          j        } ||||||
           |                     |||||
           |!d t          j        |dd          D             }|
                    | j                  |fS )NzCannot process a batch with z+ tokens at the same time, use a maximum of z with this model.    r   zThe product of batch size (z) and hidden size (z") needs to be a round multiple of r$   cudazUCalling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.memory_formatr   )dtypedevicer?      籡*Gc                 8    g | ]}|                     d           S rB   	unsqueezer   ss     r!   r"   z/RwkvLinearAttention.forward.<locals>.<listcomp>   s"    "A"A"Aa1;;q>>"A"A"Ar#   )dimc                 8    g | ]}|                     d           S rE   )squeezerH   s     r!   r"   z/RwkvLinearAttention.forward.<locals>.<listcomp>   s"    HHHaQYYq\\HHHr#   )sizer.   r/   
ValueErrorminr@   input_dtyperA   typetorchexpfloat
contiguousfloat16
empty_likecontiguous_formatzerosfloat32catbfloat16forward_with_state_bf16forward_with_stateforward_bf16forwardsave_for_backwardchunkto)ctx
time_decay
time_firstkeyvaluestatereturn_state
batch_sizeseq_lenhidden_sizeoutputforward_funcs               r!   r`   zRwkvLinearAttention.forwardU   sG   +.88::(
G[%444Fw F F#2F F F   #c+r&:&::a??7j 7 7[ 7 7";337 7 7  
 ) "f,, %//z&((| F**tuuui
 0 0 2 2 = = ? ?@@@
9%%#))++J))++CKKMME**,,
nn  ""!#U5LMMM 	E5,}-:"'"9   aaaAg$&	"A"A5"A"A"AqIIITTVVyEN**/G/BLZeVUKKKK<?I<W<W+88]m]uLLZeVDDDj*c5&IIIHH5;uaQ+G+G+GHHHEyy))500r#   c                    | j         }| j        \  }}}}}t          j        |t          j        |t          j        k    rt          j        nt          j                  }	t          j        |t          j                  }
t          j        |t          j                  }t          j        |t          j                  }|t          j        k    r|                                }|t          j        k    rt          j
        nt          j        } |||||||                                |	|
||
  
         |	                    |          |
                    |          |                    |          |                    |          d d fS )N)r?   r@   r>   )rP   saved_tensorsrR   rW   rX   r\   rZ   rV   rT   r.   backward_bf16backwardrU   rc   )rd   g_outputg_staterP   re   rf   rg   rh   rn   g_time_decayg_time_firstg_keyg_valuebackward_funcs                 r!   rs   zRwkvLinearAttention.backward   se    o585F2
JUF'1$/5>$A$A%..u}
 
 

 '
%BYZZZ E4KLLL"58OPPP%-''~~''H:E:W:W(66]m]v!!	
 	
 	
 OOK((OOK((HH[!!JJ{##
 	
r#   NFN)__name__
__module____qualname__staticmethodr`   rs   r   r#   r!   r:   r:   T   sS        <1 <1 <1 \<1| %
 %
 %
 \%
 %
 %
r#   r:   Fc                    |                                 \  }}}t          j        |          }|t          j        |d d df         t          j                  }	t          j        |d d df         t          j                  }
t          j        |d d df         t          j                  dz
  }n|\  }	}
}t          j        |            } t          |          D ]}|d d |f                                         }|d d |f         }t          j        |||z             }t          j        ||z
            }t          j        ||z   |z
            }||	z  ||z  z   }||
z  |z   }||z                      |j	                  |d d |f<   t          j        || z   |          }t          j        || z   |z
            }t          j        ||z
            }||	z  ||z  z   }	||
z  |z   }
|}|s||	|
|g}||fS )Nr   )r@   rC   )
rM   rR   
zeros_likerZ   rS   rangerT   maximumrc   r@   )re   rf   rg   rh   ri   rj   _
seq_lengthrn   	num_state	den_state	max_statecurrent_indexcurrent_keycurrent_valuemax_for_outpute1e2	numeratordenominatormax_for_states                        r!   rwkv_linear_attention_cpur      s$    xxzzAz1c""F}$SAYemDDD	$SAYemDDD	$SAYemDDDtK		*/'	9i
 )J'''Jz** " "!!!]*+1133aaa./ y+
2JKKYy>122Y{Z/.@AANR-%77	9nr)$-$;#?#?#M#Mqqq-  i*&<kJJYy:-=>>Y{]233NR-%77	NR'	!		 2u(Iy15=r#   c                     t          d | |||fD                       }|                    d          dk    }t          |s|rt          | |||||          S t                              | |||||          S )Nc              3   6   K   | ]}|j         j        d k    V  dS )r=   N)rA   rQ   )r   ts     r!   	<genexpr>z(rwkv_linear_attention.<locals>.<genexpr>   s+      XXa!(-6)XXXXXXr#   r   ri   rj   )anyrM   r.   r   r:   apply)re   rf   rg   rh   ri   rj   no_cuda	one_tokens           r!   rwkv_linear_attentionr      s    XXJ
CQV3WXXXXXG q I7i(ZeSXgstttt"((ZeUT`aaar#   c                   0     e Zd Zd fd	ZddZd	dZ xZS )
RwkvSelfAttentionr   c                 d   t                                                       || _        t          d uot          j        |j        k    }t                      rPt                      rB|s@	 t          |j                   n*# t          $ r t                              d           Y nw xY w|| _        |j        }|j        |j        n|}|| _        t          j        t#          j        |                    | _        t          j        t#          j        |                    | _        t          j        t#          j        dd|                    | _        t          j        t#          j        dd|                    | _        t          j        t#          j        dd|                    | _        t          j        d          | _        t          j        ||d          | _        t          j        ||d          | _        t          j        ||d          | _        t          j        ||d          | _        d S )Nz9Could not load the custom CUDA kernel for RWKV attention.r   r   r   r   Fbias)super__init__configr.   r/   r4   r   r   r8   	Exceptionr0   r1   layer_idrm   attention_hidden_sizer	   	ParameterrR   emptyre   rf   time_mix_keytime_mix_valuetime_mix_receptance	ZeroPad2d
time_shiftLinearrg   rh   
receptancern   )selfr   r   kernel_loadedrm   r   	__class__s         r!   r   zRwkvSelfAttention.__init__   s   (4q9I9X\b\q9q 	Y$;$=$= 	Ym 	YY$V%:;;;; Y Y YWXXXXXY (,2,H,TF((Ze 	 &;",u{3H'I'IJJ,u{3H'I'IJJLQ;)G)GHH l5;q![+I+IJJ#%<Aq+0N0N#O#O ,}559[*?eLLLY{,ANNN
)K1FUSSSi 5{OOOs   &A; ;$B"!B"Nc                 b   |                     d          dk    r||d         d d d d | j        f         }n8|                     |          }|!|d         d d d d | j        f         |d d df<   || j        z  |d| j        z
  z  z   }|| j        z  |d| j        z
  z  z   }|| j        z  |d| j        z
  z  z   }|                     |          }|                     |          }t          j	        | 
                    |                    }|!|d d df         |d         d d d d | j        f<   ||||fS Nr   r   r   )rM   r   r   r   r   r   rg   rh   rR   sigmoidr   )r   hiddenri   shiftedrg   rh   r   s          r!   extract_key_valuez#RwkvSelfAttention.extract_key_value  sW   ;;q>>Q5#4Ahqqq!!!T]23GGoof--G  %aAAAt})< =1t((7a$:K6K+LL,,w!d>Q:Q/RRd66AH`D`9aa
hhsmm

5!!]4??:#>#>??
,2111b5ME!HQQQ4=()3u,,r#   Fc                                           ||          \  }}}}|#t           fd|dd          D                       nd }t           j         j        ||||          \  }}|W|d         |d         d d d d  j        f<   |d         |d         d d d d  j        f<   |d         |d         d d d d  j        f<                        ||z            |fS )	Nri   c              3   @   K   | ]}|d d d d j         f         V  d S r|   r   )r   rI   r   s     r!   r   z,RwkvSelfAttention.forward.<locals>.<genexpr>)  s9      FFqAaaaDM12FFFFFFr#   rB   r   r   r   r      )r   tupler   re   rf   r   rn   )	r   r   ri   	use_cacher   rg   rh   layer_stater   s	   `        r!   r`   zRwkvSelfAttention.forward'  s   (,(>(>vU(>(S(S%
CJOJ[eFFFFE!""IFFFFFFae1OO"
 
 
k ",7NE!HQQQ4=(),7NE!HQQQ4=(),7NE!HQQQ4=(){{:,--u44r#   r   r|   r{   )r}   r~   r   r   r   r`   __classcell__r   s   @r!   r   r      sk        P P P P P P<- - - -&5 5 5 5 5 5 5 5r#   r   c                   (     e Zd Zd fd	ZddZ xZS )RwkvFeedForwardr   c                 0   t                                                       || _        || _        |j        }|j        |j        n	d|j        z  }t          j        d          | _        t          j	        t          j        dd|                    | _        t          j	        t          j        dd|                    | _        t          j        ||d          | _        t          j        ||d          | _        t          j        ||d          | _        d S )Nr   r   r   Fr   )r   r   r   r   rm   intermediate_sizer	   r   r   r   rR   r   r   r   r   rg   r   rh   )r   r   r   rm   r   r   s        r!   r   zRwkvFeedForward.__init__<  s     ((.(@(LF$$RSV\VhRh 	 ,}55LQ;)G)GHH#%<Aq+0N0N#O#O 9[*;%HHH)K5IIIY0+EJJJ


r#   Nc                 |   |                     d          dk    r||d         d d d d | j        f         }n8|                     |          }|!|d         d d d d | j        f         |d d df<   || j        z  |d| j        z
  z  z   }|| j        z  |d| j        z
  z  z   }t          j        t          j        |                     |                              }| 	                    |          }t          j
        |                     |                    }|!|d d df         |d         d d d d | j        f<   ||z  |fS r   )rM   r   r   r   r   rR   squarerelurg   rh   r   r   )r   r   ri   r   rg   r   rh   s          r!   r`   zRwkvFeedForward.forwardM  sK   ;;q>>Q5#4Ahqqq!!!T]23GGoof--G  %aAAAt})< =1t((7a$:K6K+LLd66AH`D`9aa
l5:dhhsmm4455

3]4??:#>#>??
,2111b5ME!HQQQ4=()E!5((r#   r   r|   r}   r~   r   r   r`   r   r   s   @r!   r   r   ;  sW        K K K K K K") ) ) ) ) ) ) )r#   r   c                   &     e Zd Z fdZddZ xZS )	RwkvBlockc                    t                                                       || _        || _        |dk    r%t	          j        |j        |j                  | _        t	          j        |j        |j                  | _	        t	          j        |j        |j                  | _
        t          ||          | _        t          ||          | _        d S )Nr   )eps)r   r   r   r   r	   	LayerNormrm   layer_norm_epsilonpre_lnln1ln2r   	attentionr   feed_forward)r   r   r   r   s      r!   r   zRwkvBlock.__init__b  s     q==,v'9v?XYYYDK< 28QRRR< 28QRRR*68<<+FH==r#   NFc                 4   | j         dk    r|                     |          }|                     |                     |          ||          \  }}||z   }|                     |                     |          |          \  }}||z   }||f}|r||fz  }n|dz  }|S )Nr   )ri   r   r   r|   )r   r   r   r   r   r   )r   r   ri   r   output_attentionsr   r   outputss           r!   r`   zRwkvBlock.forwardp  s    =A[[((F>>$((6*:*:%S\>]]	5)#"//0@0@/NNe,&5/ 		|#GGwGr#   )NFFr   r   s   @r!   r   r   a  sL        > > > > >       r#   r   c                   6    e Zd ZdZeZdZdgZddgZdZ	dZ
d ZdS )	RwkvPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    r   r   re   rf   Tc                   	
 t          |t                    r|j        }|j        j        }|j        j        	|j        ||dz
  z  
d||z  z
  }t          j        	fdt          	          D             |j
        j        |j
        j                  }|ddddf         }
fdt                    D             }t          j        ||j        j        |j        j                  }t          j        d t                    D             |j        j        |j        j                  dz  }t          j                    5  ||j        _        t          j        |j        t%          j        d	          z  |z             |j        _        t          j        ||          |j
        _        t          j        ||          d	
z  z   |j        _        t          j        |d|z            |j        _        ddd           dS # 1 swxY w Y   dS t          |t.                    r|j        }|j        j        }|j        j        	d||z  z
  }t          j        	fd
t          	          D             |j
        j        |j
        j                  }|ddddf         }t          j                    5  t          j        ||          |j
        _        t          j        ||          |j        _        ddd           dS # 1 swxY w Y   dS dS )zInitialize the weights.r   g      ?c                     g | ]}|z  S r   r   r   irm   s     r!   r"   z5RwkvPreTrainedModel._init_weights.<locals>.<listcomp>      ===Q[===r#   r@   rA   Nc                 >    g | ]}d d|dz
  z  ddz  z   z  z  z   S )   r   gffffff?g?r   )r   hr   ratio_0_to_1s     r!   r"   z5RwkvPreTrainedModel._init_weights.<locals>.<listcomp>  sM        Q!4q89sS<EW?WXXX  r#   c                 $    g | ]}|d z   dz  d z
  S )r   r   r   )r   r   s     r!   r"   z5RwkvPreTrainedModel._init_weights.<locals>.<listcomp>  s$    KKKa!eq[1_KKKr#   g      ?g333333?c                     g | ]}|z  S r   r   r   s     r!   r"   z5RwkvPreTrainedModel._init_weights.<locals>.<listcomp>  r   r#   )
isinstancer   r   r   num_hidden_layersrm   r   rR   tensorr   r   r@   rA   re   rf   no_graddata	ones_likemathlogpowr   r   r   )r   moduler   r   ratio_1_to_almost0time_weightdecay_speedzigzagr   rm   r   s           @@@r!   _init_weightsz!RwkvPreTrainedModel._init_weights  s   f/00 5	]H & ? -3K$*$@!#'81'<=L!$3D(D!E,====%*<*<===)/*1  K
 &dD!!!m4K    455  K  ,{&:K:QZ`ZkZrsssKKKe4I.J.JKKK +1!,3  
    c c)4!&).9JTXVY]]9Z]c9c)d)d!&+09[BT+U+U#(-2Y{DV-W-WZ]`lZl-l%*27)KOaIa2b2b*/c c c c c c c c c c c c c c c c c c 00 	]H & ? -3K!$3D(D!E,====%*<*<===)/*1  K
 &dD!!!m4K ] ]+09[BT+U+U#(27)KI[2\2\*/] ] ] ] ] ] ] ] ] ] ] ] ] ] ] ] ] ]	] 	]s%    B.G;;G?G?(?K44K8;K8N)r}   r~   r   __doc__r   config_classbase_model_prefix_no_split_modules_keep_in_fp32_modulessupports_gradient_checkpointing_is_statefulr   r   r#   r!   r   r     s]         
 L$)<8&*#L7] 7] 7] 7] 7]r#   r   c                       e Zd ZU dZdZej        ed<   dZe	e
ej                          ed<   dZe	eej        df                  ed<   dZe	eej        df                  ed<   dS )
RwkvOutputa  
    Class for the RWKV model outputs.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlast_hidden_stateri   .hidden_states
attentions)r}   r~   r   r   r  rR   FloatTensor__annotations__ri   r   r   r  r   r  r   r#   r!   r  r    s          , ,0u(////3E8D*+,333=AM8E%"3S"89:AAA:>Ju0#567>>>>>r#   r  c                       e Zd ZU dZdZeej                 ed<   dZ	ej        ed<   dZ
eeej                          ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   dS )	RwkvCausalLMOutputa|  
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlosslogitsri   .r  r  )r}   r~   r   r   r  r   rR   r  r	  r  ri   r   r  r   r  r   r#   r!   r  r    s          0 )-D(5$
%,,, $FE$$$/3E8D*+,333=AM8E%"3S"89:AAA:>Ju0#567>>>>>r#   r  a>  

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`RwkvConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a
  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            This is currently not used by `RwkvModel`, but will be supported in the future.

            [What are attention masks?](../glossary#attention-mask)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        use_cache (`bool`, *optional*):
            If set to `True`, the last state is returned and can be used to quickly generate the next logits.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z^The bare RWKV Model transformer outputting raw hidden-states without any specific head on top.c                   b    e Zd Z fdZd Zd Z ee           ee	e
e          	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 d	eeej                          d
ee         dee         dee         dee         deee
f         fd                        Zd Zd Z xZS )	RwkvModelc                    t                                                     t          j        j        j                  | _        t          j        fdt          j	                  D                       | _
        t          j        j                  | _        d| _        d| _        |                                  d S )Nc                 2    g | ]}t          |           S )r   )r   )r   idxr   s     r!   r"   z&RwkvModel.__init__.<locals>.<listcomp>M  s&    $p$p$pYv%D%D%D$p$p$pr#   F)r   r   r	   	Embedding
vocab_sizerm   
embeddings
ModuleListr   r   blocksr   ln_outlayers_are_rescaledgradient_checkpointing	post_initr   r   r   s    `r!   r   zRwkvModel.__init__I  s       ,v'8&:LMMm$p$p$p$pPUV\VnPoPo$p$p$pqql6#566#( &+# 	r#   c                     | j         S r|   r  r   s    r!   get_input_embeddingszRwkvModel.get_input_embeddingsW  s
    r#   c                     || _         d S r|   r  r   new_embeddingss     r!   set_input_embeddingszRwkvModel.set_input_embeddingsZ  s    (r#   
checkpointoutput_typer   N	input_idsattention_maskinputs_embedsri   r   r   output_hidden_statesreturn_dictreturnc	                    ||n| j         j        }||n| j         j        }||n| j        s| j         j        nd}||n| j         j        }|t                              d           | j        | j        k    r| 	                                 |t          d          |t          d          |                     |          |rZ|X                    d          | j         j        | j         j        ffdt          d          D             }|dxx         d	z  cc<   | j        r%| j        r|rt                              d
           d}}	|rdnd }
|rdnd }t#          | j                  D ]\  }}| j        r*| j        r#|                     |j        |	|||          \  }	}}n ||	|||          \  }	}}| j        r+| j         j        dk    r|dz   | j         j        z  dk    r|	dz  }	|r||	fz   }|r|
|fz   }
|                     |	          }	|r||	fz   }|st/          d |	|||
fD                       S t1          |	|||
          S )NFz<`attention_mask` was passed, but it is unused in this model.zDYou cannot specify both input_ids and inputs_embeds at the same timez5You have to specify either input_ids or inputs_embedsr   c                 l    g | ]0}t          j        |d k    rj        nt           j        j        d1S )r   r   )rR   rY   r@   rZ   rA   )r   r   r*  shapes     r!   r"   z%RwkvModel.forward.<locals>.<listcomp>  sY         a-"5"5U][h[o    r#      r   gꌠ9Y>)FzZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...r   )ri   r   r   r   rB   c              3      K   | ]}||V  	d S r|   r   )r   xs     r!   r   z$RwkvModel.forward.<locals>.<genexpr>  s(      ttqfgfsfsfsfsfsttr#   )r  ri   r  r  )r   r   r+  trainingr   use_return_dictr0   warning_oncer  _rescale_layersrN   r  rM   rm   r   r   r  	enumerater  _gradient_checkpointing_func__call__rescale_everyr  r   r  )r   r(  r)  r*  ri   r   r   r+  r,  r  all_self_attentionsall_hidden_statesr  blockr  r0  s      `           @r!   r`   zRwkvModel.forward]  sF   " 2C1N--TXT_Tq$8$D  $+Jj 	 "+!6IIZ^Zg=rT[=R=Rmr	%0%<kk$+B]! ^___=D444  """ ]%>cddd=#8TUUU  OOI66M 	"''**DK,CT[EbcE     q	  E !HHHHHH& 	"4= 	" "##p   "	%$5?bb4"6@BBD#DK00 	J 	JJC* t} 373T3TNM5)EV4 40ujj 495!)Wh4 4 40uj
 (2K-111W 99Q>> - 1# I$58H$H!  J&9ZM&I#M22 	E 1]4D D 	utt]E;LNa$btttttt++*	
 
 
 	
r#   c           	      ,   | j         | j         k    rd S | j        j        dk    rbt	          j                    5  t          | j                  D ] \  }}| j        r|j        j	        j
                            dt          || j        j        z            z             |j        j        j
                            dt          || j        j        z            z             t          |j        j	        j
        d          r|j        j	        j
        j                            dt          || j        j        z            z             |j        j        j
        j                            dt          || j        j        z            z             =t          |j        j	        j
        d          rB|                     |j        j	        |           |                     |j        j        |           |j        j	        j
                            dt          || j        j        z            z             |j        j        j
                            dt          || j        j        z            z             "	 d d d            n# 1 swxY w Y   | j         | _         d S )Nr   rB   SCBquant_state)r  r4  r   r;  rR   r   r8  r  r   rn   weightmul_intr   rh   hasattrr@  div_ _bnb_4bit_dequantize_and_rescale)r   block_idr>  s      r!   r7  zRwkvModel._rescale_layers  s   #DM(9::F;$q(( r r'0'='= r rOHe} r.5::1HPTP[PiDi@j@j;jkkk*07<<Q#hRVR]RkFkBlBl=lmmmm #5?#9#@%HH r!O29=BB1HX\XcXqLqHrHrCrsss!.4;?DDQ#hZ^ZeZsNsJtJtEtuuuu$U_%;%BMRR r AA%/BXZbccc AA%BTBZ\deeee!O29>>qCTXT_TmHmDnDn?nooo!.4;@@c(VZVaVoJoFpFpApqqqqrr r r r r r r r r r r r r r r" (,}#4   s   H7I<<J J c                    t                      st          d          ddl}|j                            |j        j        |j        j                  }|                    dt          || j
        j        z            z             |j                            |                    d          d                              |j                  }t!          |d|           dS )	z
        Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
        be quantized again.
        z/Please install bitsandbytes to use this method.r   NrB   cpuF)requires_gradrB  )r   ImportErrorbitsandbytes
functionaldequantize_4bitrB  r   rA  rF  rD  r   r;  r	   
Params4bitrc   rA   setattr)r   target_layerrH  bnbdequant_weightsquant_weights         r!   rG  z*RwkvModel._bnb_4bit_dequantize_and_rescale  s    
 )** 	QOPPP"""".889L9QS_SfSrssQ#h$+2K&K"L"LLMMM v((););E)B)BRW(XX[[\k\rssh55555r#   )NNNNNNNN)r}   r~   r   r   r   r$  r   RWKV_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr  _CONFIG_FOR_DOCr   rR   
LongTensorr  r   boolr   r   r`   r7  rG  r   r   s   @r!   r  r  D  s       
      ) ) ) +*+@AA&$   15595937$(,0/3&*Y
 Y
E,-Y
 !!12Y
   12	Y

 U./0Y
 D>Y
 $D>Y
 'tnY
 d^Y
 
uj 	!Y
 Y
 Y
  BAY
v5 5 506 6 6 6 6 6 6r#   r  z
    The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    c                       e Zd ZdgZ fdZd Zd ZddZ ee	           e
eee          	 	 	 	 	 	 	 	 	 ddeej                 d	eej                 d
eej                 deeej                          deej                 dee         dee         dee         dee         deeef         fd                        Z xZS )RwkvForCausalLMzhead.weightc                     t                                          |           t          |          | _        t	          j        |j        |j        d          | _        | 	                                 d S )NFr   )
r   r   r  r   r	   r   rm   r  headr  r  s     r!   r   zRwkvForCausalLM.__init__  s`       f%%	If0&2C%PPP	 	r#   c                     | j         S r|   r^  r  s    r!   get_output_embeddingsz%RwkvForCausalLM.get_output_embeddings  s
    yr#   c                     || _         d S r|   r`  r"  s     r!   set_output_embeddingsz%RwkvForCausalLM.set_output_embeddings  s    "			r#   Nc                 v    ||d d df                              d          }||d|i}nd|i}||d<   ||d<   |S )Nr   r*  r(  ri   r   rF   )r   r(  ri   r*  r   kwargsmodel_inputss          r!   prepare_inputs_for_generationz-RwkvForCausalLM.prepare_inputs_for_generation  sf     !!!!R%(22266I $+];LL'3L %W$-[!r#   r%  r(  r)  r*  ri   labelsr   r   r+  r,  r-  c
           	      p   |	|	n| j         j        }	|                     |||||||	          }
|
d         }|                     |          }d}||                    |j                  }|dddddf                                         }|dddf                                         }t                      } ||                    d|	                    d                    |                    d                    }|	s|f|
dd         z   }||f|z   n|S t          |||
j        |
j        |
j                  S )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        N)r*  ri   r   r   r+  r,  r   .r   r   )r  r  ri   r  r  )r   r5  r   r^  rc   rA   rU   r
   viewrM   r  ri   r  r  )r   r(  r)  r*  ri   rh  r   r   r+  r,  rwkv_outputsr  r  r  shift_logitsshift_labelsloss_fctrn   s                     r!   r`   zRwkvForCausalLM.forward  sl   0 &1%<kk$+B]yy'/!5# ! 
 
 %Q=))YYv}--F!#ssAAA+.99;;L!#qrr'?5577L'))H8L--b,2C2CB2G2GHH,J[J[\^J_J_``D 	FYabb!11F)-)9TGf$$vE!$&4#.
 
 
 	
r#   )NNN)	NNNNNNNNN)r}   r~   r   _tied_weights_keysr   ra  rc  rg  r   rV  r   rW  r  rX  r   rR   rY  r  r   rZ  r   r   r`   r   r   s   @r!   r\  r\    s        (      # # #   " +*+@AA&&$   15595937-1$(,0/3&*6
 6
E,-6
 !!126
   12	6

 U./06
 )*6
 D>6
 $D>6
 'tn6
 d^6
 
u((	)6
 6
 6
  BA6
 6
 6
 6
 6
r#   r\  r{   )6r   r   dataclassesr   pathlibr   typingr   r   r   r   rR   torch.utils.checkpointr	   torch.nnr
   
generationr   modeling_utilsr   utilsr   r   r   r   r   r   r   r   configuration_rwkvr   
get_loggerr}   r0   rW  rX  r.   r8   autogradFunctionr:   r   r   Moduler   r   r   r   r  r  RWKV_START_DOCSTRINGrV  r  r\  r   r#   r!   <module>r~     s       ! ! ! ! ! !       / / / / / / / / / / / /            % % % % % % ) ) ) ) ) ) - - - - - -	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 + * * * * * 
	H	%	%-   5 5 5@g
 g
 g
 g
 g
%.1 g
 g
 g
T) ) ) )Xb b b bC5 C5 C5 C5 C5	 C5 C5 C5L#) #) #) #) #)bi #) #) #)L    	   DD] D] D] D] D]/ D] D] D]N ? ? ? ? ? ? ? ?: ? ? ? ? ? ? ? ?@  ( V d a6 a6 a6 a6 a6# a6 a6	 a6H   ^
 ^
 ^
 ^
 ^
)? ^
 ^
 ^
 ^
 ^
r#   