
    g                         d Z ddlZddlZddlmZmZmZmZ ddlZ	ddl
mZ ddlmZmZmZmZmZ ddlmZmZmZmZmZ  e            rddlZ ej        e          Zdd	iZd
 Zd Z G d de          ZdS )z!Tokenization class for Pop2Piano.    N)ListOptionalTupleUnion   )BatchFeature)
AddedTokenBatchEncodingPaddingStrategyPreTrainedTokenizerTruncationStrategy)
TensorTypeis_pretty_midi_availableloggingrequires_backendsto_numpyvocabz
vocab.jsonc                 4    || z  }|t          ||          }|S N)minnumbercutoff_time_idxcurrent_idxs      p/var/www/html/ai-engine/env/lib/python3.11/site-packages/transformers/models/pop2piano/tokenization_pop2piano.pytoken_time_to_noter   '   s'    6K"+77    c                     ||          9||          }||k     r*|}|                     ||| |g           |dk    rd n|}||| <   n||| <   |S )Nr   )append)	r   current_velocitydefault_velocitynote_onsets_readyr   notes	onset_idx
offset_idxonsets_readys	            r   token_note_to_noter'   /   sr     ,%f-	{""$JLL)Z9IJKKK#3q#8#844kL(4f%$/&!Lr   c                       e Zd ZdZddgZeZ	 	 	 	 	 	 d5 fd
	Zed             Z	d Z
dedefdZd6defdZdej        dededefdZ	 	 	 d7dej        dej        dededef
dZd8dej        dedefdZd9dej        dej        d efd!Zd8d"ed#ee         dee         fd$Z	 	 d:deej        eej                 f         d%ee         d&ee         defd'Z 	 	 d:deej        eej                 f         d%ee         d&ee         defd(Z!	 	 	 	 	 	 	 d;deej        eej                 eeej                          f         d+ee"ee#f         d,ee"eef         d&ee         d-ee         d.ee"         d/eeee$f                  d0e"defd1Z%	 d<d2e&d3e"fd4Z' xZ(S )=Pop2PianoTokenizera  
    Constructs a Pop2Piano tokenizer. This tokenizer does not require training.

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.

    Args:
        vocab (`str`):
            Path to the vocab file which contains the vocabulary.
        default_velocity (`int`, *optional*, defaults to 77):
            Determines the default velocity to be used while creating midi Notes.
        num_bars (`int`, *optional*, defaults to 2):
            Determines cutoff_time_idx in for each token.
        unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"-1"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 1):
            The end of sequence token.
        pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 0):
             A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
            attention mechanisms or loss computation.
        bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 2):
            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
    	token_idsattention_maskM      -1102c                 z   t          |t                    rt          |dd          n|}t          |t                    rt          |dd          n|}t          |t                    rt          |dd          n|}t          |t                    rt          |dd          n|}|| _        || _        t          |d          5 }	t          j        |	          | _        d d d            n# 1 swxY w Y   d | j        	                                D             | _
         t                      j        d||||d| d S )NF)lstriprstriprbc                     i | ]\  }}||	S  r7   ).0kvs      r   
<dictcomp>z/Pop2PianoTokenizer.__init__.<locals>.<dictcomp>s   s    >>>A1>>>r   )	unk_token	eos_token	pad_token	bos_tokenr7   )
isinstancestrr	   r!   num_barsopenjsonloadencoderitemsdecodersuper__init__)selfr   r!   rB   r<   r=   r>   r?   kwargsfile	__class__s             r   rJ   zPop2PianoTokenizer.__init__[   s    JTT]_bIcIcrJyuEEEEir	IST]_bIcIcrJyuEEEEir	IST]_bIcIcrJyuEEEEir	IST]_bIcIcrJyuEEEEir	 0  % 	+$9T??DL	+ 	+ 	+ 	+ 	+ 	+ 	+ 	+ 	+ 	+ 	+ 	+ 	+ 	+ 	+ ?>););)=)=>>> 	
		
 	

 	
 	
 	
 	
 	
s   C**C.1C.c                 *    t          | j                  S )z-Returns the vocabulary size of the tokenizer.)lenrF   rK   s    r   
vocab_sizezPop2PianoTokenizer.vocab_size}   s     4<   r   c                 0    t          | j        fi | j        S )z(Returns the vocabulary of the tokenizer.)dictrF   added_tokens_encoderrQ   s    r   	get_vocabzPop2PianoTokenizer.get_vocab   s    DL>>D$=>>>r   token_idreturnc                     | j                             || j         d          }|                    d          }d                    |dd                   t          |d                   }}||gS )a?  
        Decodes the token ids generated by the transformer into notes.

        Args:
            token_id (`int`):
                This denotes the ids generated by the transformers to be converted to Midi tokens.

        Returns:
            `List`: A list consists of token_type (`str`) and value (`int`).
        _TOKEN_TIME_   Nr   )rH   getr<   splitjoinint)rK   rW   token_type_value
