U
    4Afa                     @   s   d dl Z d dlZd dlZd dlZd dlmZmZmZ ddlm	Z	m
Z
mZmZ ddlmZ eeZdddZd	d
 Zdd Zdd Zdd ZG dd dZG dd de	ZdS )    N)ListOptionalTuple   )PreTrainedTokenizer_is_control_is_punctuation_is_whitespace)loggingz
vocab.jsonz
merges.txt)
vocab_filemerges_filec                 C   s6   t  }| d }| dd D ]}|||f |}q|S )z
    Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
    strings)
    r      N)setadd)wordpairsZ	prev_charchar r   T/tmp/pip-unpacked-wheel-zw5xktn0/transformers/models/herbert/tokenization_herbert.py	get_pairs"   s    r   c                 C   s  |  dd} tdd| } |  dd} |  dd} |  dd} |  d	d
} |  dd
} |  dd} |  dd} |  dd} |  dd} |  dd} |  dd} |  dd} |  dd} |  dd} |  dd} |  dd} |  dd} |  d d!} |  d"d#} |  d$d%} |  d&d'} |  d(d)} |  d*d+} |  d,d-} td.d| } |  d/d0} |  d1d2} |  d3d4} |  d5d6} |  d7d8} |  d9d:} |  d;d<} |  d=d>} |  d?d@} | S )Azz
    Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
    u   ，,u   。\s*z. u   、u   ”"u   “u   ∶:u   ：u   ？?u   《u   》u   ）)u   ！!u   （(u   ；;u   １1u   」u   「u   ０0u   ３3u   ２2u   ５5u   ６6u   ９9u   ７7u   ８8u   ４4u   ．\s*u   ～~u   ’'u   …z...u   ━-u   〈<u   〉>u   【[u   】]u   ％%)replaceresub)textr   r   r   replace_unicode_punct0   sJ    r4   c                 C   s8   g }| D ]$}t |}|dr"q|| qd|S )zw
    Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
    C )unicodedatacategory
startswithappendjoin)r3   outputr   catr   r   r   remove_non_printing_char\   s    

r>   c                 C   s   |   } | sg S |  }|S )z@Runs basic whitespace cleaning and splitting on a piece of text.)stripsplit)r3   tokensr   r   r   whitespace_tokenizej   s
    rB   c                   @   sN   e Zd ZdZdddZdddZdd	 Zdd
dZdd Zdd Z	dd Z
dS )BasicTokenizera  
    Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).

    Args:
        do_lower_case (`bool`, *optional*, defaults to `True`):
            Whether or not to lowercase the input when tokenizing.
        never_split (`Iterable`, *optional*):
            Collection of tokens which will never be split during tokenization. Only has an effect when
            `do_basic_tokenize=True`
        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
            Whether or not to tokenize Chinese characters.

            This should likely be deactivated for Japanese (see this
            [issue](https://github.com/huggingface/transformers/issues/328)).
        strip_accents (`bool`, *optional*):
            Whether or not to strip all accents. If this option is not specified, then it will be determined by the
            value for `lowercase` (as in the original BERT).
        do_split_on_punc (`bool`, *optional*, defaults to `True`):
            In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
            the full context of the words, such as contractions.
    TNc                 C   s2   |d krg }|| _ t|| _|| _|| _|| _d S N)do_lower_caser   never_splittokenize_chinese_charsstrip_accentsdo_split_on_punc)selfrE   rF   rG   rH   rI   r   r   r   __init__   s    
zBasicTokenizer.__init__c                 C   s   |r| j t|n| j }| |}| jr4| |}td|}t|}g }|D ]R}||kr| j	r|
 }| jdk	r| |}n| jr| |}|| || qPtd|}|S )aj  
        Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.

        Args:
            never_split (`List[str]`, *optional*)
                Kept for backward compatibility purposes. Now implemented directly at the base class level (see
                [`PreTrainedTokenizer.tokenize`]) List of token not to split.
        NFCF )rF   unionr   _clean_textrG   _tokenize_chinese_charsr7   	normalizerB   rE   lowerrH   _run_strip_accentsextend_run_split_on_puncr;   )rJ   r3   rF   Zunicode_normalized_textZorig_tokenssplit_tokenstokenZoutput_tokensr   r   r   tokenize   s$    




