
    g)p                        d dl Z d dlmZmZ d dlZd dlZddlm	Z	 ddl
mZ ddlmZ  ee          ZdZ G d d	          Z G d
 d          Z G d de          Z G d de          Z G d de          Z G d de          Z G d de          Z G d de          Z G d de          Z G d de          Z G d de          Z G d de          Z G d  d!e          Z G d" d#e          Z G d$ d%e          Z dS )&    N)ListTuple   )stable_softmax)add_start_docstrings)
get_loggeraZ  
    Args:
        input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
            Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
            search or log softmax for each vocabulary token when using beam search.
        cur_len (`int`):
            The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
            is the maximum length generate can produce, and we need to know which of its tokens are valid.
        kwargs (`Dict[str, Any]`, *optional*):
            Additional logits processor specific kwargs.

    Return:
        `tf.Tensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
c            	       j    e Zd ZdZ ee          dej        dej        dedej        fd            Z	dS )TFLogitsProcessorzSAbstract base class for all logit processors that can be applied during generation.	input_idsscorescur_lenreturnc                 0    t          | j         d          )z TF method for processing logits.H is an abstract class. Only classes inheriting this class can be called.NotImplementedError	__class__selfr   r   r   s       e/var/www/html/ai-engine/env/lib/python3.11/site-packages/transformers/generation/tf_logits_process.py__call__zTFLogitsProcessor.__call__8   $     "~ggg
 
 	
    N
__name__
__module____qualname____doc__r   $TF_LOGITS_PROCESSOR_INPUTS_DOCSTRINGtfTensorintr    r   r   r
   r
   5   sj        ]]>??
") 
RY 
 
QSQZ 
 
 
 @?
 
 
r   r
   c            	       j    e Zd ZdZ ee          dej        dej        dedej        fd            Z	dS )TFLogitsWarperzjAbstract base class for all logit warpers that can be applied during generation with multinomial sampling.r   r   r   r   c                 0    t          | j         d          )zTF method for warping logits.r   r   r   s       r   r   zTFLogitsWarper.__call__C   r   r   Nr   r#   r   r   r%   r%   @   sj        tt>??
") 
RY 
 
QSQZ 
 
 
 @?
 
 
r   r%   c            	       j    e Zd ZdZ ee          dej        dej        dedej        fd            Z	dS )TFLogitsProcessorListz
    This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
    This class inherits from list and adds a specific *__call__* method to apply each [`TFLogitsProcessor`] to the
    inputs.
    r   r   r   r   c                    | D ]}t          j        |j                  j        }t	          |          dk    rt          fdt          |                                          dd          D                       s:t          dt          |                                           d|j	         d           ||||fi } ||||          }|S )N   c              3       K   | ]}|v V  	d S Nr#   ).0argkwargss     r   	<genexpr>z1TFLogitsProcessorList.__call__.<locals>.<genexpr>W   s'      SSS3&=SSSSSSr   r   z,Make sure that all the required parameters: z for z$ are passed to the logits processor.)
inspect	signaturer   
parameterslenalllistkeys
ValueErrorr   )r   r   r   r   r/   	processorfunction_argss       `  r   r   zTFLogitsProcessorList.__call__R   s    
	? 
	?I#-i.@AALM=!!A%%SSSSD9K9K9M9M4N4Nqrr4RSSSSS $UtML^L^L`L`GaGa U U$.U U U   #9fgHHHH"9fg>>r   Nr   r#   r   r   r(   r(   K   sp          >??") RY  [][d    @?  r   r(   c                   T    e Zd ZdZdefdZdej        dej        dedej        fdZ	d	S )
TFTemperatureLogitsWarperz
    [`TFLogitsWarper`] for temperature (exponential scaling output probability distribution).

    Args:
        temperature (`float`):
            The value used to module the logits distribution.
    temperaturec                 n    t          |t                    r|dk    st          d|           || _        d S )Nr   z:`temperature` has to be a strictly positive float, but is )