token_typevalues        r   _convert_id_to_tokenz'Pop2PianoTokenizer._convert_id_to_token   st      <++H6T6T6TUU+11#66HH%5abb%9::C@PQR@S<T<TE
E""r   
TOKEN_TIMEc                 f    | j                             | d| t          | j                            S )a  
        Encodes the Midi tokens to transformer generated token ids.

        Args:
            token (`int`):
                This denotes the token value.
            token_type (`str`):
                This denotes the type of the token. There are four types of midi tokens such as "TOKEN_TIME",
                "TOKEN_VELOCITY", "TOKEN_NOTE" and "TOKEN_SPECIAL".

        Returns:
            `int`: returns the id of the token.
        r[   )rF   r]   r`   r<   )rK   tokenrb   s      r   _convert_token_to_idz'Pop2PianoTokenizer._convert_token_to_id   s4     |5 7 7: 7 7T^9L9LMMMr   tokensbeat_offset_idxbars_per_batchr   c                    d}t          t          |                    D ]c}||         }|||z  dz  z   }||z   }	|                     |||	          }
t          |
          dk    rF||
}Kt          j        ||
fd          }d|g S |S )a  
        Converts relative tokens to notes which are then used to generate pretty midi object.

        Args:
            tokens (`numpy.ndarray`):
                Tokens to be converted to notes.
            beat_offset_idx (`int`):
                Denotes beat offset index for each note in generated Midi.
            bars_per_batch (`int`):
                A parameter to control the Midi output generation.
            cutoff_time_idx (`int`):
                Denotes the cutoff time index for each note in generated Midi.
        N   )	start_idxr   r   )axis)rangerP   relative_tokens_ids_to_notesnpconcatenate)rK   ri   rj   rk   r   r#   index_tokens
_start_idx_cutoff_time_idx_notess              r   "relative_batch_tokens_ids_to_notesz5Pop2PianoTokenizer.relative_batch_tokens_ids_to_notes   s    * 3v;;'' 	@ 	@EUmG(5>+AA+EEJ.;66$ 0 7  F 6{{avQ???=Ir   r      beatstepc                     |dn|}|                      ||||          }|                     ||||                   }|S )al  
        Converts tokens to Midi. This method calls `relative_batch_tokens_ids_to_notes` method to convert batch tokens
        to notes then uses `notes_to_midi` method to convert them to Midi.

        Args:
            tokens (`numpy.ndarray`):
                Denotes tokens which alongside beatstep will be converted to Midi.
            beatstep (`np.ndarray`):
                We get beatstep from feature extractor which is also used to get Midi.
            beat_offset_idx (`int`, *optional*, defaults to 0):
                Denotes beat offset index for each note in generated Midi.
            bars_per_batch (`int`, *optional*, defaults to 2):
                A parameter to control the Midi output generation.
            cutoff_time_idx (`int`, *optional*, defaults to 12):
                Denotes the cutoff time index for each note in generated Midi.
        Nr   )ri   rj   rk   r   )
offset_sec)ry   notes_to_midi)rK   ri   r{   rj   rk   r   r#   midis           r   !relative_batch_tokens_ids_to_midiz4Pop2PianoTokenizer.relative_batch_tokens_ids_to_midi   s^    0  /6!!O77+)+	 8 
 
 !!%h>W!XXr   Nrn   c           	           fd|D             }|}d}d t          t          d  j                                        D                       dz             D             }g }|D ]e\  }	}
|	dk    r	|
dk    r nS|	dk    rt	          |
||          }-|	d	k    r|
}6|	d
k    rt          |
| j        |||          }Wt          d          t          |          D ]P\  }}|I||dz   }nt          ||dz             }t          ||          }|
                    ||| j        g           Qt          |          dk    rg S t          j        |          }|dddf         dz  |dddf         z   }||                                         }|S )a  
        Converts relative tokens to notes which will then be used to create Pretty Midi objects.

        Args:
            tokens (`numpy.ndarray`):
                Relative Tokens which will be converted to notes.
            start_idx (`float`):
                A parameter which denotes the starting index.
            cutoff_time_idx (`float`, *optional*):
                A parameter used while converting tokens to notes.
        c                 :    g | ]}                     |          S r7   )rd   )r8   rg   rK   s     r   
<listcomp>zCPop2PianoTokenizer.relative_tokens_ids_to_notes.<locals>.<listcomp>  s'    FFFe**511FFFr   r   c                     g | ]}d S r   r7   r8   is     r   r   zCPop2PianoTokenizer.relative_tokens_ids_to_notes.<locals>.<listcomp>  s    mmmaTmmmr   c                 8    g | ]}|                     d           S )NOTE)endswith)r8   r9   s     r   r   zCPop2PianoTokenizer.relative_tokens_ids_to_notes.<locals>.<listcomp>  s$    5f5f5fQajj6H6H5f5f5fr   r\   TOKEN_SPECIALre   r   TOKEN_VELOCITY
TOKEN_NOTE)r   r    r!   r"   r   r#   zToken type not understood!N   )rp   sumrF   keysr   r'   r!   
ValueError	enumeratemaxr   rP   rr   arrayargsort)rK   ri   rn   r   wordsr   r    r"   r#   rb   r   pitch
note_onsetcutoffr%   
note_orders   `               r   rq   z/Pop2PianoTokenizer.relative_tokens_ids_to_notes   s    GFFFvFFFmm55f5fRVR^RcRcReRe5f5f5f1g1gjk1k+l+lmmm"' 	? 	?J_,,Q;;E |++0!?P[   ///#)  |++*!%5%)%:&7 +   !!=>>>!*+<!=!= 		U 		UE:%"*'!^FF *q.AAF f55
j*eT=RSTTTu::??IHUOOEqqq!ts*U111a4[8J*,,../ELr           r#   r}   c                 ~   t          | dg           t          j        dd          }t          j        d          }g }|D ]F\  }}}	}
t          j        |
|	||         |z
  ||         |z
            }|                    |           G||_        |j                            |           |                                 |S )a  
        Converts notes to Midi.

        Args:
            notes (`numpy.ndarray`):
                This is used to create Pretty Midi objects.
            beatstep (`numpy.ndarray`):
                This is the extrapolated beatstep that we get from feature extractor.
            offset_sec (`int`, *optional*, defaults to 0.0):
                This represents the offset seconds which is used while creating each Pretty Midi Note.
        pretty_midii  g      ^@)
