
    g,                        d dl Z d dlZd dlZd dlmZmZ d dlmZ d dlm	Z	m
Z
mZmZmZmZmZmZmZmZ d dlZd dlmZ ddlmZmZmZmZ dd	lmZmZmZ e	rdd
lm Z  ddl!m"Z" ddl#m$Z$ ddl%m&Z&  e            rd dl'm(Z(  ej)        e*          Z+dZ,dZ-ej.         G d d                      Z/ G d de          Z0 G d de0e          Z1 G d de1          Z2dS )    N)ABCabstractmethod)OrderedDict)
TYPE_CHECKINGAnyCallableDictIterableListMappingOptionalTupleUnion)version   )
TensorTypeis_torch_availableis_vision_availablelogging   )ParameterFormat compute_effective_axis_dimension"compute_serialized_parameters_size)PretrainedConfigFeatureExtractionMixinImageProcessingMixinPreTrainedTokenizerBase)Image   l        c                   f    e Zd ZU dZeed<   eed<   eed<   dZe	e         ed<   dZ
e	e         ed<   dS )PatchingSpeca  
    Data class that holds patching specifications.

    Args:
        o: Module / object where the op to patch is located
        name: Name of the op to monkey patch
        custom_op: Custom op that patches the original op
        orig_op: Original op that is being patched
        op_wrapper: Wrapper (optional) that wraps both the original and custom ops.
            It is useful for ops that are class or static methods for instance.
    oname	custom_opNorig_op
op_wrapper)__name__
__module____qualname____doc__r   __annotations__strr   r(   r   r)        T/var/www/html/ai-engine/env/lib/python3.11/site-packages/transformers/onnx/config.pyr$   r$   /   sf         
 
 FFF
III"&GXh&&&%)J")))))r1   r$   c                   ,   e Zd ZdZdZdZdZ ej        d          Z	 e
dddd	i           e
d
ddd	i           e
dddd	i           e
ddd	ddd	ddd	d           e
dddd	i           e
dddd	i           e
dddii           e
ddd	ddd	d           e
ddd	ddd	d           e
ddddddi           e
dddd	i           e
dddii           e
dddd	i           e
dddd	i           e
dddd	i          dZdIdddedee         fdZedJdddedd fd            Zeedeeeeef         f         fd                        Zedeeeeef         f         fd            Zedeeeef                  fd             Zedefd!            Zedefd"            Zedefd#            Zedefd$            Zedefd%            Z ede!fd&            Z"e#d'ede!fd(            Z$	 dKd+ed,ed-ed.efd/Z%	 dLd+ed3ed4ed5efd6Z&	 	 	 	 	 	 	 	 	 	 	 	 dMd9e'd:         d+ed;ed<ed=e!d>ee(         d,ed.ed-ed3ed4ed5ed?d@deeef         fdAZ)dBeeef         deeef         fdCZ*dD Z+dE Z,edFedGe-e         de.eef         fdH            Z/dS )N
OnnxConfigzv
    Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
    r         z1.8logitsbatchsequencer   r   last_hidden_state)r7   
pred_boxes
pred_masksr   )r7   r<   )start_logits
end_logits
num_labelsheightwidth)r   r   r      decoder_sequence)z	causal-lmdefaultzimage-classificationzimage-segmentationz	masked-imz	masked-lmmultiple-choicezobject-detectionzquestion-answeringzsemantic-segmentationz
seq2seq-lmzsequence-classificationztoken-classificationzvision2seq-lmzspeech2seq-lmrE   Nconfigr   taskpatching_specsc                 L   || _         || j        vr+t          | d| j                                                   || _        g | _        ||ng D ]S}|}|j        .t          j        |t          |j
        |j                            }| j                            |           Td S )Nz+ is not a supported task, supported tasks: )r(   )_config_tasks_to_common_outputs
ValueErrorkeysrH   _patching_specsr(   dataclassesreplacegetattrr%   r&   append)selfrG   rH   rI   spec
final_specs         r2   __init__zOnnxConfig.__init__o   s    t444jjDDaDfDfDhDhjj   	!&4&@NNb 	4 	4DJ|#(0wtvty?Y?YZZZ
 ''