isinstancefloatr8   r=   )r   r=   s     r   __init__z"TFTemperatureLogitsWarper.__init__k   sC    +u-- 	ikAoogZegghhh&r   r   r   r   r   c                     || j         z  }|S r,   )r=   r   s       r   r   z"TFTemperatureLogitsWarper.__call__q   s    $**r   N)
r   r   r   r   r@   rA   r    r!   r"   r   r#   r   r   r<   r<   b   sq         'E ' ' ' '") RY  QSQZ      r   r<   c                   v    e Zd ZdZ ed           dfdededefdZdej        d	ej        d
edej        fdZ	dS )TFTopKLogitsWarpera  
    [`TFLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.

    Args:
        top_k (`int`):
            The number of highest probability vocabulary tokens to keep for top-k-filtering.
        filter_value (`float`, *optional*, defaults to -inf):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.
    Inf   top_kfilter_valuemin_tokens_to_keepc                     t          |t                    r|dk    rt          d|           t          ||          | _        || _        d S )Nr   z6`top_k` has to be a strictly positive integer, but is )r?   r"   r8   maxrG   rH   )r   rG   rH   rI   s       r   rA   zTFTopKLogitsWarper.__init__   sU    %%% 	_!]V[]]^^^ 233
(r   r   r   r   r   c                     t          | j        |j        d                   }|t          j                            ||          d         ddd f         k     }t          j        || j        |          }|S )N)kr   .)minrG   shaper    mathwhererH   )r   r   r   r   rG   indices_to_removenext_scoress          r   r   zTFTopKLogitsWarper.__call__   sf    DJR 011"RW]]6U]%C%CA%FsBCCx%PPh0$2CVLLr   N
r   r   r   r   r@   r"   rA   r    r!   r   r#   r   r   rD   rD   v   s        
 
 ;@%,,bc ) )c ) )\_ ) ) ) )") RY  QSQZ      r   rD   c                   v    e Zd ZdZ ed           dfdededefdZdej        d	ej        d
edej        fdZ	dS )TFTopPLogitsWarpera.  
    [`TFLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to <= prob_cut_off.

    Args:
        top_p (`float`):
            If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
            higher are kept for generation.
        filter_value (`float`, *optional*, defaults to -inf):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.
    rE   rF   top_prH   rI   c                     t          |t                    r|dk     s|dk    rt          d|           t          |t                    r|dk     rt          d|           || _        || _        || _        d S )Nr   g      ?z.`top_p` has to be a float > 0 and < 1, but is rF   z:`min_tokens_to_keep` has to be a positive integer, but is )r?   r@   r8   r"   rX   rH   rI   )r   rX   rH   rI   s       r   rA   zTFTopPLogitsWarper.__init__   s    %'' 	WEAIIUeUUVVV,c22 	p7IA7M7MnZlnnooo
("4r   r   r   r   r   c                    t           j                            ||j        d                   \  }}t          j        |j        | j                  }t           j                            t          |d          d          }|| j        k     }t          j	        t          j
        |j        d         dgt           j                  |d d d df         fd          }t          j	        t          j
        |j        d         | j        gt           j                  |d d | j        d f         fd          }t          j        |||          }	t          j        t          j        t          j        |j        d                   d          d|j        d         g          }
