
    g                     @   d Z ddlZddlmZ ddlmZmZmZ ddlZddl	Zddlm
Z
 ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZmZ ddlmZmZ ddlmZ  ej        e          Z  e            rddl!m"Z" ddl#m$Z$m%Z% ndZ" e            r	ddl&m'Z'm(Z( nd\  Z(Z' e)e"e'e(f          Z*dZ+dZ,dej-        de.fdZ/d Z0d Z1 G d d          Z2 G d dej
        j3                  Z4 G d de
j3                  Z5 G d  d!e
j3                  Z6 G d" d#e
j3                  Z7 G d$ d%e          Z8e G d& d'e                      Z9e G d( d)e                      Z:d*Z;d+Z< ed,e;           G d- d.e8                      Z= ed/e;           G d0 d1e8e                      Z>dS )2zPyTorch MAMBA2 model.    N)	dataclass)OptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN)GenerationMixin)PreTrainedModel)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging)is_causal_conv1d_availableis_mamba_2_ssm_available   )Mamba2Config)selective_state_update)mamba_chunk_scan_combined mamba_split_conv1d_scan_combined)causal_conv1d_fncausal_conv1d_update)NNz!mistralai/mamba-codestral-7B-v0.1r   input_tensorpad_sizec                     t          | j                  dk    r
ddddd|ddfnddd|ddf}t          j        j                            | |dd          S )z
    Padding x tensor with `pad_size` on the seq_len dim (dim=1)

    Assumes that we only have tensors of either size 4 or 3
       r   constant)modevalue)lenshapetorchr   
functionalpad)r   r   	pad_shapes      f/var/www/html/ai-engine/env/lib/python3.11/site-packages/transformers/models/mamba2/modeling_mamba2.pypad_tensor_by_sizer)   ?   sj     47|7I3J3Ja3O3OAq!Q!Q//VWYZ\]_gijlmUnI8""<ST"UUU    c                 "   t          | |          } t          | j                  dk    r.|                     | j        d         d|| j        d                   S |                     | j        d         d|| j        d         | j        d                   S )z
    Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
    simultaneously splitting it into chunk sequences.

    Assumes that we only have tensors of either size 4 or 3
    r	   r      )r)   r"   r#   reshape)r   r   
chunk_sizes      r(   reshape_into_chunksr0   J   s     &lH==L
<!####L$6q$92z<K]^_K`aaa ##q!2z<3Ea3H,J\]^J_
 
 	
r*   c                    |                      d          } | d         j        g |                                  |R  } t          j        t          j        ||| j        t          j                  d          }|                     | d          } t          j        | d          }t          j        t          j        ||| j        t          j                  d          }|                    | t          j	                   }|S )zo
    More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
    r,   .Ndevicedtype)diagonalr   dim)
sizeexpandr$   trilonesr4   boolmasked_fillcumsuminf)r   r/   masktensor_segsums       r(   segment_sumrD   ^   s     ""2&&J 2<	*1S<3D3D3F3FS
SSSL:ejZ@S[`[efffqstttD++TE155LL2666M :ejZ@S[`[efffqrsssD!--teeiZ@@Mr*   c            
           e Zd ZdZej        dfdededej        de	e
         fdZded	ej        d
ej        dej        fdZd ZdS )Mamba2Cachea  
    Arguments:
        config: Mamba2Config
        batch_size: int
        dtype: torch.dtype
        device: torch.device

    Attributes:
        seqlen_offset: int
        dtype: torch.dtype
        conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
        ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
    Nconfig
batch_sizer5   r4   c                 n    d _          _        j         _        t	          j        j        z             _         fdt          j	                  D              _
        fdt          j	                  D              _        j         _        t          j                  _        d S )Nr   c           
      ~    i | ]9}|t          j        j        d j        z  j        z  z   j                  :S )r-   r3   )r$   zerosintermediate_sizen_groups