resolutioninitial_tempor   )program)velocityr   startend)	r   r   
PrettyMIDI
InstrumentNoter   r#   instrumentsremove_invalid_notes)rK   r#   r{   r}   new_pmnew_inst	new_notesr$   r%   r   r   new_notes               r   r~   z Pop2PianoTokenizer.notes_to_midi4  s     	$000'3eLLL)!444	6; 	' 	'2Iz5("'!y)J6Z(:5	  H X&&&&"!!(+++##%%%r   save_directoryfilename_prefixc                    t           j                            |          s t                              d| d           dS t           j                            ||r|dz   ndt          d         z             }t          |d          5 }|                    t          j
        | j                             ddd           n# 1 swxY w Y   |fS )a}  
        Saves the tokenizer's vocabulary dictionary to the provided save_directory.

        Args:
            save_directory (`str`):
                A path to the directory where to saved. It will be created if it doesn't exist.
            filename_prefix (`Optional[str]`, *optional*):
                A prefix to add to the names of the files saved by the tokenizer.
        zVocabulary path (z) should be a directoryN- r   w)ospathisdirloggererrorr_   VOCAB_FILES_NAMESrC   writerD   dumpsrF   )rK   r   r   out_vocab_filerM   s        r   save_vocabularyz"Pop2PianoTokenizer.save_vocabularyT  s    w}}^,, 	LLT^TTTUUUF oM_s222QbcjQkk
 
 .#&& 	1$JJtz$,//000	1 	1 	1 	1 	1 	1 	1 	1 	1 	1 	1 	1 	1 	1 	1   s   -B>>CCtruncation_strategy