3333		4 	4r1   returnc                      | ||          S )z
        Instantiate a OnnxConfig for a specific model

        Args:
            config: The model's configuration to use when exporting to ONNX

        Returns:
            OnnxConfig for this model
        )rH   r0   clsrG   rH   s      r2   from_model_configzOnnxConfig.from_model_config   s     s6%%%%r1   c                     t                      )z
        Mapping containing the axis definition of the input tensors to provide to the model

        Returns:
            For each input: its name associated to the axes symbolic name and the axis position within the tensor
        )NotImplementedErrorrT   s    r2   inputszOnnxConfig.inputs   s     "###r1   c                 N    | j         | j                 }t          j        |          S )z
        Mapping containing the axis definition of the output tensors to provide to the model

        Returns:
            For each output: its name associated to the axes symbolic name and the axis position within the tensor
        )rL   rH   copydeepcopy)rT   common_outputss     r2   outputszOnnxConfig.outputs   s#     6tyA}^,,,r1   c                 8    t          | j        d          rddiS dS )z
        Dictionary of keys to override in the model's config before exporting

        Returns:
            Dictionary with the keys (and their corresponding values) to override
        	use_cacheFN)hasattrrK   r_   s    r2   values_overridezOnnxConfig.values_override   s(     4<-- 	(''tr1   c                     t           j        S )zp
        The default batch size to use if no other indication

        Returns:
            Integer > 0
        )r4   default_fixed_batchr_   s    r2   default_batch_sizezOnnxConfig.default_batch_size   s     --r1   c                     t           j        S )zu
        The default sequence length to use if no other indication

        Returns:
            Integer > 0
        )r4   default_fixed_sequencer_   s    r2   default_sequence_lengthz"OnnxConfig.default_sequence_length   s     00r1   c                     t           j        S )zw
        The default number of choices to use if no other indication

        Returns:
            Integer > 0
        )r4   default_fixed_num_choicesr_   s    r2   default_num_choiceszOnnxConfig.default_num_choices   s     33r1   c                     t           S )z{
        Which onnx opset to use when exporting the model

        Returns:
            Integer ONNX Opset version
        )DEFAULT_ONNX_OPSETr_   s    r2   default_onnx_opsetzOnnxConfig.default_onnx_opset   s
     "!r1   c                     dS )z
        What absolute tolerance value to use during model conversion validation.

        Returns:
            Float absolute tolerance value.
        gh㈵>r0   r_   s    r2   atol_for_validationzOnnxConfig.atol_for_validation   s	     tr1   c                 x    t                      r+ddlm} t          j         |                      | j        k    S dS )z
        The minimum PyTorch version required to export the model.

        Returns:
            `bool`: Whether the installed version of PyTorch is compatible with the model.
        r   )get_torch_versionF)r   transformers.utilsry   r   parsetorch_onnx_minimum_version)rT   ry   s     r2   is_torch_support_availablez%OnnxConfig.is_torch_support_available   sK      	<<<<<<=!2!2!4!4559XXX5r1   num_parametersc                 H    t          | t          j                  t          k    S )a  
        Flag indicating if the model requires using external data format

        Args:
            num_parameters: Number of parameter on the model

        Returns:
            True if model.num_parameters() * size_of(float32) >= 2Gb False otherwise
        )r   r   FloatEXTERNAL_DATA_FORMAT_SIZE_LIMIT)r~   s    r2   use_external_data_formatz#OnnxConfig.use_external_data_format   s!     /~?TUU./	