state_sizeconv_kernel_size).0irH   rG   r4   r5   selfs     r(   
<dictcomp>z(Mamba2Cache.__init__.<locals>.<dictcomp>   sg     	
 	
 	
  u{&V_)<v?P)PP%  	
 	
 	
r*   c                 d    i | ],}|t          j        j        j        j                   -S )r3   )r$   rK   	num_headshead_dimrN   )rP   rQ   rH   rG   r4   r5   s     r(   rS   z(Mamba2Cache.__init__.<locals>.<dictcomp>   sU     
 
 
  u{F,fov?PY_gl  
 
 
r*   )seqlen_offsetr5   conv_kernelrO   intr;   hidden_sizerL   rangenum_hidden_layersconv_states
ssm_states
hidden_act
activationr
   act)rR   rG   rH   r5   r4   s   `````r(   __init__zMamba2Cache.__init__   s     
 & 2!$V]V5G%G!H!H	
 	
 	
 	
 	
 	
 	
 	
 6344	
 	
 	

 
 
 
 
 
 
 6344	
 
 
 !+&+,r*   	layer_idxnew_conv_statecache_positionreturnc                 P   | j         |         }|                    d| j        dz
            }|                    dd          }|                    |j                  |d d d d |f<   | j         |                                          | j         |xx         |z  cc<   | j         |         S )Nr   r   r,   shiftsdims)r]   clamprO   rolltor4   zero_)rR   rc   rd   re   
conv_states        r(   update_conv_statezMamba2Cache.update_conv_state   s     %i0
'--a1F1JKK__BR_88
+9+<+<Z=N+O+O
111aaa'(#))+++###z1###	**r*   c                 j    | j                                          | j                                         d S N)r]   rn   r^   rR   s    r(   resetzMamba2Cache.reset   s1       r*   )__name__
__module____qualname____doc__r$   float16r   rY   r5   r   strrb   Tensor
LongTensorrp   rt    r*   r(   rF   rF   r   s          KP-qu- -"-03-<AK-aijman- - - -6
+
+.3l
+LQL\
+	
+ 
+ 
+ 
+         r*   rF   c                   (     e Zd Zd fd	ZddZ xZS )MambaRMSNormGatedư>c                     t                                                       t          j        t	          j        |                    | _        || _        d S rr   superrb   r   	Parameterr$   r=   weightvariance_epsilonrR   rZ   eps	__class__s      r(   rb   zMambaRMSNormGated.__init__   sB    l5:k#:#:;; #r*   Nc                    |j         }|                    t          j                  }|?|t          j                            |                    t          j                            z  }|                    d                              dd          }|t          j	        || j
        z             z  }| j        |                    |          z  S Nr-   r,   T)keepdim)r5   rm   r$   float32r   r%   silupowmeanrsqrtr   r   )rR   hidden_statesgateinput_dtypevariances        r(   forwardzMambaRMSNormGated.forward   s    #)%((77)BM,>,>twwu}?U?U,V,VVM $$Q'',,R,>>%Ht?T4T(U(UU{]--k::::r*   r   rr   ru   rv   rw   rb   r   __classcell__r   s   @r(   r   r      sQ        $ $ $ $ $ $
	; 	; 	; 	; 	; 	; 	; 	;r*   r   c            
       0    e Zd ZdZdedef fdZ	 	 	 ddej        de	e
         de	ej                 d	e	ej                 fd
Zdde	e
         de	ej                 d	e	ej                 fdZ	 	 	 dde	e
         de	ej                 d	e	ej                 fdZ xZS )Mamba2Mixeru  
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
    and is why Mamba is called **selective** state spaces)
    rG   rc   c           	         t                                                       |j        | _        |j        | _        |j        | _        |j        | _        t          |j	        | j        z            | _
        t          |j                  | _        || _        |j        | _        |j        | _        t           |j                 | _        |j        | _        |j        | _        |j        | _        |j        | _        |j        | _        |j        | _        |j        | _        |j        | _        | j
        d| j        z  | j        z  z   | _        t7          j        | j        | j        |j        |j        | j        |j        dz
            | _        | j
        | j        z   | j        z   }t7          j        | j        ||j                  | _         t7          j!        tE          j#        | j                            | _$        tE          j%        d| j        dz             }t7          j!        tE          j&        |                    | _'        d| j'        _(        tS          | j
        | j                  | _*        t7          j!        tE          j#        | j                            | _+        d| j+        _(        t7          j        | j
        | j        |j                  | _,        |j        | _        tZ          st\          /                    d           d S d S )Nr-   r   )in_channelsout_channelsbiaskernel_sizegroupspaddingr   Tr   a  The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d)0r   rb   rU   rZ   rN   ssm_state_sizerX   rO   rY   r;   rL   time_step_rankrc   use_conv_biasr_   r`   r
   ra   layer_norm_epsilonrms_normrM   rV   r/   time_step_limittime_step_mintime_step_maxconv_dimr   Conv1dconv1dLinearuse_biasin_projr   r$   r=   dt_biasarangelogA_log_no_weight_decayr   normDout_projis_fast_path_availableloggerwarning_once)rR   rG   rc   projection_sizeAr   s        r(   rb   zMamba2Mixer.__init__   s   )!-$/ & 2!$V]T5E%E!F!F!&"788"#1 +&+,"("; +%5#1#1.T]1BTEX1XXi%*=&*
 
 
 04=@4>Qy
 
 
 |EJt~$>$>?? LDNQ.//\%)A,,//
&*
#%d&<$BYZZZ	ej8899"&	$"8$:JQWQ`aaa% 	>    	 	r*   Nr   cache_paramsre   attention_maskc                    |j         \  }}}| j        | j        z  }d| j        z  d| j        z  | j        z  z   | j        z   }	||j        dk    r|                     |                    d                    }
|
j         d         |	z
  dz  }||| j        | j        | j        g}t          j
        |