t          j        |
|fd          }t          j        ||	|	j                  }|S )NrM   axisr   rF   dtype)rP   )r    rQ   rG   rP   fillrH   cumsumr   rX   concatonesboolrI   rR   tileexpand_dimsrangestack
scatter_nd)r   r   r   r   topk_scorestopk_indicesmask_scorescumulative_probs
score_masktopk_next_scoresscatter_rowsscatter_indicesrT   s                r   r   zTFTopPLogitsWarper.__call__   s   $&GMM&&,r:J$K$K!\gflD,=>>7>>.2*N*N*NUW>XX%
2
 Y)9!)<a(@ P P PR\]^]^]^`cac`c]cRdelnooo
 Y)!,d.EFbgVVV111d57778 
 
 

 8J[II
 wr~bh|7I!7L.M.MTVWWWZ[]i]opr]sYtuu(L,#?bIIImO5EM]Mcdddr   NrU   r#   r   r   rW   rW      s          =BE%LL=de 5 5e 55 5^a 5 5 5 5") RY  QSQZ      r   rW   c                   |    e Zd ZdZdedefdZdej        dej        fdZdej        dej        d	edej        fd
Z	dS )TFMinLengthLogitsProcessora1  
    [`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.

    Args:
        min_length (`int`):
            The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
        eos_token_id (`int`):
            The id of the *end-of-sequence* token.
    
min_lengtheos_token_idc                     t          |t                    r|dk     rt          d|           t          |t                    r|dk     rt          d|           || _        || _        d S )Nr   z2`min_length` has to be a positive integer, but is z4`eos_token_id` has to be a positive integer, but is )r?   r"   r8   rs   rt   )r   rs   rt   s      r   rA   z#TFMinLengthLogitsProcessor.__init__   s    *c** 	`j1nn^R\^^___,,, 	dq0@0@bT`bbccc$(r   r   r   c                     t          j        |j        d                   | j        k    }t          j        |t          d          |          }|S )NrM   z-inf)r    rf   rP   rt   rR   r@   )r   r   eos_token_id_masks      r   _apply_eos_token_maskz0TFMinLengthLogitsProcessor._apply_eos_token_mask   s@    HV\"%566$:KK+U6]]FCCr   r   r   c                 t     t          j        t          j        | j                   fdfd          S )Nc                  .                                    S r,   )rx   )r   r   s   r   <lambda>z5TFMinLengthLogitsProcessor.__call__.<locals>.<lambda>   s    D..v66 r   c                  ,    t          j                   S r,   r    identityr   s   r   r{   z5TFMinLengthLogitsProcessor.__call__.<locals>.<lambda>       BK'' r   )r    condlessrs   r   s   ` ` r   r   z#TFMinLengthLogitsProcessor.__call__   sH    GGT_--66666''''
 

 r   N)
r   r   r   r   r"   rA   r    r!   rx   r   r#   r   r   rr   rr      s         )3 )c ) ) ) )BI ")    
") RY  QSQZ      r   rr   c                       e Zd ZdZdefdZdej        dej        dej        fdZdej        dej        d	e	dej        fd
Z
dS )"TFRepetitionPenaltyLogitsProcessora%  
    [`TFLogitsProcessor`] enforcing an exponential penalty on repeated sequences.

    Args:
        repetition_penalty (`float`):
            The parameter for repetition penalty. 1.0 means no penalty. See [this
            paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
    penaltyc                 n    t          |t                    r|dk    st          d|           || _        d S )Nr   z6`penalty` has to be a strictly positive float, but is )r?   r@   r8   r   )r   r   s     r   rA   z+TFRepetitionPenaltyLogitsProcessor.__init__   s@    '5)) 	a'A++_V]__```r   r   logitsr   c           	         t          j        ||dd          }t          j        |dk    d| j        z  |          }t          j        |dk     | j        |          }t          j        |j                  }|j        d         }t          j        |          d         }t          j        t          j        t          j        t          j	        |          |          d          t          j        t          j
        |dg          d          fd          }t          j        ||t          j
        |dg                    }|S )NrF   )r\   
batch_dimsr   rM   r[   indicesupdates)r    gatherrR   r   rb   rP   ra   re   repeatrf   reshapetensor_scatter_nd_update)r   r   r   logit_penaltiestoken_penalties
batch_sizeseq_lenindexable_prev_input_idss           r   _create_score_penaltiesz:TFRepetitionPenaltyLogitsProcessor._create_score_penalties   s+    )FIA!LLL(?Q#6DL8H/ZZ(?Q#6oVV '&,//_Q'
(9%%a(#%9ry*)=)=wGGbQQQrz)bT::DDD $
 $
 $
  5%=rzRadfcgGhGh
 
 
 r   r   r   c                     |                      |d d d |f         |          }t          j                            ||          }|S r,   )r   r    rQ   multiply)r   r   r   r   score_penaltiess        r   r   z+TFRepetitionPenaltyLogitsProcessor.__call__  sD    66yHWH7MvVV!!&/::r   N)r   r   r   r   r@   rA   r    r!   r   r"   r   r#   r   r   r   r      s              BI RTR[    4") RY  QSQZ      r   r   c                       e Zd ZdZdeee                  defdZdej        dej        fdZ	dej        d	ej        d