r1   rC   (   
batch_sizenum_channelsimage_heightimage_widthc                    g }t          |          D ]s}t          j                            |||          dz  }|                    t          j        |                    d                                        d                     t|S )N   uint8RGB)	rangenprandomrandrS   r!   	fromarrayastypeconvert)rT   r   r   r   r   images_datas           r2   _generate_dummy_imagesz!OnnxConfig._generate_dummy_images  s|     z"" 	P 	PA9>>,\JJSPDMM%/$++g*>*>??GGNNOOOOr1   "V        @   sampling_ratetime_duration	frequencyc           	          g }t          |          D ]g}t          j        d|t          ||z            d          }|                    dt          j        dt          j        z  |z  |z            z             h|S )Nr   F)endpointg      ?r   )r   r   linspaceintrS   sinpi)rT   r   r   r   r   
audio_datar   ts           r2   _generate_dummy_audioz OnnxConfig._generate_dummy_audio  s     
z"" 	G 	GAA}c--2O.P.P[`aaaA cBF1ru9y+@1+D$E$EEFFFFr1   Fpreprocessor)r    r   r   
seq_lengthnum_choicesis_pair	framework	tokenizerr    c                    ddl m} ddlm} ddlm} t          ||          r|t          d          |6t          j	        dt                     t                              d           |}t          ||          rPt          |t          j        d	
          }|                    |          }t          |t          j        |
          }|j        t'          |j                  d	k    r|j        nd}d                    |g          |z  g|z  }| j        dk    rt          t          j        d	
          |z  } |||          }|                                D ]3\  }fdt1          d	t'                              D             ||<   4t3          |                    |                    S t3           |||                    S t          ||          r|j        d	         dk    r*t          d|j        j         d|j        d	                    t          |t          j                  }|                     |||	|          }t3           |||                    S t          ||          r^|j        d	         dk    rMt          |t          j                  }|                     |||	|          }t3           |||                    S t          ||          r^|j        d	         dk    rMt          |t          j                  }|                     ||
||          }t3           |||                    S t          d          )am  
        Generate inputs to provide to the ONNX exporter for the specific framework

        Args:
            preprocessor: ([`PreTrainedTokenizerBase`], [`FeatureExtractionMixin`], or [`ImageProcessingMixin`]):
                The preprocessor associated with this model configuration.
            batch_size (`int`, *optional*, defaults to -1):
                The batch size to export the model for (-1 means dynamic axis).
            num_choices (`int`, *optional*, defaults to -1):
                The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
            seq_length (`int`, *optional*, defaults to -1):
                The sequence length to export the model for (-1 means dynamic axis).
            is_pair (`bool`, *optional*, defaults to `False`):
                Indicate if the input is a pair (sentence 1, sentence 2)
            framework (`TensorType`, *optional*, defaults to `None`):
                The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for.
            num_channels (`int`, *optional*, defaults to 3):
                The number of channels of the generated images.
            image_width (`int`, *optional*, defaults to 40):
                The width of the generated images.
            image_height (`int`, *optional*, defaults to 40):
                The height of the generated images.
            sampling_rate (`int`, *optional* defaults to 22050)
                The sampling rate for audio data generation.
            time_duration (`float`, *optional* defaults to 5.0)
                Total seconds of sampling for audio data generation.
            frequency (`int`, *optional* defaults to 220)
                The desired natural frequency of generated audio.

        Returns:
            Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
        r   r   r   r   NzPYou cannot provide both a tokenizer and a preprocessor to generate dummy inputs.ztThe `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.zSOverwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.r   )fixed_dimensionnum_token_to_add0 rF   )	text_pairc                 *    g | ]}||z            S r0   r0   ).0ir   vs     r2   