zBasicTokenizer.tokenizec                 C   sB   t d|}g }|D ]"}t |}|dkr,q|| qd|S )z$Strips accents from a piece of text.ZNFDZMnr6   )r7   rQ   r8   r:   r;   )rJ   r3   r<   r   r=   r   r   r   rS      s    
z!BasicTokenizer._run_strip_accentsc                 C   s   | j r|dk	r||kr|gS t|}d}d}g }|t|k r|| }t|r^||g d}n |rl|g  d}|d | |d7 }q0dd |D S )	z&Splits punctuation on a piece of text.Nr   TFr   c                 S   s   g | ]}d  |qS )r6   )r;   ).0xr   r   r   
<listcomp>   s     z5BasicTokenizer._run_split_on_punc.<locals>.<listcomp>)rI   listlenr   r:   )rJ   r3   rF   charsiZstart_new_wordr<   r   r   r   r   rU      s"    

z!BasicTokenizer._run_split_on_puncc                 C   sT   g }|D ]@}t |}| |r>|d || |d q|| qd|S )z)Adds whitespace around any CJK character.rM   r6   )ord_is_chinese_charr:   r;   rJ   r3   r<   r   cpr   r   r   rP      s    


z&BasicTokenizer._tokenize_chinese_charsc                 C   s   |dkr|dks|dkr |dks|dkr0|dks|dkr@|dks|d	krP|d
ks|dkr`|dks|dkrp|dks|dkr|dkrdS dS )z6Checks whether CP is the codepoint of a CJK character.i N  i  i 4  iM  i   iߦ i  i? i@ i i  i i   i  i  i TFr   )rJ   rd   r   r   r   rb      sD    
zBasicTokenizer._is_chinese_charc                 C   sX   g }|D ]D}t |}|dks|dkst|r.qt|rB|d q|| qd|S )zBPerforms invalid character removal and whitespace cleanup on text.r   i  rM   r6   )ra   r   r	   r:   r;   rc   r   r   r   rO     s    zBasicTokenizer._clean_text)TNTNT)N)N)__name__
__module____qualname____doc__rK   rX   rS   rU   rP   rb   rO   r   r   r   r   rC   t   s        

&
rC   c                       sV  e Zd ZdZeZddddddddd	d
ddddddddg
ddf fdd	Zedd Zdd Z	dd Z
dd Zdd Zedd  Zd!d" Zd#d$ Zd%d& Zd'd( Zd)d* Zd+d, Zd<ee eee  ee d-d.d/Zd=ee eee  eee d0 fd1d2Zd>ee eee  ee d-d3d4Zd?eee ee d5d6d7Zd8d9 Zd:d; Z  Z S )@HerbertTokenizera  
    Construct a BPE tokenizer for HerBERT.

    Peculiarities:

    - uses BERT's pre-tokenizer: BaseTokenizer splits tokens on spaces, and also on punctuation. Each occurrence of a
      punctuation character will be treated separately.

    - Such pretokenized input is BPE subtokenized

    This tokenizer inherits from [`XLMTokenizer`] which contains most of the methods. Users should refer to the
    superclass for more information regarding methods.
    Nz<s>z<unk>z<pad>z<mask>z</s>Fz
<special0>z
<special1>z
<special2>z
<special3>z
<special4>z
<special5>z
<special6>z
<special7>z
<special8>z
<special9>c                    s`  zdd l }W n tk
r(   tdY nX || _i | _i | _dddh| _|
| _|| _|| _|d k	r~|d k	r~t	|t	|ks~t
d | _d | _t|dd}t|| _W 5 Q R X dd	 | j D | _t|dd}| d
d d }W 5 Q R X dd |D }tt|tt	|| _i | _t jf ||	||||||||
d d| td| jddd| _d S )Nr   zrYou need to install sacremoses to use HerbertTokenizer. See https://pypi.org/project/sacremoses/ for installation.zhthjautf-8encodingc                 S   s   i | ]\}}||qS r   r   )rZ   kvr   r   r   
<dictcomp>]  s      z-HerbertTokenizer.__init__.<locals>.<dictcomp>
rY   c                 S   s    g | ]}t | d d qS )N   )tupler@   )rZ   merger   r   r   r\   `  s     z-HerbertTokenizer.__init__.<locals>.<listcomp>)	unk_token	bos_token	sep_token	pad_token	cls_token
mask_tokenadditional_special_tokenslang2idid2langdo_lowercase_and_remove_accenttokenizer_fileF)rE   rF   rG   rH   )
sacremosesImportErrorsmcache_moses_punct_normalizercache_moses_tokenizerZlang_with_custom_tokenizerr   r~   r   r^   AssertionErrorja_word_tokenizerZzh_word_tokenizeropenjsonloadencoderitemsdecoderreadr@   dictziprange	bpe_rankscachesuperrK   rC   Zall_special_tokensbert_pre_tokenizer)rJ   r   r   r   r{   rw   rz   r|   ry   rx   r   r}   r~   r   kwargsr   Zvocab_handleZmerges_handleZmerges	__class__r   r   rK   &  s\    
 zHerbertTokenizer.__init__c                 C   s   | j S rD   )r   rJ   r   r   r   rE   z  s    zHerbertTokenizer.do_lower_casec                 C   s8   || j kr$| jj|d}|| j |< n
| j | }||S )Nlang)r   r   ZMosesPunctNormalizerrQ   )rJ   r3   r   Zpunct_normalizerr   r   r   moses_punct_norm  s
    

z!HerbertTokenizer.moses_punct_normc                 C   s>   || j kr$| jj|d}|| j |< n
| j | }|j|dddS )Nr   F)Z
return_strescape)r   r   ZMosesTokenizerrX   )rJ   r3   r   Zmoses_tokenizerr   r   r   moses_tokenize  s
    

zHerbertTokenizer.moses_tokenizec                 C   s    t |}| ||}t|}|S rD   )r4   r   r>   )rJ   r3   r   r   r   r   moses_pipeline  s    zHerbertTokenizer.moses_pipelinec              	   C   s   | j d krz(dd l}|dtjd d| _ W nV ttfk
r   td td td td td	 td
  Y nX t	| j 
|S )Nr   z-model r(   z/local/share/kytea/model.binzMake sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following stepsz81. git clone git@github.com:neubig/kytea.git && cd kyteaz2. autoreconf -iz#3. ./configure --prefix=$HOME/localz4. make && make installz5. pip install kytea)r   Mykyteaospath
expanduserAttributeErrorr   loggererrorr]   ZgetWS)rJ   r3   r   r   r   r   ja_tokenize  s"    






zHerbertTokenizer.ja_tokenizec                 C   s
   t | jS rD   )r^   r   r   r   r   r   
vocab_size  s    zHerbertTokenizer.vocab_sizec                 C   s   t | jf| jS rD   )r   r   Zadded_tokens_encoderr   r   r   r   	get_vocab  s    zHerbertTokenizer.get_vocabc           
         s  t |d d |d d f }| jkr2 j| S t|}|sF|d S t| fddd}| jkrhqf|\}}g }d}|t|k r<z|||}	W n, tk
r   |||d   Y q<Y nX ||||	  |	}|| |kr$|t|d k r$||d  |kr$|	||  |d7 }qx|	||  |d7 }qxt |}|}t|dkr\qfqFt|}qFd	
|}|d
kr~d}| j|< |S )NrY   </w>c                    s    j | tdS )Ninf)r   getfloat)pairr   r   r   <lambda>      z&HerbertTokenizer.bpe.<locals>.<lambda>keyr   r   rt   rM   z
  </w>z
</w>)ru   r   r   minr   r^   index
ValueErrorrT   r:   r;   )
rJ   rW   r   r   ZbigramfirstsecondZnew_wordr`   jr   r   r   bpe  sF    