edej        fdZ
dS )TFNoBadWordsLogitsProcessora"  
    [`TFLogitsProcessor`] that enforces that specified sequences will never be sampled.

    Args:
        bad_words_ids (`List[List[int]]`):
            List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
            that should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing
            the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space`
            argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from
            `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
        eos_token_id (`int`):
            The id of the *end-of-sequence* token.
    bad_words_idsrt   c                    t          |t                    rt          |          dk    rt          d| d          t	          d |D                       rt          d| d          t	          d |D                       rt          d| d          t
          j                            |                              d	          | _	        d
 |D             }t	          d |D                       rt          d| d          t          j
        |t
          j                  | _        t          j
        d |D                       | _        d S )Nr   z3`bad_words_ids` has to be a non-empty list, but is .c              3   B   K   | ]}t          |t                     V  d S r,   )r?   r6   r-   bad_word_idss     r   r0   z7TFNoBadWordsLogitsProcessor.__init__.<locals>.<genexpr>2  s/      TTl:lD111TTTTTTr   z2`bad_words_ids` has to be a list of lists, but is c              3   H   K   | ]}t          d  |D                       V  dS )c              3   f   K   | ],}t          |t          t          j        f           p|d k     V  -dS r   N)r?   r"   npinteger)r-   token_ids     r   r0   zATFNoBadWordsLogitsProcessor.__init__.<locals>.<genexpr>.<genexpr>5  s@      kkRZZ3
*;<<<L1kkkkkkr   N)anyr   s     r   r0   z7TFNoBadWordsLogitsProcessor.__init__.<locals>.<genexpr>4  sN       
 
 kk^jkkkkk
 
 
 
 
 
r   zKEach list in `bad_words_ids` has to be a list of positive integers, but is rM   )default_valuec                 ,    g | ]}t          |          S r#   )r4   r-   	bad_wordss     r   
<listcomp>z8TFNoBadWordsLogitsProcessor.__init__.<locals>.<listcomp>@  s    KKK	S^^KKKr   c              3   "   K   | ]
}|d k    V  dS r   r#   )r-   word_lens     r   r0   z7TFNoBadWordsLogitsProcessor.__init__.<locals>.<genexpr>A  s&      ??x1}??????r   zBanned words token sequences z cannot have an empty listr]   c                     g | ]
}|d          S )rM   r#   r   s     r   r   z8TFNoBadWordsLogitsProcessor.__init__.<locals>.<listcomp>E  s    9g9g9gI)B-9g9g9gr   )r?   r   r4   r8   r   r    raggedconstant	to_tensorbad_word_seqs_idsconvert_to_tensorint32bad_word_seqs_lenseq_forbidden_tokens)r   r   rt   r   s       r   rA   z$TFNoBadWordsLogitsProcessor.__init__/  s   -.. 	e#m2D2D2I2IcS`cccdddTTmTTTTT 	dbR_bbbccc 
 
 -
 
 
 
 
 	 n^knnn   "$!3!3M!B!B!L!L[]!L!^!^KK]KKK??->????? 	hf]fffggg!#!56Grx!X!X!X$&$89g9gYf9g9g9g$h$h!!!r   row_input_idsr   c                       fd}t          j        |t          j         j        j        d                   t           j                  } j        |         }|S )Nc                 N      fd} fd fd |            }|S )Nc                      t          j        t           j                            j                 d          d            S )NrF   c                  B    t          j        dt           j                  S Nr#   r]   r    rb   rc   r#   r   r   r{   zrTFNoBadWordsLogitsProcessor._calc_row_banned_bad_tokens.<locals>._tokens_match.<locals>._len_one.<locals>.<lambda>M      BGBbg666 r   )r    r   rQ   equalr   )_len_greater_than_cur_lenbad_word_seq_numberr   s   r   _len_onez`TFNoBadWordsLogitsProcessor._calc_row_banned_bad_tokens.<locals>._tokens_match.<locals>._len_oneI  s<    wGMM$"89L"MqQQ66-  r   c                      t          j        t           j                            j                 t          j                  d                   d            S )Nr   c                  B    t          j        dt           j                  S r   r    zerosrc   r#   r   r   r{   zTFNoBadWordsLogitsProcessor._calc_row_banned_bad_tokens.<locals>._tokens_match.<locals>._len_greater_than_cur_len.<locals>.<lambda>U      BHRrw777 r   )r    r   rQ   greaterr   rP   )_match_foundr   r   r   s   r   r   zqTFNoBadWordsLogitsProcessor._calc_row_banned_bad_tokens.<locals>._tokens_match.<locals>._len_greater_than_cur_lenQ  sO    wGOOD$:;N$OQSQYZgQhQhijQkll77   r   c                      j                  dz
  } t          j        t          j                            t          j                            |  d          j        d | f                             d d           S )NrF   c                  B    t          j        dt           j                  S r   r   r#   r   r   r{   zvTFNoBadWordsLogitsProcessor._calc_row_banned_bad_tokens.<locals>._tokens_match.<locals>._match_found.<locals>.<lambda>c  r   r   c                  B    t          j        dt           j                  S r   r   r#   r   r   r{   zvTFNoBadWordsLogitsProcessor._calc_row_banned_bad_tokens.<locals>._tokens_match.<locals>._match_found.<locals>.<lambda>d  r   r   )r   r    r   rQ   