<listcomp>z4OnnxConfig.generate_dummy_inputs.<locals>.<listcomp>r  s'    )h)h)hQ!AK,?*@)h)h)hr1   )tensor_type)return_tensorspixel_valuesz*The `preprocessor` is an image processor (zC) and expects `model_input_names[0]` to be "pixel_values", but got )r   )r   r   input_featuresz\Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor.) feature_extraction_utilsr   image_processing_utilsr   tokenization_utils_baser    
isinstancerM   warningswarnFutureWarningloggerwarningr   r4   rk   num_special_tokens_to_addrn   	unk_tokenlenjoinrH   rq   itemsr   dictconvert_to_tensorsmodel_input_names	__class__r*   r   r   )rT   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    token_to_addinput_tokendummy_inputtokenized_inputkr   s       `                 @r2   generate_dummy_inputsz OnnxConfig.generate_dummy_inputs  s   ` 	FEEEEEAAAAAAEEEEEEl$;<< 	qAVoppp M+  
 NNpqqq$Ll$;<< 8	9J,J]^  J (AA'JJL9J,M`l  J !*63|?U;V;VYZ;Z;Z && 
 88[M22Z?@:MKy--- ?1Uhi   *K7".,{k"R"R"R+1133 i iDAq)h)h)h)h)h%PQSVWXSYSY[fJgJg)h)h)hOA&&O>>9>UUVVV[KKKLLL&:;; 	-a0NBB qAWA` q qMYMklmMnq q  
 :*V`VtuuuJ55j,P\^ijjKK	RRRSSS&<== 	,B`abBcguBuBu9*V`VtuuuJ55j,P\^ijjKK	RRRSSS|%;<<
	AMA_`aAbfvAvAv :*V`VtuuuJ44ZP]_hiiK[KKKLLLn  r1   reference_model_inputsc                     |S )a  
        Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq
        models which have the encoder and decoder exported as separate ONNX files.

        Args:
            reference_model_inputs ([`Mapping[str, Tensor]`):
                Reference inputs for the model.

        Returns:
            `Mapping[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function
        r0   )rT   r   s     r2   !generate_dummy_inputs_onnxruntimez,OnnxConfig.generate_dummy_inputs_onnxruntime  s
     &%r1   c                     | j         D ]E}|j        |j        n|                    |j                  }t          |j        |j        |           Fd S N)rO   r)   r'   setattrr%   r&   )rT   rU   r'   s      r2   	patch_opszOnnxConfig.patch_ops  sY    ( 	2 	2D*./*AtW[WeGfGfIDFDIy1111	2 	2r1   c                     | j         D ]E}|j        |j        n|                    |j                  }t          |j        |j        |           Fd S r   )rO   r)   r(   r   r%   r&   )rT   rU   r(   s      r2   restore_opszOnnxConfig.restore_ops  sY    ( 	0 	0D&*o&=dll4??SWS_C`C`GDFDIw////	0 	0r1   r&   fieldc                 l    ddl m} fdt          |                    |                    D             S )a  
        Flatten any potential nested structure expanding the name of the field with the index of the element within the
        structure.

        Args:
            name: The name of the nested structure
            field: The structure to, potentially, be flattened

        Returns:
            (Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.

        r   )chainc                 &    i | ]\  }} d | |S ).r0   )r   idxitemr&   s      r2   
<dictcomp>zAOnnxConfig.flatten_output_collection_property.<locals>.<dictcomp>  s)    ]]])#t4#]]]r1   )	itertoolsr   	enumeratefrom_iterable)r[   r&   r   r   s    `  r2   "flatten_output_collection_propertyz-OnnxConfig.flatten_output_collection_property  sH     	$#####]]]]yATATUZA[A[7\7\]]]]r1   )rE   NrE   )r   rC   r   r   )r   r   r   r   )r   r   r   FNrC   r   r   r   r   r   N)0r*   r+   r,   r-   rk   rn   rq   r   r{   r|   r   rL   r/   r   r$   rW   classmethodr\   propertyr   r   r   r`   re   r   r   ri   rl   ro   rr   ru   floatrw   boolr}   staticmethodr   r   r   r   r   r   r   r   r   r
   r	   r   r0   r1   r2   r4   r4   D   s~          !!.u!5!5 [(J,G,G!HII; 3Z5P5PQRR +X7z7R7R,S T T)k%*55")j99")j99 
 
 ![(J,G,G!HII [(J,G,G!HII&;1g,'?@@'K%*55")j99 
 
 *k$+
 ; ;")j99 
 
 "-hGYafm8n8n-o!p!p!k8=O-P-P"QRR#.;1g,/G#H#H +X7z7R7R,S T T$hG
0K0K%LMM$hG
0K0K%LMM?     D4 41 4 4Z^_kZl 4 4 4 4  
& 
&'9 
& 
&Ua 
& 
& 
& [
& $WS#X%6 67 $ $ $ ^ X$ -gc3h&7!78 - - - X- 
'#s(*;!< 
 
 
 X
 .C . . . X. 1 1 1 1 X1 4S 4 4 4 X4 "C " " " X" U    X D    X 
 
 
 
 
 \
" fh 14HK_b    mp 25NSfi     *.""/3v vghv v 	v
 v v J'v v v v v v v -v 
c	v v v vp&PSUXPXHY &^efiknfn^o & & & &2 2 2
0 0 0
 ^c ^(3- ^TXY\^aYaTb ^ ^ ^ [^ ^ ^r1   r4   c                       e Zd Z	 	 	 d dddedee         def fd	Zed!ddded
d fd            Z	e
d
eeeeef         f         f fd            Ze
d
eeeef                  fd            Ze
d
efd            Ze
d
efd            Z	 	 	 	 d"dddedededee         d
eeef         f fdZ	 d#deeeeef         f         dedefdZd Zdedee         d
eeef         f fdZ xZS )$OnnxConfigWithPastrE   NFrG   r   rH   rI   use_pastc                 ^    t                                          |||           || _        d S )N)rH   rI   )superrW   r   )rT   rG   rH   rI   r   r   s        r2   rW   zOnnxConfigWithPast.__init__  s/     	d>JJJ r1   rX   c                      | ||d          S )z
        Instantiate a OnnxConfig with `use_past` attribute set to True

        Args:
            config: The underlying model's config to use when exporting to ONNX

        Returns:
            OnnxConfig with `.use_past = True`
        T)rH   r   r0   rZ   s      r2   	with_pastzOnnxConfigWithPast.with_past  s     s6t4444r1   c                 j    t                      j        }| j        r|                     |d           |S )Nre   	direction)r   re   r   fill_with_past_key_values_)rT   rd   r   s     r2   re   zOnnxConfigWithPast.outputs  s8    = 	Q++Ni+PPPr1   c                 B    t          | j        d          r	d| j        iS d S )Nrg   )rh   rK   r   r_   s    r2   ri   z"OnnxConfigWithPast.values_override  s(    4<-- 	0//tr1   c                 b    t          | j        d          st          d          | j        j        S )z
        The number of layers attribute retrieved from the model config. Override this for model configs where the
        number of layers attribute is not called `num_layers`.
        