|d          \  }}}}}t          ||j        | j                 | j        j                            d          | j        j        | j                  }t          j
        || j        ||gd          \  }}}t          j        | j                                                   }|d d d df         d d d d d f                             d| j        | j                                      t          j                  }|d d d d d f                             dd| j                  }| j        d d d df                             d| j                  }| j        d d d df                             d| j                  }|                    || j        |j         d         | j        z            }|                    || j        |j         d         | j        z            }|                    || j        | j                  }t9          |j        | j                 ||||||d |d	
  
        }|                    || j        | j        z            }|                     ||          }|                     |          d d d df         }n|N|j         d         dk    r=|j         d         dk    r,|j         }||d d d d d f         z                      |          }|                     |          }t          j        | j                                                   }| j!        d
t)          d          fk    ri nd| j!        i}| j"        r|tG          || j        j                            d          | j        j        | j        |f| j        | j$        d | j        | j        j        | j        j%        | j        j        | j        j        | j        | j        ddd|\  }}ngt          j
        || j        | j        | j        gd          \  }}}tL          	| j        dvr]| '                    |                     |(                    dd                    (                    dd          d d d |f                   }nstM          |(                    dd          | j        j                            d          | j        j        | j                  (                    dd          d d d |f         }t          j
        || j        ||gd          \  }}}|N|j         d         dk    r=|j         d         dk    r,|j         }||d d d d d f         z                      |          }tS          |                    ||d| j                  |||                    ||| j        d          |                    ||| j        d          f| j$        | j        d d d| j        dd|\  }}|'|%|j        | j                 *                    |           |                    ||d          }|                     ||          }|                     |          }|S )Nr-   r   r   r,   r8   .r5   T)zr   dt_softplusg        rA   dt_limitF)r   r/   seq_idxr`   rmsnorm_weightrmsnorm_epsoutproj_weightoutproj_biasheaddimngroupsnorm_before_gatereturn_final_states)r   swish)xr   r   r`   )r/   r   r   r   r   r   r   )+r#   rM   r   rL   rU   rW   r   squeezer   r$   splitr   r]   rc   r   r   r   r`   expr   floatr;   rV   rm   r   r   r   viewr   r^   r   r   r5   r   trainingr   r/   r   r   ra   	transposer   copy_)rR   r   r   re   r   rH   seq_len_groups_time_state_sized_to_removein_projected_statesd_mlpsplit_projection_dimr   hidden_states_B_CdtBCr   r   r   hidden_states_reshapedoutr5   projected_statesdt_limit_kwargs	ssm_state	time_stepscan_outputs                                r(   cuda_kernels_forwardz Mamba2Mixer.cuda_kernels_forward	  sJ    "/!4
GQ!%1D!D$001t}3DtGZ3ZZ]a]kk #(BQ(F(F"&,,}/D/DQ/G/G"H"H(.r2[@QFE$)5$2H$-Y]Yg#h 05<OQekm0n0n0n-Aq$)2 4!(8"**1-- ! ! #(+!')?AWX# # #M1a
 4:++--...A!!!T3,111d
+222t}dFYZZ]]didq]rrAAAAqqq$J&&r2t}==Bl111dC<077DMJJGqqq$|$++B>>Az4=!'!*2MNNAz4=!'!*2MNNA%2%7%7
DNTXTa%b%b"2'7&   M *..z4>DM;YZZM IImT::M--..qqq$|<CC )n.B1.E.I.InNbcdNehiNiNi%+!.111d
1K!K O OPU V V#||M::4:++--...A$($8S%,,<O$O$ObbV`bfbvUwO} F1!5!A$K&..q11K$L" f# ##'9#3 $	 :#'=#7!%!3 M M%*(,#" "$ &%" "YY, 6;[$+T]DNK6 6 62' $+tFW/W/W(,$5$?$?1$E$EFFPPQRTUVVWXWXWXZb[bZbWbc) )%% )9+55a;;#{199!<<![-#'?	) ) )
  i1ooaaa'k)3% ',k%+-CE[\' ' '#q!
 "-.2Fq2IA2M2MR`RfghRilmRmRm)/E%2^AAAqqq$J5O%O$S$STY$Z$ZM)B!&&z7BNNFF:wrBBFF:wrBB*  $f (, L $* * &* *&Y (\-E +DN;AA)LLL)..z7BGG"iiT::mmK00
r*   c                    2 |j         \  }}}|j        }                     |                    d                    }	|	j         d         d j        z  z
  d j        z   j        z  z
   j        z
  dz  }