max_lengthc                 `   t          | dg           t          |d         t          j                  r2t	          j        d |D                                           dd          }t	          j        |                              t          j	                  }|ddddf         
                                }d t          |d	z             D             }|D ]A\  }}}	}
||                             |	|
g           ||                             |	dg           Bg }d}t          |          D ]\  }}t          |          dk    r|                    |                     |d
                     |D ]r\  }	}
t!          |
dk              }
||
k    r+|
}|                    |                     |
d                     |                    |                     |	d                     st          |          }|t"          j        k    r |r||k    r | j        d|||z
  |d|\  }}}t)          d|i          S )a  
        This is the `encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
        generated token ids. It only works on a single batch, to process multiple batches please use
        `batch_encode_plus` or `__call__` method.

        Args:
            notes (`numpy.ndarray` of shape `[sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
                This represents the midi notes. If `notes` is a `numpy.ndarray`:
                    - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
                If `notes` is a `list` containing `pretty_midi.Note` objects:
                    - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
            truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
                Indicates the truncation strategy that is going to be used during truncation.
            max_length (`int`, *optional*):
                Maximum length of the returned list and optionally padding length (see above).

        Returns:
            `BatchEncoding` containing the tokens ids.
        r   r   c                 B    g | ]}|j         |j        |j        |j        gS r7   )r   r   r   r   )r8   	each_notes     r   r   z2Pop2PianoTokenizer.encode_plus.<locals>.<listcomp>  s+    nnn[d)/9=)/9CUVnnnr   rm   Nr-   c                     g | ]}g S r7   r7   r   s     r   r   z2Pop2PianoTokenizer.encode_plus.<locals>.<listcomp>  s    777777r   r\   re   r   r   )idsnum_tokens_to_remover   r*   r7   )r   r@   r   r   rr   r   reshaperoundastypeint32r   rp   r   r   rP   rh   r`   r   DO_NOT_TRUNCATEtruncate_sequencesr
   )rK   r#   r   r   rL   max_time_idxtimesonsetoffsetr   r   ri   r    r   time	total_lenr[   s                    r   encode_pluszPop2PianoTokenizer.encode_plusk  sy   6 	$000 eAh 011 	Hnnhmnnn gb!nn 
 &&rx00QQQU|''))77UL1$466777.3 	- 	-*E65(%L 1222&M  %,,,, '' 		N 		NGAt4yyA~~MM$33A|DDEEE#' N Nxx!|,,#x//'/$MM$";";HFV"W"WXXXd77|LLMMMMN KK	 "4"DDDDXadnXnXn242 %.%;$7  	 LFAq k62333r   c           	          g }t          t          |                    D ]2}|                     | j        ||         f||d|d                    3t	          d|i          S )a  
        This is the `batch_encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
        generated token ids. It works on multiple batches by calling `encode_plus` multiple times in a loop.

        Args:
            notes (`numpy.ndarray` of shape `[batch_size, sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
                This represents the midi notes. If `notes` is a `numpy.ndarray`:
                    - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
                If `notes` is a `list` containing `pretty_midi.Note` objects:
                    - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
            truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
                Indicates the truncation strategy that is going to be used during truncation.
            max_length (`int`, *optional*):
                Maximum length of the returned list and optionally padding length (see above).

        Returns:
            `BatchEncoding` containing the tokens ids.
        )r   r   r*   )rp   rP   r   r   r
   )rK   r#   r   r   rL   encoded_batch_token_idsr   s          r   batch_encode_plusz$Pop2PianoTokenizer.batch_encode_plus  s    4 #%s5zz"" 	 	A#**  !H(;)  	 
     k+BCDDDr   FTpadding
truncationpad_to_multiple_ofreturn_attention_maskreturn_tensorsverbosec	           	      D   t          |t          j                  r|j        dk    nt          |d         t                    }
 | j        d|||||d|	\  }}}}	|