num_layerszcould not find the number of layers attribute in the model configuration, override the num_layers property of the model OnnxConfig to solve this)rh   rK   AttributeErrorr   r_   s    r2   r   zOnnxConfigWithPast.num_layers  s=     t|\22 	 B   |&&r1   c                 b    t          | j        d          st          d          | j        j        S )z
        The number of attention heads attribute retrieved from the model config. Override this for model configs where
        the number of attention heads attribute is not called `num_attention_heads`.
        num_attention_headszcould not find the number of attention heads attribute in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this)rh   rK   r   r   r_   s    r2   r   z&OnnxConfigWithPast.num_attention_heads  s>     t|%:;; 	 V   |//r1   r   r   r    r   r   r   r   c                 X   t                                          |||||          }| j        rt                      st	          d          dd l}|d         j        \  }}	|	dz   }
|| j        |
| j        j	        | j        z  f}d|v rE|d         j
        }|                    |d         |                    ||
|          gd	          |d<   g |d
<   t          | j                  D ]E}|d
                             |                    |          |                    |          f           F|S )Nr   r   r   r   ACannot generate dummy past_keys inputs without PyTorch installed.r   	input_idsr   attention_mask)dtyper   )dimpast_key_values)r   r   r   r   rM   torchshaper   rK   hidden_sizer  catonesr   r   rS   zeros)rT   r   r   r   r   r   common_inputsr  r8   seqlenpast_key_values_lengthr	  
mask_dtyper   r   s                 r2   r   z(OnnxConfigWithPast.generate_dummy_inputs  sg    55*W`i 6 
 
 = 	b%''  !deee)+6<ME6%+aZ"(&(D,DD	E  =00*+;<B
27))"#34ejjH^fpj6q6qr 3< 3 3./
 02M+,4?++ b b/077U9K9KU[[Y^M_M_8`aaaar1   inputs_or_outputsr   inverted_values_shapec                     |dvrt          d| d          |dk    rdnd}t          | j                  D ]/}ddd	|| d