|	                    |
|
 j         j	         j        gd          \  }}}}}|]|j
         j                                                 }|                    |j                  }|j        dk    r|j         j                 }t#          j        |dd          }|j        dk    r|d d dd d f         n||d d d d df<   |j         j                                     |           t#          j        |                    |	j                   j        j        d d dd d f         z  d          } j        r| j        j        z  }                     |                              |          d d d df         }n|                    dd          }t8          j                            | j        |j         d         z
  df          }|j         j                                     |                                                     |                              dd                    d d d |d d f         }|N|j         d         dk    r=|j         d         dk    r,|j        }||d d d d d f         z                      |          }nt#          j         | j         j!         j        f|j        |	          }                                          |                    dd                    dd |f                             dd                    }t#          j        | j         j         j        z   j         j        z  gd          \  }}}t#          j"         j#        $                                           }|-|j        dk    r!|j        dk    r|d d d df         n|d d dd d f         d d d df         }|                    dd          %                    ||j         d          j!                  } j&        d
         %                     j&        j         d          j!                  }t"          j        j        '                    ||                    |j                  z             }t#          j(        | j)                  }|d         %                     j         j!         j                                      t"          j*                  }t#          j"        |d
         |z            }|+                    | j        d          dd d d f         }|%                    | j         j         j        z  |j         d                   ,                                }|+                    |d|j         d                   }|d
         |dd d d f         z  }|+                    |d j!                  }||d
         z  }|j
         j                                     |j
         j                 |z  |z              |+                    | j        d          dd d d f         }|%                    | j         j         j        z  |j         d                   ,                                }|+                    |d|j         d                   }|j
         j                                     |j                  }|-                    | j        z   j!         j                  }|-                    | j        z   j        d          }t#          j.        ||          }|-                    | j         j!                  } j/        d
         %                     j/        j         d          j!                  }|||z  z                       |j                  }|+                    |d          d d d df         }nat8          j        '                    | j&        z             }t#          j(        | j)                  }|+                    ||d j!                  $                                }|+                    ||d j                  $                                }|+                    ||d j                  $                                }|0                    dd j         j        z  d          }|0                    dd j         j        z  d          } j1        | j1        z  z
   j1        z  2 j/        d
         te          |2          z  }||d
         z  }|                    |j                  |z  }2 fd||||fD             \  }}}}|3                    dddd          }t#          j4        |d          }t#          j"        tk          |                    }|d d d d d d d d d d d f         |d d d d d d d d d d d f         z  }|                    d          } | d
         |3                    ddddd          d
         z  }!|!                    d          }"|"d
         |d d d d d f         z                      d          }#t#          j"        |d d d d d d dd f         |z
            }$||$3                    dddd          d
         z  }%|%3                    ddddd          d
         |3                    ddddd          dd d d f         z                      d          3                    ddddd          }&|)|j        dk    r|j
         j                 d d d df         }'n t#          j6        |&d d d df                   }'t#          j7        |'|&gd          }&t#          j"        tk          t8          j                            |d d d d d d df         d                              }(|&3                    ddddd          })|(d         |)d d d d d df         z                      d          }*|*3                    ddddd          }+|+d d d df         |+d d df         }}&t#          j"        |          },|dd d d f         |&d d d d d df         z  }-|,3                    dddd          }.|-                    d          |.d
         z  }/|#|/z   }|+                    |d j         j!                  }||z   }2dk    r|d d d |d d d d f         }|+                    ||d          }|'|%|j
         j                                     |            8                    ||          }0 9                    |0                    |                    }1|1S )Nr   r,   r-   r8   r   rh   r	   .r3   r2   ).NNr   c                 <    g | ]}t          |j                  S r}   )r0   r/   )rP   tr   rR   s     r(   
<listcomp>z-Mamba2Mixer.torch_forward.<locals>.<listcomp>  s)    %z%z%z\]&9!Xt&W&W%z%z%zr*   r   )r   r   ):r#   r5   r   r   rL   rM   r   rU   r   r   r^   rc   clonerm   r4   rW   r]   r$   rl   ndimr   sumr   r   r   r   ra   r   r   r%   r&   rO   rK   rV   r   r   r   r;   r   softplusrk   r   r   r.   
contiguousr   bmmr   repeatr/   r)   permuter@   rD   
zeros_likecatr   r   )3rR   input_statesr   re   r   rH   r   r   r5   r   r   r   r   r   r   ro   r   r   r   r   dAdBdBxr^   ssm_states_reshaped
C_reshapedyr   
D_residualA_cumsumLG_intermediateGM_intermediateMY_diagdecay_statesB_decay_contractionstatesprevious_statesdecay_chunkstates_permutedresult
new_statesstate_decay_outC_times_statesstate_decay_out_permutedY_offr   contextualized_statesr   s3   `                                                 @r(   torch_forwardzMamba2Mixer.torch_forward  s;   !-!3
GQ" LL)=)=a)@)@AA!'+a$2H.HHAPTP]L]`d`sLssuy  vD  D  IJ  J(8(>(>t5t~V\^ )? )
 )
%1dM2
 #$/?EEGGI!]%9::I)A--)5dnE
"Z
2BGGG
ANASWXAXAX}QQQ111W'='=^k
111aaa8$(8>>zJJJ %	*--8H8O*P*PSWS^SefgfgfgijlmlmlmfmSn*ntv w w w% 6!T[%55M $ 7 7 : :5 A A!!!T3, O - 7 7! < <]..!*]-@-DDaH 
 (8>>zJJJ $])C)C)M)MaPQ)R)R S STUTUTUW_X_W_abababTb c!-.2Fq2IA2M2MR`RfghRilmRmRm)/E%2^AAAqqq$J5O%O$S$STY$Z$ZMT^T]D<OP$+5  I !HHT[[1H1HA1N1N%O%OPSU]V]U]P]%^%h%hijlm%n%nooM#k-$:PRVR_bfbuRuw{  xE  HL  H[  x[  :\  bd  e  e  eq!Ytz''))***#(BQ(F(F &(W\\AAAtSL!!r!!!Q'{111dC<7PBa##**:rx|T]SSBl9-44T\5G5JDMZZG$--b7::bh3G3G.GHHBR!344B/"))$.$-I\]]``glgt`uuA2i=1,--B
 		*dmR88dAAAFAT]DNdm4SUVU\]_U`aallnnA		*b!'"+66AI3aaa<0B *11*b$-PPM}Y//C #DN399'7"<sB   		*dmR88dAAAFAT]DNdm4SUVU\]_U`aallnnA		*b!'"+66A &0@CCAGLLJ",//*t~2Mt}^b^q"r"r