reduce_allr   r   )compare_lenr   r   r   s    r   r   zdTFNoBadWordsLogitsProcessor._calc_row_banned_bad_tokens.<locals>._tokens_match.<locals>._match_foundY  s     #45HIAMwG&&);,--8$:PQdfrgrfrQr:s  
 7677  r   r#   )r   r   matchr   r   r   r   s   `  @@r   _tokens_matchzNTFNoBadWordsLogitsProcessor._calc_row_banned_bad_tokens.<locals>._tokens_matchH  s                        HJJELr   r   fn_output_signature)r    map_fnrf   r   rP   rc   r   )r   r   r   
match_maskrow_banned_tokenss   ``   r   _calc_row_banned_bad_tokensz7TFNoBadWordsLogitsProcessor._calc_row_banned_bad_tokensG  si     	  	  	  	  	  	F Y}bht7M7STU7V.W.Wmomtuuu
 5jA  r   r   r   r   c                      dt           t          j                 dt          j        f fd}t          j        |||ft          j                  }|S )N
row_inputsr   c                 0   | \  }}                     |d                    }t          j        t          j        |d          t          j        |t          j                  |j                  }t          j        |t          d           |          }|S )NrM   r[   r]   )r   r   rP   inf)	r   r    rh   re   	ones_likerc   rP   rR   r@   )r   r   	row_scorebanned_tokensbanned_tokens_maskr   r   s        r   _get_row_updated_scorezDTFNoBadWordsLogitsProcessor.__call__.<locals>._get_row_updated_scoreu  s    '1$M9 <<]8G8=TUUM!#}2>>>]"'BBBo" " "
 !3eEll]INNIr   r   )r   r    r!   r   float32)r   r   r   r   r   s   `  ` r   r   z$TFNoBadWordsLogitsProcessor.__call__o  se    		uRY/? 		BI 		 		 		 		 		 		 		 1Iv3F\^\fgggr   N)r   r   r   r   r   r"   rA   r    r!   r   r   r#   r   r   r   r      s         id49o iS i i i i0&! &!ry &! &! &! &!P") RY  QSQZ      r   r   c                   Z    e Zd ZdZdefdZd Zdej        dej        dedej        fd	Z	d
S )TFNoRepeatNGramLogitsProcessora7  
    [`TFLogitsProcessor`] that enforces no repetition of n-grams. See
    [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).

    Args:
        ngram_size (`int`):
            All ngrams of size `ngram_size` can only occur once.
    
ngram_sizec                 n    t          |t                    r|dk    rt          d|           || _        d S )Nr   z;`ngram_size` has to be a strictly positive integer, but is )r?   r"   r8   r   )r   r   s     r   rA   z'TFNoRepeatNGramLogitsProcessor.__init__  s@    *c** 	ijAoog[egghhh$r   c                 H   	
 dz    j         k     rd t          |          D             S d t          |          D             |d d d f         t          |          D ]}|                                                                         