| d<   |rddd|| d
| d<   !ddd	|| d
| d<   0dS )a  
        Fill the input_or_outputs mapping with past_key_values dynamic axes considering.

        Args:
            inputs_or_outputs: The mapping to fill.
            direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the
                output mapping, this is important for axes naming.
            inverted_values_shape:
                If `True`, store values on dynamic axis 1, else on axis 2.

        r`   re   4direction must either be "inputs" or "outputs", but 
 was givenr`   r  presentr8   zpast_sequence + sequencer   r   r   .keyr:   .valueN)rM   r   r   )rT   r  r   r  r&   r   s         r2   r   z-OnnxConfigWithPast.fill_with_past_key_values_&  s     111iT]iiijjj$-$9$9  yt'' 	e 	eA7>C]3^3^/////0$ e=DIc9d9d!T"5"5A"5"5"566=DIc9d9d!T"5"5A"5"5"566	e 	er1   c                 J    |d         || d| d<   |d         || d| d<   d S )Nr   r   r  r   r  r0   rT   flattened_outputr&   r   r   s        r2   _flatten_past_key_values_z,OnnxConfigWithPast._flatten_past_key_values_?  sH    01!D,,3,,,-23A$D..3...///r1   r&   r   c                     i }|dv r.t          |          D ]\  }}|                     ||||           n"t                                          ||          }|S )N)r  r  )r   r  r   r   )rT   r&   r   r  r   r   r   s         r2   r   z5OnnxConfigWithPast.flatten_output_collection_propertyC  sy    111#E** O OQ../?sANNNNO  %wwII$PUVVr1   )rE   NFr   r   r   FN)F)r*   r+   r,   r/   r   r$   r   rW   r   r   r   r   r   re   r   r   ri   r   r   r   r   r   r  r
   r	   r   __classcell__r   s   @r2   r   r     s        -1! !"! ! \*	!
 ! ! ! ! ! ! 
5 
51 
5 
5Ma 
5 
5 
5 [
5 gc3h&7!78      X '#s(*;!<    X 
'C 
' 
' 
' X
' 
0S 
0 
0 
0 X
0 *.( (,( ( 	(
 ( J'( 
c	( ( ( ( ( (V qve e!(gc3h.?)?!@eMPeime e e e27 7 7 s  8C=  UYZ]_bZbUc                    r1   r   c                   ,    e Zd Zedeeeeef         f         f fd            Zedee         f fd            Z	edee         f fd            Z
	 	 	 	 ddd	d
edededee         deeef         f fdZdeeeeef         f         defdZd Z xZS )OnnxSeq2SeqConfigWithPastrX   c                    t          t          |           j        }|                                D ]6\  }}d|v rdnd}|                                D ]\  }}d|v r|||<   |||<   7| j        r|                     |d           |S )Nencoderencoder_sequencerD   r9   re   r   )r   r   re   r   r   r   )rT   rd   r&   
axes_namessequence_nameaxis_idxr   s         r2   re   z!OnnxSeq2SeqConfigWithPast.outputsO  s    1488@ . 4 4 6 6 	0 	0D*2;t2C2C..I[M","2"2"4"4 0 0$%%+8Jx(( ,0Jx((0 = 	Q++Ni+PPPr1   c                     	 t                      j        }||f}nb# t          $ rU t          | j        d          r.t          | j        d          r| j        j        | j        j        f}nt          d          Y nw xY w|S )Nencoder_layersdecoder_layerszcould not find the number of encoder and decoder layers attributes in the model configuration, override the num_layers property of the model OnnxConfig to solve this)r   r   r   rh   rK   r-  r.  )rT   r   r   s     r2   r   z$OnnxSeq2SeqConfigWithPast.num_layers`  s    
	+J$j1JJ 	 	 	t|%566 74<Qa;b;b "l94<;VW

$^   
	     AA:9A:c                     	 t                      j        }||f}nb# t          $ rU t          | j        d          r.t          | j        d          r| j        j        | j        j        f}nt          d          Y nw xY w|S )Nencoder_attention_headsdecoder_attention_headszcould not find the number of attention heads for the encoder and the decoder attributes in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this)r   r   r   rh   rK   r1  r2  )rT   r   r   s     r2   r   z-OnnxSeq2SeqConfigWithPast.num_attention_headsp  s    	"'''"=#68K"L 	 	 	t|%>?? GDLZsDtDt '+|'KT\Mq&r##$   $#	 #"r/  r   FNr   r    r   r   r   r   c           	         t          t          |                               |||||          }| j        s|nd}t          t          |                               |||||          }d |                                D             }t          di ||}	| j        rt                      st          d          dd l}
|	d         j	        d         }|	d         j	        d         }|	d         j	        d         }| j
        \  }}|||| j        j        |z  f}|||dz   | j        j        |z  f}g |	d	<   | j        \  }}t          ||          }t          ||          |z
  }||k    rd
nd}t!          |          D ]m}|	d	                             |
                    |          |
                    |          |
                    |          |
                    |          f           n|d