T^ ;T=PRSTTJ	-z::Az4>4=AAA y!((a$-HHA]Q&&**1733A 		*b))!!!T3,7AA ''T\(9::BR!344B)11*gr4=YY__aaM		*gD4GHHNNPPA		*gr43FGGMMOOAAt~>BBAAt~>BBA'DO*CCtVH	*-?x-X-XXJ *ByM9M]())B.A &{%z%z%z%zboqrtuwxay%z%z%z"M1a 		!Q1%%A|A2...H 	+a..))A qqq!!!QQQaaa23a111dAAAqqq!!!8K6LLN""r"**A y\AIIaAq!,D,DY,OON""r"**A 	l]111aaa:%>>CCAFFF !9hqqq!!!QQQ|&<x&GIIL"#l&:&:1aA&F&Fy&Q"Q)11!Q1a@@K}OdOdefhiklnoqrOsOstwy}  @A  @A  @A  uA  PB  B  G  G  LM  G  N  N  V  V  WX  Z[  ]^  `a  cd  e  eF'L,F,J,J"."9$."I!!!TSV,"W"'"26!!!RaR%="A"AY8a@@@F)K0A0A(111aaaQRQRQRTV;BWY_0`0`$a$abbK$nnQ1a;;O!/2_QQQ4QT_5UUZZ_`ZaaF1aA66J *111crc6 2Jqqq"u4EIF $i11OT111oqqq!!!T30GGN'6'>'>q!Q'J'J$#''++.Fy.QQE A		*b$.$-HHAJA!||aaa'111aaa'(		*gr22A$)A'7==iHHHii4((
 !%knnU.C.C D D$$r*   c                 @   t           r0d| j        j        j        j        v r|                     ||||          S |j        }|G|j        d         dk    r6|j        d         dk    r%||d d d d d f         z                      |          }| 	                    ||||          S )Ncudar   r   )
r   r   r   r4   typer   r5   r#   rm   r  )rR   r   r   re   r   r5   s         r(   r   zMamba2Mixer.forwardR  s     " 	jf0C0J0O&O&O,,]L.Zhiii#%.*>q*AA*E*E.J^_`JadeJeJe*^AAAqqq$J-GGKKERRM!!-~~^^^r*   NNN)ru   rv   rw   rx   r   rY   rb   r$   r{   r   rF   r|   r   r  r   r   r   s   @r(   r   r      s        @| @ @ @ @ @ @ @J /35915K K|K {+K !!12	K
 !.K K K K\x% x%8M x%cklql|c} x%  U]  ^c  ^j  Uk x% x% x% x%| /35915_ _ {+_ !!12	_
 !._ _ _ _ _ _ _ _r*   r   c                   &     e Zd Zd fd	Zd Z xZS )Mamba2RMSNormr   c                     t                                                       t          j        t	          j        |                    | _        || _        dS )zM
        Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
        Nr   r   s      r(   rb   zMamba2RMSNorm.__init__d  sD     	l5:k#:#:;; #r*   c                    |j         }|                    t          j                  }|                    d                              dd          }|t          j        || j        z             z  }| j        |                    |          z  S r   )	r5   rm   r$   r   r   r   r   r   r   )rR   r   r   r   s       r(   r   zMamba2RMSNorm.forwardl  s|    #)%((77 $$Q'',,R,>>%Ht?T4T(U(UU{]--k::::r*   r   r   r   s   @r(   r  r  c  sL        $ $ $ $ $ $; ; ; ; ; ; ;r*   r  c                   r     e Zd Z fdZ	 	 	 ddee         deej                 deej                 fdZ	 xZ
S )Mamba2Blockc                     t                                                       || _        || _        |j        | _        t          |j        |j                  | _        t          ||          | _
        d S )Nr   rc   )r   rb   rG   rc   residual_in_fp32r  rZ   r   r   r   mixer)rR   rG   rc   r   s      r(   rb   zMamba2Block.__init__u  sd    " & 7!&"4&:STTT	 9===


r*   Nr   re   r   c                    |}|                      |                    | j         j        j                            }| j        r|                    t
          j                  }|                     ||||          }||z   }|S )Nr   r   re   r   )r   rm   r   r5   r%  r$   r   r&  )rR   r   r   re   r   residuals         r(   r   zMamba2Block.forward}  s     !		-"2"29I9O"2"P"PQQ  	2{{5=11H