r|dn|} | j        d|||d|	}n | j        d|||d|	}|                     |||||||          }|S )	a  
        This is the `__call__` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer generated
        token ids.

        Args:
            notes (`numpy.ndarray` of shape `[batch_size, max_sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
                This represents the midi notes.

                If `notes` is a `numpy.ndarray`:
                    - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
                If `notes` is a `list` containing `pretty_midi.Note` objects:
                    - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
            padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
                Activates and controls padding. Accepts the following values:

                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
                  sequence if provided).
                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
                  acceptable input length for the model if that argument is not provided.
                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
                  lengths).
            truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
                Activates and controls truncation. Accepts the following values:

                - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
                  to the maximum acceptable input length for the model if that argument is not provided. This will
                  truncate token by token, removing a token from the longest sequence in the pair if a pair of
                  sequences (or a batch of pairs) is provided.
                - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
                  maximum acceptable input length for the model if that argument is not provided. This will only
                  truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
                - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
                  maximum acceptable input length for the model if that argument is not provided. This will only
                  truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
                - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
                  greater than the model maximum admissible input size).
            max_length (`int`, *optional*):
                Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to
                `None`, this will use the predefined model maximum length if a maximum length is required by one of the
                truncation/padding parameters. If the model has no specific maximum input length (like XLNet)
                truncation/padding to a maximum length will be deactivated.
            pad_to_multiple_of (`int`, *optional*):
                If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
                the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
            return_attention_mask (`bool`, *optional*):
                Whether to return the attention mask. If left to the default, will return the attention mask according
                to the specific tokenizer's default, defined by the `return_outputs` attribute.

                [What are attention masks?](../glossary#attention-mask)
            return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
                If set, will return tensors instead of list of python integers. Acceptable values are:

                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return Numpy `np.ndarray` objects.
            verbose (`bool`, *optional*, defaults to `True`):
                Whether or not to print more information and warnings.

        Returns:
            `BatchEncoding` containing the token_ids.
        r   r   )r   r   r   r   r   NT)r#   r   r   )r   r   r   r   r   r   r7   )	r@   rr   ndarrayndimlist"_get_padding_truncation_strategiesr   r   pad)rK   r#   r   r   r   r   r   r   r   rL   
is_batchedpadding_strategyr   r*   s                 r   __call__zPop2PianoTokenizer.__call__  s9   d )35"*(E(EeUZ1__:V[\]V^`dKeKe
 ElDDk E
!!1E
 E
 E
 E
A-z6  	,A,IDDOd!.. $7%  	 II )( $7%  	 I HH$!1"7)  
 
	 r   feature_extractor_outputreturn_midic                 8   t          t          |d          ot          |d          ot          |d                    }|s&|d         j        d         dk    rt          d          |rt	          |d         dddf         dk              |d         j        d         k    s(|d         j        d         |d	         j        d         k    rEt          d
|j        d          d|d         j        d          d|d	         j        d                    |d         j        d         |j        d         k    r1t          d|d         j        d          d|j        d                    nf|d         j        d         dk    s|d	         j        d         dk    r8t          d|d         j        d          d|d	         j        d          d          |r/t          j        |d         dddf         dk              d         }n|j        d         g}g }g }d}t          |          D ]\  }	}
|||
         }|dddt          j        t          j        |t          | j
                  k              d                   dz   f         }|d         |	         }|d	         |	         }|r|d         |	         }|d         |	         }|dt          j        t          j        |dk              d                   dz            }|dt          j        t          j        |dk              d                   dz            }t          |          }t          |          }t          |          }|                     ||| j        | j        dz   dz            }|j        d         j        D ]C}|xj        |d         z  c_        |xj        |d         z  c_        |                    |           D|                    |           ||
dz   z  }|rt'          ||d          S t'          d|i          S )aF  
        This is the `batch_decode` method for `Pop2PianoTokenizer`. It converts the token_ids generated by the
        transformer to midi_notes and returns them.

        Args:
            token_ids (`Union[np.ndarray, torch.Tensor, tf.Tensor]`):
                Output token_ids of `Pop2PianoConditionalGeneration` model.
            feature_extractor_output (`BatchFeature`):
                Denotes the output of `Pop2PianoFeatureExtractor.__call__`. It must contain `"beatstep"` and
                `"extrapolated_beatstep"`. Also `"attention_mask_beatsteps"` and
                `"attention_mask_extrapolated_beatstep"`
                 should be present if they were returned by the feature extractor.
            return_midi (`bool`, *optional*, defaults to `True`):
                Whether to return midi object or not.
        Returns:
            If `return_midi` is True:
                - `BatchEncoding` containing both `notes` and `pretty_midi.pretty_midi.PrettyMIDI` objects.
            If `return_midi` is False:
                - `BatchEncoding` containing `notes`.
        r+   attention_mask_beatsteps$attention_mask_extrapolated_beatstep	beatstepsr   r\   zattention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep must be present for batched inputs! But one of them were not present.Nextrapolated_beatstepzbLength mistamtch between token_ids, beatsteps and extrapolated_beatstep! Found token_ids length - z, beatsteps shape - z$ and extrapolated_beatsteps shape - z!Found attention_mask of length - z but token_ids of length - zLength mistamtch of beatsteps and extrapolated_beatstep! Since attention_mask is not present the number of examples must be 1, But found beatsteps length - z", extrapolated_beatsteps length - .rm   )ri   r{   rk   r   )r#   pretty_midi_objectsr#   )boolhasattrshaper   r   rr   wherer   r   r`   r=   r   r   rB   r   r#   r   r   r   r
   )rK   r*   r   r   attention_masks_present	batch_idx