|         }t	          
fdt           j                   D              D ]<}t          |d d                   }|                    |g           |d         gz   ||<   = fd		fdt          |          D             }|S )NrF   c                     g | ]}g S r#   r#   r-   _s     r   r   zKTFNoRepeatNGramLogitsProcessor.calc_banned_ngram_tokens.<locals>.<listcomp>  s    1111B111r   c                     g | ]}i S r#   r#   r   s     r   r   zKTFNoRepeatNGramLogitsProcessor.calc_banned_ngram_tokens.<locals>.<listcomp>  s    9991B999r   c                 $    g | ]}|d          S r,   r#   )r-   i
gen_tokenss     r   r   zKTFNoRepeatNGramLogitsProcessor.calc_banned_ngram_tokens.<locals>.<listcomp>  s!    NNN!z!""~NNNr   rM   c                     dz   j         z
  }t          | |f                                                                                   }|                              |g           S )NrF   )r   tuplenumpytolistget)hypo_idx	start_idx	ngram_idxr   generated_ngramsprev_input_idsr   s      r   _get_generated_ngramszVTFNoRepeatNGramLogitsProcessor.calc_banned_ngram_tokens.<locals>._get_generated_ngrams  sb    !do5InXy7H-HIOOQQXXZZ[[I#H-11)R@@@r   c                 &    g | ]} |          S r#   r#   )r-   r   r   s     r   r   zKTFNoRepeatNGramLogitsProcessor.calc_banned_ngram_tokens.<locals>.<listcomp>  s%    ZZZX..x88ZZZr   )r   rf   r   r   zipr   r   )r   r   	num_hyposr   idxgenerated_ngramngramprev_ngram_tupler   r   r   r   r   s   `  `     @@@@r   calc_banned_ngram_tokensz7TFNoRepeatNGramLogitsProcessor.calc_banned_ngram_tokens  s   Q;((11i 0 0111199i(8(8999"111hwh;/## 	l 	lC',2244;;==J.s3ONNNNuT_7M7MNNNO l l#(ss#4#4 4C4G4GHXZ\4]4]afgiaj`k4k 011l	A 	A 	A 	A 	A 	A 	A 	A [ZZZyIYIYZZZr   r   r   r   r   c                 |   t          j                    st          d          |j        \  }}|                     |||          }g }|D ]0|                    fdt          |          D                        1t          j        t          j        |t           j	                  t          d           |          }|S )NzGTFNoRepeatNGramLogitsProcessor is only implemented for eager execution.c                      g | ]
}|v rd ndS )TFr#   )r-   tokenbanned_tokens_slices     r   r   z;TFNoRepeatNGramLogitsProcessor.__call__.<locals>.<listcomp>  s(    ```U"5555```r   r]   r   )r    executing_eagerlyr   rP   r  appendrf   rR   r   rc   r@   )	r   r   r   r   r   
vocab_sizer   banned_tokens_indices_maskr  s	           @r   r   z'TFNoRepeatNGramLogitsProcessor.__call__  s     #%% 	q%&oppp!'
J55iWUU &("#0 	 	&--````eT^N_N_```    "./IQSQXYYY\abg\h\h[hjpqqr   N)
r   r   r   r   r"   rA   r  r    r!   r   r#   r   r   r   r     s         %3 % % % %
  0") RY  QSQZ      r   r   c                   T    e Zd ZdZdefdZdej        dej        dedej        fdZd	S )
TFForcedBOSTokenLogitsProcessorz
    [`TFLogitsProcessor`] that enforces the specified token as the first generated token.

    Args:
        bos_token_id (`int`):
            The id of the token to force as the first generated token.
    bos_token_idc                 D    |dk     rt          d|           || _        d S )Nr   z=The forced bos token id  must be a non-negative integer, got )r8   r  )r   r  s     r   rA   z(TFForcedBOSTokenLogitsProcessor.__init__  s3    !k]ikklll(r   r   r   r   r   c           	         |dk    r|j         \  }}t          j        |df          }| j        dk    r@t          j        t          j        t          d           || j        f          |fd          }| j        |dz
  k     rFt          j        |t          j        t          d           ||dz
  | j        z
  f          fd          }|S NrF   r   r   rM   r[   )rP   r    r   r  ra   broadcast_tor@   r   r   r   r   r   