^dr # 
 
 !=0r*   r  )ru   rv   rw   rb   r   rF   r$   r|   r{   r   r   r   s   @r(   r"  r"  t  s        > > > > > /35915  {+ !!12	
 !.       r*   r"  c                   .    e Zd ZdZeZdZdgZdZdZ	d Z
dS )Mamba2PreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    backboner"  Tc                 R   t          |t                    r?d|j        _        d|j        _        t          j        t          j        | j        j	                  t          j        | j        j                  t          j        | j        j                  z
  z  t          j        | j        j                  z                                 | j        j                  }|t          j        t          j        |                      z   }t          j                    5  |j                            |           ddd           n# 1 swxY w Y   d|j        _        t          |t*          j                  rB|j        :t1          |j        dd          s$t*          j                            |j                   nJt          |t*          j                  r0t*          j                            |j        | j        j                   | j        j        r|                                 D ]\  }}|dv rt*          j        !                    |t          j"        d          	           t          j                    5  |t          j"        | j        j#                  z  }ddd           n# 1 swxY w Y   dS dS )
zInitialize the weights.T)minN
_no_reinitF)std)zout_proj.weight   )a)$
isinstancer   r   r   r   r$   r   randrG   rU   mathr   r   r   rk   time_step_floorexpm1no_gradr   r   r/  r   r   r   getattrinitzeros_	Embeddingnormal_r   initializer_rangerescale_prenorm_residualnamed_parameterskaiming_uniform_sqrtr\   )rR   moduler   inv_dtnameps         r(   _init_weightsz#Mamba2PreTrainedModel._init_weights  s   fk** 	-,0FL)(,FH%
4;0118DK566$+B[9\9\\^(4;4556  e3e44	  %)U["%5%5$5666F - -$$V,,,- - - - - - - - - - - - - - -(,FN%fbi(( 	N{&v{L%@@ 0GNN6;///-- 	NGOOFMt{/LOMMM;/ 	F "2244 F Fa...
 G,,Q$)A,,,??? F FTYt{'DEEEF F F F F F F F F F F F F F F	F 	FF Fs$   D??EE+"JJ	 J	N)ru   rv   rw   rx   r   config_classbase_model_prefix_no_split_modulessupports_gradient_checkpointing_is_statefulrG  r}   r*   r(   r+  r+    sR         
  L"&&*#L(F (F (F (F (Fr*   r+  c                       e Zd ZU dZdZeej                 ed<   dZ	ee
         ed<   dZeeej                          ed<   dS )Mamba2Outputa%  
    Class for the MAMBA2 model outputs.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        cache_params (`Mamba2Cache`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.

            Includes both the State space model state matrices after the selective scan, and the Convolutional states
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Nlast_hidden_stater   r   )ru   rv   rw   rx   rO  r   r$   FloatTensor__annotations__r   rF   r   r   r}   r*   r(   rN  rN    si          $ 6:x 12999*.L(;'...8<M8E%"345<<<<<r*   rN  c                       e Zd ZU dZdZeej                 ed<   dZ	eej                 ed<   dZ
ee         ed<   dZeeej                          ed<   dS )Mamba2CausalLMOutputa  
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        cache_params (`Mamba2Cache`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.

            Includes both the State space model state matrices after the selective scan, and the Convolutional states
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Nlosslogitsr   r   )ru   rv   rw   rx   rT  r   r$   rP  rQ  rU  r   rF   r   r   r}   r*   r(   rS  rS    s          ( )-D(5$
%,,,*.FHU&'...*.L(;'...8<M8E%"345<<<<<r*   rS  a@  

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`Mamba2Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            Indices of input sequence tokens in the vocabulary.

            If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        cache_params (`Mamba2Cache`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        use_cache (`bool`, *optional*):
            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z`The bare MAMBA2 Model transformer outputting raw hidden-states without any specific head on top.c                   P    e Zd Z fdZd Zd Zd Z ee           e	e
ee          	 	 	 	 	 	 	 	 ddeej                 deej                 d	ee         d
ee         dee         dee         deej                 deej                 deeef         fd                        Z xZS )Mamba2Modelc                    t                                                     t          j        j        j                  | _        t          j        fdt          j	                  D                       | _
        d| _        t          j        j                  | _        |                     | j                   |                                  d S )Nc                 2    g | ]}t          |           S )r$  )r"  )rP   idxrG   s     r(   r   z(Mamba2Model.__init__.<locals>.<listcomp>4  s&    $s$s$sC[3%G%G%G$s$s$sr*   Fr   )r   rb   r   r<  