2





zHerbertTokenizer.bpec                 C   s<   | j |}g }|D ]"}|r|t| |d q|S )NrM   )r   rX   rT   r]   r   r@   )rJ   r3   Z
pre_tokensrV   rW   r   r   r   	_tokenize  s    zHerbertTokenizer._tokenizec                 C   s   | j || j | jS )z0Converts a token (str) in an id using the vocab.)r   r   rw   )rJ   rW   r   r   r   _convert_token_to_id  s    z%HerbertTokenizer._convert_token_to_idc                 C   s   | j || jS )z=Converts an index (integer) in a token (str) using the vocab.)r   r   rw   )rJ   r   r   r   r   _convert_id_to_token  s    z%HerbertTokenizer._convert_id_to_tokenc                 C   s   d |dd }|S )z:Converts a sequence of tokens (string) in a single string.r6   r   rM   )r;   r0   r?   )rJ   rA   Z
out_stringr   r   r   convert_tokens_to_string  s    z)HerbertTokenizer.convert_tokens_to_string)token_ids_0token_ids_1returnc                 C   s8   | j g}| jg}|dkr$|| | S || | | | S )a  
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. An XLM sequence has the following format:

        - single sequence: `<s> X </s>`
        - pair of sequences: `<s> A </s> B </s>`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.

        N)Zbos_token_idsep_token_id)rJ   r   r   Zbossepr   r   r    build_inputs_with_special_tokens  s
    z1HerbertTokenizer.build_inputs_with_special_tokens)r   r   already_has_special_tokensr   c                    sf   |rt  j||ddS |dk	rLdgdgt|  dg dgt|  dg S dgdgt|  dg S )a  
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        T)r   r   r   Nr   r   )r   get_special_tokens_maskr^   )rJ   r   r   r   r   r   r   r     s      .z(HerbertTokenizer.get_special_tokens_maskc                 C   sV   | j g}| jg}|dkr.t|| | dg S t|| | dg t|| dg  S )a  
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLM sequence
        pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        Nr   r   )r   Zcls_token_idr^   )rJ   r   r   r   clsr   r   r   $create_token_type_ids_from_sequences8  s
    z5HerbertTokenizer.create_token_type_ids_from_sequences)save_directoryfilename_prefixr   c           
   	   C   s  t j|s"td| d d S t j||r6|d ndtd  }t j||rX|d ndtd  }t|ddd	$}|t	j