num_tokenss         r   r   z(TFForcedBOSTokenLogitsProcessor.__call__  s    a<<%+\"J
Xz1o..F 1$$BOU5\\MJPTPaCb$c$cek#lsuvvv JN33R_eEll]Z*WX.\`\mIm<noop   r   N	r   r   r   r   r"   rA   r    r!   r   r#   r   r   r  r    sq         )S ) ) ) )
") RY  QSQZ      r   r  c                   X    e Zd ZdZdedefdZdej        dej        dedej        fd	Zd
S )TFForcedEOSTokenLogitsProcessorac  
    [`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.

    Args:
        max_length (`int`):
            The maximum length of the sequence to be generated.
        eos_token_id (`int`):
            The id of the token to force as the last generated token when `max_length` is reached.
    
max_lengthrt   c                 R    || _         |dk     rt          d|           || _        d S )Nr   z<The forced eos token id must be a non-negative integer, got )r  r8   rt   )r   r  rt   s      r   rA   z(TFForcedEOSTokenLogitsProcessor.__init__  s:    $!j\hjjkkk(r   r   r   r   r   c           	         || j         dz
  k    r|j        \  }}t          j        |df          }| j        dk    r@t          j        t          j        t          d           || j        f          |fd          }| j        |dz
  k     rFt          j        |t          j        t          d           ||dz
  | j        z
  f          fd          }|S r  )r  rP   r    r   rt   ra   r  r@   r  s         r   r   z(TFForcedEOSTokenLogitsProcessor.__call__  s    do)))%+\"J
Xz1o..F 1$$BOU5\\MJPTPaCb$c$cek#lsuvvv JN33R_eEll]Z*WX.\`\mIm<noop   r   Nr  r#   r   r   r  r    sx         )3 )c ) ) ) )") RY  QSQZ      r   r  c                   N    e Zd ZdZd Zdej        dej        dedej        fdZdS )	&TFSuppressTokensAtBeginLogitsProcessora!  
    [`TFSuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts
    generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not
    sampled at the begining of the generation.
    c                 <    t          |          | _        || _        d S r,   )r6   begin_suppress_tokensbegin_index)r   r  r  s      r   rA   z/TFSuppressTokensAtBeginLogitsProcessor.__init__  s!    %)*?%@%@"&r   r   r   r   r   c                 L    g  j         D ]Lj        d         k     r9                    fdt          j        d                   D                        Mt	                    dk    r6t          j        t          j        | j                   fdfd          S )NrM   c                     g | ]}|gS r#   r#   r-   r   r  s     r   r   zCTFSuppressTokensAtBeginLogitsProcessor.__call__.<locals>.<listcomp>      *V*V*V!Au:*V*V*Vr   r   c                      t          j         d t           j        d         t	          j                  z            D                       S )Nc                 .    g | ]}t          d            S r   r@   r   s     r   r   zUTFSuppressTokensAtBeginLogitsProcessor.__call__.<locals>.<lambda>.<locals>.<listcomp>  s     mmmqeEll]mmmr   r   r   )r    r   rf   rP   r4   r  )r   r   suppressed_indicess   r   r{   zATFSuppressTokensAtBeginLogitsProcessor.__call__.<locals>.<lambda>  sO    3.mmE&,q/CPTPjLkLk:k4l4lmmm   r   c                       S r,   r#   r   s   r   r{   zATFSuppressTokensAtBeginLogitsProcessor.__call__.<locals>.<lambda>       r   )	r  rP   extendrf   r4   r    r   r   r  r   r   r   r   r&  r  s   ` ` @@r   r   z/TFSuppressTokensAtBeginLogitsProcessor.__call__
  s    / 	X 	XEv|B'''"))*V*V*V*VuV\RS_?U?U*V*V*VWWW!""Q&&W$"233     
  F r   N	r   r   r   r   rA   r    r!   r"   r   r#   r   r   r  r    se         ' ' '") RY  QSQZ      r   r  c                   N    e Zd ZdZd Zdej        dej        dedej        fdZdS )	TFSuppressTokensLogitsProcessorzThis processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
    are not sampled.c                 .    t          |          | _        d S r,   )r6   suppress_tokens)r   r/  s     r   rA   z(TFSuppressTokensLogitsProcessor.__init__!  s    #O44r   r   r   r   r   c                     g } j         D ]L|j        d         k     r9|                    fdt          |j        d                   D                        Mt	          |          dk    rrt          j        | fdt          |j        d                   D             d t          |j        d         t	           j                   z            D                       }|S )NrM   c                     g | ]}|gS r#   r#   r   s     r   r   z<TFSuppressTokensLogitsProcessor.__call__.<locals>.<listcomp>(  r!  r   r   c                 ,    g | ]}j         D ]}||gS r#   )r/  )r-   r   r  r   s      r   r   z<TFSuppressTokensLogitsProcessor.__call__.<locals>.<listcomp>-  s.    fffQUQeff!Uffffr   c                 .    g | ]}t          d            S r$  r%  r   s     r   r   z<TFSuppressTokensLogitsProcessor.__call__.<locals>.<listcomp>.  s     ccc1%,,cccr   r   )r/  rP   r)  rf   r4   r    r   r*  s   `    @r   r   z(TFSuppressTokensLogitsProcessor.__call__$  s    ) 	X 	XEv|B'''"))*V*V*V*VuV\RS_?U?U*V*V*VWWW!""Q&&0ffffU6<?-C-Cfffccfl1oDL`HaHa6a0b0bccc  F
 r   Nr+  r#   r   r   r-  r-    se         5 5 5") RY  QSQZ      r   r-  c                   l    e Zd ZdZdeee                  fdZdej        dej        dedej        fdZ	d	S )