vocab_sizerZ   
embeddings
ModuleListr[   r\   layersgradient_checkpointingr  r   norm_f"_register_load_state_dict_pre_hook	load_hook	post_initrR   rG   r   s    `r(   rb   zMamba2Model.__init__0  s       ,v'8&:LMMm$s$s$s$sSXY_YqSrSr$s$s$stt&+##F$6F<UVVV//???r*   c                 v    |D ]5}d|v r/|                     |          ||                    dd          <    d S 6d S )Nz
embedding.zembeddings.)popreplace)rR   
state_dictprefixargsks        r(   rb  zMamba2Model.load_hook<  sW     	 	Aq  EO^^TUEVEV
199\=AAB !	 	r*   c                     | j         S rr   r\  rs   s    r(   get_input_embeddingsz Mamba2Model.get_input_embeddingsB  s
    r*   c                     || _         d S rr   rm  rR   new_embeddingss     r(   set_input_embeddingsz Mamba2Model.set_input_embeddingsE  s    (r*   
checkpointoutput_typerH  N	input_idsinputs_embedsr   	use_cacheoutput_hidden_statesreturn_dictre   r   rf   c	                 f   ||n| j         j        }||n| j        s| j         j        nd}||n| j         j        }|d u |d uz  rt          d          ||                     |          }| j        r| j        r|rd}|rp|\t          | j         |	                    d          |j
        |j                  }t          j        d| j         j        |j
                  }n|t          d          nd }|}
|rdnd }| j        D ]F}| j        r&| j        r|                     |j        |
|||          }
n ||
|||          }
|r||
fz   }G|r|xj        |j        d	         z  c_        |                     |
          }
|r||
fz   }|st+          d
 |
||fD                       S t-          |
|r|nd |          S )NFz:You must specify exactly one of input_ids or inputs_embedsr   r3   r4   zYou have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will be initialized for you automaticallyr}   r(  r   c              3      K   | ]}||V  	d S rr   r}   )rP   vs     r(   	<genexpr>z&Mamba2Model.forward.<locals>.<genexpr>  s(      ffqXYXeXeXeXeXeffr*   )rO  r   r   )rG   ry  r   rx  use_return_dict
ValueErrorr\  r_  rF   r:   r4   r5   r$   r   rX   r^  _gradient_checkpointing_func__call__rW   r#   r`  tuplerN  )rR   rv  rw  r   rx  ry  rz  re   r   kwargsr   all_hidden_statesmixer_blocks                r(   r   zMamba2Model.forwardH  s   & %9$D  $+Jj 	 "+!6IIZ^Zg=rT[=R=Rmr	%0%<kk$+B]-t";< 	[YZZZ  OOI66M& 	4= 	Y 	I 	 #*K!3!3A!6!6}?S[h[n      "'a1HQ^Qe!f!f!f' !;  	 (  L%"6@BBD; 	I 	IK* 
t} 
 $ A A(-~We! ! !,!!-#1#1	! ! ! $ I$58H$H! 	A&&-*=a*@@&&M22 	E 1]4D D 	gff]LBS$Tffffff+)2<+
 
 
 	
r*   )NNNNNNNN)ru   rv   rw   rb   rb  rn  rr  r   MAMBA2_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCrN  _CONFIG_FOR_DOCr   r$   r|   rF   r>   r{   r   r   r   r   r   s   @r(   rW  rW  +  s~       

 
 
 
 
    ) ) ) +*+BCC& $   1548.2$(/3&*5915N
 N
E,-N
   01N
 {+	N

 D>N
 'tnN
 d^N
 !!12N
 !.N
 
ul"	#N
 N
 N
  DCN
 N
 N
 N
 N
r*   rW  z
    The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input
    embeddings).
    c                       e Zd Zg Z fdZd Zd Zd Zd Z	 	 	 	 	 dde	e
         de	ej                 d	e	ej                 fd
Z ee           eeee          	 	 	 	 	 	 	 	 	 dde	ej                 de	ej                 de	e
         de	ej                 de	e         de	e         de	e         de	ej                 d	e	ej                 deeef         fd                        Z xZS )Mamba2ForCausalLMc                     t                                          |           t          |          | _        t	          j        |j        |j        d          | _        | 	                                 d S )NFr   )