notes_listpretty_midi_objects_listrn   rt   end_idxeach_tokens_idsr   r   r   r   pretty_midi_objectnotes                     r   batch_decodezPop2PianoTokenizer.batch_decodeU  sF   8 #',.>?? Z02LMMZ02XYY#
 #
 ' 	+CK+P+VWX+Y\]+]+]H   # 	 ,-=>qqq!tDIJJ+K8>qAB B+K8>qA+,CDJ1MN N !w*3/!*<w wRjkvRwR}~  SAw w:RSj:k:qrs:tw w  
 ((89?BioVWFXXX  ]8PQa8b8hij8k  ]  ]  IR  IX  YZ  I[  ]  ]   Y )5;A>!CC+,CDJ1MQRRR D4L[4Y4_`a4bD D G_  `w  Gx  G~  @  GAD D D  
 # 	-!9:J!KAAAqD!QUV!VWWXYZII"+,I
#% 	'	22 #	% #	%NE7'	'(9:O-aaa1r26"(?VYZ^ZhViViCi:j:jkl:m3n3nqr3r1r.rsO0=eDI$<=T$UV[$\! ' +CD^+_`e+f(7O:884 &&^rx8PTU8U/V/VWX/Y(Z(Z]^(^&^_	(=XbfRX&Ja&OPPQRSTTWXXX)% '77O ++I$,-B$C$C!!%!G!G&.#}!%!2a 7	 "H " " +6q9? ( (

il*

IaL(!!$''''$++,>???1$II 	i :Nf!g!ghhhgz2333r   )r,   r-   r.   r/   r0   r1   )re   )r   r-   rz   r   )r   )NN)FNNNNNT)T))__name__
__module____qualname____doc__model_input_namesr   vocab_files_namesrJ   propertyrR   rV   r`   r   rd   rh   rr   r   ry   r   floatrq   r~   rA   r   r   r   r   r   r   r   r   r
   r   r   r   r   r   r   r   r   __classcell__)rN   s   @r   r)   r)   >   s(        2 %&67)
  
  
  
  
  
  
D ! ! X!? ? ?#S #T # # # #$N Nc N N N N *
* * 	*
 * * * *`  !!   
  *  	 
          H: :2: :% :bg : : : :x 2:  QT    @! !c !HSM !]bcf]g ! ! ! !4 =A$(	E4 E4RZk&6!778E4 &&89E4 SM	E4 
E4 E4 E4 E4T =A$(	%E %ERZk&6!778%E &&89%E SM	%E 
%E %E %E %E\ 6;;?$(,004;?z zJ!"k&'(*
z tS/12z $%778z SMz %SMz  (~z !sJ!78z z 
z z z z@ !	w4 w4 #/w4 	w4 w4 w4 w4 w4 w4 w4 w4r   r)   ) r   rD   r   typingr   r   r   r   numpyrr   feature_extraction_utilsr   tokenization_utilsr	   r
   r   r   r   utilsr   r   r   r   r   r   
get_loggerr   r   r   r   r'   r)   r7   r   r   <module>r	     sR   ( '  				 / / / / / / / / / / / /     4 4 4 4 4 4 u u u u u u u u u u u u u u _ _ _ _ _ _ _ _ _ _ _ _ _ _  		H	%	% \ 
    N
4 N
4 N
4 N
4 N
4, N
4 N
4 N
4 N
4 N
4r   