TFForceTokensLogitsProcessora$  This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
    indices that will be forced before sampling. The processor will set their log probs to `0` and all other tokens to
    `-inf` so that they are sampled at their corresponding index.force_token_mapc                 :   t          |          }t          j        t          |                                          dz   t          j                  dz  }|                                D ]\  }}||||<   t          j        |t          j                  | _	        d S )NrF   r]   rM   )
dictr   rb   rK   r7   r   itemsr    r   force_token_array)r   r6  r:  indexr  s        r   rA   z%TFForceTokensLogitsProcessor.__init__8  s    // GS)=)=)?)?%@%@1%DRXVVVY[[+1133 	1 	1LE5 +0!%(!#!56Grx!X!X!Xr   r   r   r   r   c                       fdt          j        t          j        t          j         j                  d                   fd fd          S )Nc                    j         d         }j        |          }t          j        j                  t          j        j        j        g          z   }t          j        t          j        |          t          j	        |g|g          fd          }t          j
        |fj                  }t          j        |||          }|S )Nr   r]   rF   r[   )rP   r:  r    
zeros_liker^   r   rO   rg   rf   rd   r   r   )generation_idxr   current_token
new_scoresr   r   r   r   s         r   _force_tokenz;TFForceTokensLogitsProcessor.__call__.<locals>._force_tokenD  s    aJ 2>BMvV\BBBR[RXR^RbQcEdEddJh 4 4bg}oPZ|6\6\]defffGh
}FLAAAG4Z'RRJr   r   c                  ,    t          j                   S r,   r}   r   s   r   r{   z7TFForceTokensLogitsProcessor.__call__.<locals>.<lambda>Q  r   r   c                  z    t          j        t          j        j                 d           fdfd          S )Nr   c                                  S r,   r#   )rB  r   s   r   r{   zITFForceTokensLogitsProcessor.__call__.<locals>.<lambda>.<locals>.<lambda>V  s    W-- r   c                       S r,   r#   r   s   r   r{   zITFForceTokensLogitsProcessor.__call__.<locals>.<lambda>.<locals>.<lambda>X  r(  r   )r    r   greater_equalr:  )rB  r   r   r   s   r   r{   z7TFForceTokensLogitsProcessor.__call__.<locals>.<lambda>S  sC    BG !7!@!DD-----  r   )r    r   rG  rP   r:  )r   r   r   r   rB  s   ` ``@r   r   z%TFForceTokensLogitsProcessor.__call__C  s    	 	 	 	 	 	 Wbht/E&F&Fq&IJJ''''      
 
 r   N)
r   r   r   r   r   r"   rA   r    r!   r   r#   r   r   r5  r5  3  s        E E	YT#Y 	Y 	Y 	Y 	Y") RY  QSQZ      r   r5  )!r1   typingr   r   r   r   
tensorflowr    tf_utilsr   utilsr   utils.loggingr   r   loggerr   r
   r%   r6   r(   r<   rD   rW   rr   r   r   r   r  r  r  r-  r5  r#   r   r   <module>rN     s"                     % % % % % % ( ( ( ( ( ( & & & & & & 
H		( $.
 
 
 
 
 
 
 

 
 
 
 
 
 
 
    D   .       (       85 5 5 5 5 5 5 5p! ! ! ! !!2 ! ! !H/ / / / /): / / /da a a a a"3 a a aH9 9 9 9 9%6 9 9 9x    &7   <    &7   B    ->   <    &7   ,( ( ( ( (#4 ( ( ( ( (r   