r   rb   rW  r,  r   r   rZ   r[  lm_headrc  rd  s     r(   rb   zMamba2ForCausalLM.__init__  s^       #F++y!3V5FUSSSr*   c                     | j         S rr   r  rs   s    r(   get_output_embeddingsz'Mamba2ForCausalLM.get_output_embeddings  s
    |r*   c                     || _         d S rr   r  rp  s     r(   set_output_embeddingsz'Mamba2ForCausalLM.set_output_embeddings  s    %r*   c                 4    | j                                         S rr   )r,  rn  rs   s    r(   rn  z&Mamba2ForCausalLM.get_input_embeddings  s    }11333r*   c                 6    | j                             |          S rr   )r,  rr  rp  s     r(   rr  z&Mamba2ForCausalLM.set_input_embeddings  s    }11.AAAr*   Nr   re   r   c                    ||j         d         |j         d         z   }n|j         d         }|r|t          d          |d         dk    r%|d d df         d         }|d d df         d         }nst          j        d||j                  }t          j        |                    d          ||j         d         z
  |j                  }	t          j        ||	gd          }d }|j         d         |k     rUt          j        |                    d          ||j         d         z
  |j                  }	t          j        ||	gd          }||d|i}
nd	|i}
|
                    ||||d
           |
S )Nr   z`cache_position` should not be None as it should have been initialized in `model.generate`, you are responsible for passing in a valid `cache_position` if you are calling `prepare_inputs_for_generation` directly with `use_cache=True`r   r,   r2   r|  r8   rw  rv  )r   r   rx  re   )	r#   r  r$   r   r4   r=   r:   r   update)rR   rv  rw  rx  r   re   r   r  past_lenextended_maskmodel_inputss              r(   prepare_inputs_for_generationz/Mamba2ForCausalLM.prepare_inputs_for_generation  s    $$*1-	0BBHH q)H 	$% e   a 1$$%aaae,Y7	!/2!6y!A "'a)BR!S!S!S !&
"''**H~7KA7N,NWeWl! ! ! "'NM+JPQ!R!R!R#"X--
 "J##A&&>3G3J(JSaSh  M #Y'FANNNN$)=+];LL'3L"0 ,&"0	 	
 	
 	
 r*   rs  rv  rw  labelsry  rz  rx  rf   c
           
         ||n| j         j        }|                     ||||||||	          }|d         }|                     |                    | j        j        j                                                            }d}||                    |j                  }|dddddf         	                                }|dddf         	                                }t                      } ||                    d|                    d                    |                    d                    }|s|f|dd         z   }||f|z   n|S t          |||j        |j                  S )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        N)r   rw  ry  rz  rx  re   r   r   .r,   r   )rT  rU  r   r   )rG   r  r,  r  rm   r   r5   r   r4   r   r   r   r:   rS  r   r   )rR   rv  rw  r   r  ry  rz  rx  re   r   r  mamba2_outputsr   rU  rT  shift_logitsshift_labelsloss_fctoutputs                      r(   r   zMamba2ForCausalLM.forward  s   2 &1%<kk$+B]%'!5#)) ' 	
 	
 'q)m..t|/B/HIIJJPPRRYYv}--F!#ssAAA+.99;;L!#qrr'?5577L'))H8L--b,2C2CB2G2GHH,J[J[\^J_J_``D 	FY!33F)-)9TGf$$vE#'4(6	
 
 
 	
r*   )NNNNN)	NNNNNNNNN)ru   rv   rw   _tied_weights_keysrb   r  r  rn  rr  r   rF   r$   r|   r{   r  r   r  r   r  rS  r  rP  r>   r   r   r   r   r   s   @r(   r  r    s               & & &4 4 4B B B .25915@ @
 {+@ !!12@ !.@ @ @ @D +*+BCC&($   1559.2-1/3&*$(15157
 7
E,-7
   127
 {+	7

 )*7
 'tn7
 d^7
 D>7
 !.7
 !.7
 
u**	+7
 7
 7
  DC7
 7
 7
 7
 7
r*   r  )?rx   r5  dataclassesr   typingr   r   r   r$   torch.utils.checkpointr   torch.nnr   activationsr
   
generationr   modeling_utilsr   utilsr   r   r   r   r   utils.import_utilsr   r   configuration_mamba2r   
get_loggerru   r   +mamba_ssm.ops.triton.selective_state_updater   !mamba_ssm.ops.triton.ssd_combinedr   r   causal_conv1dr   r   allr   r  r  r{   rY   r)   r0   rD   rF   Moduler   r   r  r"  r+  rN  rS  MAMBA2_START_DOCSTRINGr  rW  r  r}   r*   r(   <module>r     s      ! ! ! ! ! ! ) ) ) ) ) ) ) ) ) )            % % % % % % ! ! ! ! ! ! ) ) ) ) ) ) - - - - - -              W V V V V V V V . . . . . . 
	H	%	%  "RRRRRRmmmmmmmmm! 8DDDDDDDDD-7**46FH\]^^ 9  VU\ VS V V V V
 
 
(  (8  8  8  8  8  8  8  8 v; ; ; ; ; ; ; ;$a_ a_ a_ a_ a_") a_ a_ a_H; ; ; ; ;BI ; ; ;"    ")   84F 4F 4F 4F 4FO 4F 4F 4Fn = = = = =; = = =0 = = = = =; = = =6   : f m
 m
 m
 m
 m
' m
 m
	 m
`   U
 U
 U
 U
 U
- U
 U
 U
 U
 U
r*   