| jd
dddd  W 5 Q R X d}t|ddd	`}t| j dd dD ]B\}}	||	krtd| d |	}|d|d  |d7 }qW 5 Q R X ||fS )NzVocabulary path (z) should be a directoryr*   r6   r   r   wrm   rn   rt   TF)indent	sort_keysensure_asciirs   r   c                 S   s   | d S )Nr   r   )kvr   r   r   r   f  r   z2HerbertTokenizer.save_vocabulary.<locals>.<lambda>r   zSaving vocabulary to zZ: BPE merge indices are not consecutive. Please check that the tokenizer is not corrupted!rM   r   )r   r   isdirr   r   r;   VOCAB_FILES_NAMESr   writer   dumpsr   sortedr   r   warning)
rJ   r   r   r   Z
merge_filefr   writerZ
bpe_tokensZtoken_indexr   r   r   save_vocabularyV  s0      (
z HerbertTokenizer.save_vocabularyc                 C   s   | j  }d |d< |S )Nr   )__dict__copy)rJ   stater   r   r   __getstate__s  s    
zHerbertTokenizer.__getstate__c                 C   s:   || _ zdd l}W n tk
r.   tdY nX || _d S )Nr   znYou need to install sacremoses to use XLMTokenizer. See https://pypi.org/project/sacremoses/ for installation.)r   r   r   r   )rJ   dr   r   r   r   __setstate__y  s    
zHerbertTokenizer.__setstate__)N)NF)N)N)!re   rf   rg   rh   r   Zvocab_files_namesrK   propertyrE   r   r   r   r   r   r   r   r   r   r   r   r   intr   r   boolr   r   strr   r   r   r   __classcell__r   r   r   r   ri     sz   T
		
,  
    
   
ri   )r   r   r1   r7   typingr   r   r   Ztokenization_utilsr   r   r   r	   utilsr
   Z
get_loggerre   r   r   r   r4   r>   rB   rC   ri   r   r   r   r   <module>   s"   
,
 "