k    r|n|}t!          ||          D ]E}|	d	                             |
                    |          |
                    |          f           F|	S )Nr  r   c                      i | ]\  }}d | |S )decoder_r0   )r   r&   tensors      r2   r   zCOnnxSeq2SeqConfigWithPast.generate_dummy_inputs.<locals>.<dictcomp>  s'    ___f+T++V___r1   r  r   r  decoder_input_idsrC   r  r'  decoderr0   )r   r   r   r   r   r   r   rM   r  r	  r   rK   r
  r   minmaxr   rS   r  )rT   r   r   r   r   r   encoder_inputsdecoder_seq_lengthdecoder_inputsr  r  r8   encoder_seq_lengthnum_encoder_attention_headsnum_decoder_attention_headsencoder_shapedecoder_shapenum_encoder_layersnum_decoder_layersmin_num_layersmax_num_layersremaining_side_namer   r	  r   s                           r2   r   z/OnnxSeq2SeqConfigWithPast.generate_dummy_inputs  s    1488NN*W`i O 
 

 04}CZZ!1488NN*9KU\hq O 
 
 `_H\H\H^H^___@@~@@@= -	b%''  !deee!+.4Q7E!.{!;!A!!D!./B!C!I!!LGKG_D')D+"(,GG	M +"Q&(,GGM 02M+,59_2 2 !35GHHN !35GHH>YN/ADV/V/V))\e>** 
 
 /077M22M22M22M22	    &9I%E%EMM=E>>:: b b/077U9K9KU[[Y^M_M_8`aaaar1   r  r   c           	         |dvrt          d| d          |dk    rdnd}| j        \  }}t          ||          }t          ||          |z
  }||k    rdnd}d	}	|dk    rd
nd}
t	          |          D ]:}d|
d|| d| d<   d|
d|| d| d<   d|	d|| d| d<   d|	d|| d| d<   ;t	          ||          D ]!}|dk    rd|	d}nd|
d}||| d| d| d<   "d S )Nr  r  r  r`   r  r  r'  r8  past_encoder_sequencepast_decoder_sequencez past_decoder_sequence + sequencer8   r  r   .decoder.key.decoder.value.encoder.key.encoder.valuer  )rM   r   r9  r:  r   )rT   r  r   r&   rC  rD  rE  rF  rG  r(  rD   r   	axes_infos                r2   r   z4OnnxSeq2SeqConfigWithPast.fill_with_past_key_values_  s   111iT]iiijjj$-$9$9  y 26../1CDD/1CDD~U+=@R+R+RiiXa26?86K6K22Qs~&& 	_ 	_A?FK[;\;\777778AHM]=^=^99999:?FK[;\;\777778AHM]=^=^99999::~~66 	S 	SA"i// ',<==		 ',<==	IREEEE,?EEEFF	S 	Sr1   c                     |d         || d| d<   |d         || d| d<   |d         || d| d<   |d         || d| d	<   d S )
Nr   r   rK  r   rL  r   rM  rC   rN  r0   r  s        r2   r  z3OnnxSeq2SeqConfigWithPast._flatten_past_key_values_  s    89!D4434445:;A$D663666789!D4434445:;A$D663666777r1   r!  )r*   r+   r,   r   r   r/   r   re   r   r   r   r   r   r   r   r   r   r  r"  r#  s   @r2   r%  r%  N  s       gc3h&7!78      X  E#J      X #U3Z # # # # # X#$ *.C C,C C 	C
 C J'C 
c	C C C C C CJSGCQTVYQYIZDZ<[ Shk S S S S8? ? ? ? ? ? ?r1   r%  )3rb   rP   r   abcr   r   collectionsr   typingr   r   r   r	   r
   r   r   r   r   r   numpyr   	packagingr   utilsr   r   r   r   r   r   r   configuration_utilsr   r   r   r   r   r   r    PILr!   
get_loggerr*   r   rt   r   	dataclassr$   r4   r   r%  r0   r1   r2   <module>r[     sx         # # # # # # # # # # # # # # f f f f f f f f f f f f f f f f f f f f f f f f           P P P P P P P P P P P P h h h h h h h h h h  B666666AAAAAA======AAAAAA  		H	%	%   #9  * * * * * * * *(t^ t^ t^ t^ t^ t^ t^ t^nP  P  P  P  P S P  P  P fW? W? W? W? W? 2 W? W? W? W? W?r1   