U
    4Af                     @   s\  d dl Z d dlZd dlmZ d dlmZmZ d dlZd dl	m	Z	 ddl
mZ ddlmZmZmZ ddlmZmZmZ d	d
lmZ ddddhZe rd dlZd dlmZ e rd dlZeeZdd Zdd Zdd Z dd Z!dd Z"edddZ#d.d d!Z$G d"d# d#eZ%G d$d% d%e%Z&G d&d' d'e%Z'G d(d) d)Z(G d*d+ d+Z)G d,d- d-Z*dS )/    N)partial)Pool	cpu_count)tqdm   )whitespace_tokenize)BatchEncodingPreTrainedTokenizerBaseTruncationStrategy)is_tf_availableis_torch_availablelogging   )DataProcessorZrobertaZ	camembertZbartZmpnet)TensorDatasetc           	      C   sp   d ||}t||d D ]H}t||d dD ]2}d | ||d  }||kr2||f    S q2q||fS )zFReturns tokenized answer spans that better match the annotated answer. r   )jointokenizerange)	
doc_tokensZinput_startZ	input_end	tokenizerZorig_answer_textZtok_answer_textZ	new_startZnew_endZ	text_span r   F/tmp/pip-unpacked-wheel-zw5xktn0/transformers/data/processors/squad.py_improve_answer_span+   s    r   c                 C   s   d}d}t | D ]l\}}|j|j d }||jk r4q||kr>q||j }|| }	t||	d|j  }
|dkst|
|kr|
}|}q||kS ):Check if this is the 'max context' doc span for the token.Nr   {Gz?)	enumeratestartlengthminZ	doc_spansZcur_span_indexpositionZ
best_scoreZbest_span_indexZ
span_indexZdoc_spanendZnum_left_contextZnum_right_contextZscorer   r   r   _check_is_max_context8   s    

r$   c                 C   s   d}d}t | D ]v\}}|d |d  d }||d k r:q||krDq||d  }|| }	t||	d|d   }
|dks~|
|kr|
}|}q||kS )r   Nr   r   r   r   )r   r    r!   r   r   r   _new_check_is_max_contextL   s    r%   c                 C   s4   | dks,| dks,| dks,| dks,t | dkr0dS dS )Nr   	
i/   TF)ord)cr   r   r   _is_whitespaceb   s    ,r+   c           1      C   s  g }|rl| j sl| j}| j}d| j||d  }	dt| j}
|	|
dkrlt	d|	 d|
 d g S g }g }g }t
| jD ]Z\}}|t| tjjdkrtj|dd	}n
t|}|D ]}|| || qq|rB| j sB|| j }| jt| jd k r || jd  d }nt|d }t|||t| j\}}g }tj| jd
d|d}ttjdd }|tkrtjtj d n
tjtj }tjtj }|}t|| t|k rtjdkr|}|}tjj}n|}|}tjj}tj |||||d|| t| | dd}t!t|t||  |t| | }tj"|d krtjdkrt|d d |d #tj" }n>t|d d |d d d d #tj" } |d | d d  }n|d }t$|}!i }"t%|D ]>}tjdkrt|| | n|}#|t|| |  |"|#< q||d< |!|d< |"|d< t|| |d< i |d< t|| |d< ||d< || d|ksd|krt|d dkrq|d }qt%t|D ]b}$t%||$ d D ]J}%t&||$|$| |% }&tjdkr|%n||$ d |% }#|&||$ d |#< qq|D ]}'|'d #tj'}(t()|'d })tjdkrTd|)t|| d < n d|)t|'d  t||  < t(*|'d tj"k}*t(+tj,|'d dd- }+d|)|*< d|)|+< d|)|(< | j },d}d}|rX|,sX|'d }-|'d |'d  d }.d
}/||-kr
||.ksd}/|/r"|(}|(}d},n6tjdkr4d}0nt|| }0||- |0 }||- |0 }|t.|'d |'d |'d |(|)/ dd|'d |'d |'d |'d |||,| j0d q|S )Nr   r   r   zCould not find answer: 'z' vs. '')ZRobertaTokenizerZLongformerTokenizerZBartTokenizerZRobertaTokenizerFastZLongformerTokenizerFastZBartTokenizerFastT)Zadd_prefix_spaceF)Zadd_special_tokens
truncation
max_length	Tokenizer right)r-   paddingr.   Zreturn_overflowing_tokensZstrideZreturn_token_type_ids	input_idsparagraph_lentokenstoken_to_orig_mapZ*truncated_query_with_special_tokens_lengthtoken_is_max_contextr   r   Zoverflowing_tokensr   lefttoken_type_ids)Zalready_has_special_tokensattention_mask)
example_index	unique_idr4   r7   r5   r6   start_positionend_positionis_impossibleqas_id)1r?   r=   r>   r   r   r   answer_textfindloggerwarningr   appendlenr   	__class____name__r   r   encodequestion_texttypereplacelowerMULTI_SEP_TOKENS_TOKENIZERS_SETZmodel_max_lengthZmax_len_single_sentenceZmax_len_sentences_pairZpadding_sider
   ZONLY_SECONDvalueZ
ONLY_FIRSTZencode_plusr    Zpad_token_idindexZconvert_ids_to_tokensr   r%   Zcls_token_idnpZ	ones_likewhereZasarrayZget_special_tokens_maskZnonzeroSquadFeaturestolistr@   )1examplemax_seq_length
doc_stridemax_query_lengthpadding_strategyis_trainingfeaturesr=   r>   Zactual_textZcleaned_answer_textZtok_to_orig_indexZorig_to_tok_indexZall_doc_tokensitokenZ
sub_tokensZ	sub_tokenZtok_start_positionZtok_end_positionZspansZtruncated_queryZtokenizer_typeZsequence_added_tokensZsequence_pair_added_tokensZspan_doc_tokensZtextspairsr-   Zencoded_dictr4   Znon_padded_idsZlast_padding_id_positionr5   r6   rP   Zdoc_span_indexjZis_max_contextspan	cls_indexp_maskZpad_token_indicesZspecial_token_indicesZspan_is_impossibleZ	doc_startZdoc_endZout_of_spanZ
doc_offsetr   r   r   !squad_convert_example_to_featuresh   s4   



       

(
 


 rc   Ztokenizer_for_convertc                 C   s   | a d S N)r   rd   r   r   r   &squad_convert_example_to_features_init7  s    rf   r.   FTc
              	      sp  g  t |t }t|t|fd@}
tt|||||d}tt|
j|| ddt	| d|	 d W 5 Q R X g }d}d}t t	 d	|	 dD ]:}|sq|D ]"}||_
||_|| |d
7 }q|d
7 }q| ~|dkr t stdtjdd  D tjd}tjdd  D tjd}tjdd  D tjd}tjdd  D tjd}tjdd  D tjd}tjdd  D tjd}|stj|dtjd}t||||||}nJtjdd  D tjd}tjdd  D tjd}t||||||||} |fS |dkrht std fdd}d|jkrtjtjtjtjtjdtjtjtjtjtjdf}tdgtdgtdgtg tg dtg tg tg tdgtg df}ntjtjtjtjdtjtjtjtjtjdf}tdgtdgtg tg dtg tg tg tdgtg df}tjj|||S  S dS ) a  
    Converts a list of examples into a list of features that can be directly given as input to a model. It is
    model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.

    Args:
        examples: list of [`~data.processors.squad.SquadExample`]
        tokenizer: an instance of a child of [`PreTrainedTokenizer`]
        max_seq_length: The maximum sequence length of the inputs.
        doc_stride: The stride used when the context is too large and is split across several features.
        max_query_length: The maximum length of the query.
        is_training: whether to create features for model evaluation or model training.
        padding_strategy: Default to "max_length". Which padding strategy to use
        return_dataset: Default False. Either 'pt' or 'tf'.
            if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset
        threads: multiple processing threads.


    Returns:
        list of [`~data.processors.squad.SquadFeatures`]

    Example:

    ```python
    processor = SquadV2Processor()
    examples = processor.get_dev_examples(data_dir)

    features = squad_convert_examples_to_features(
        examples=examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        max_query_length=args.max_query_length,
        is_training=not evaluate,
    )
    ```)ZinitializerZinitargs)rV   rW   rX   rY   rZ       )	chunksizez"convert squad examples to features)totaldescdisablei ʚ;r   zadd example index and unique idr   ptz6PyTorch must be installed to return a PyTorch dataset.c                 S   s   g | ]
}|j qS r   )r3   .0fr   r   r   
<listcomp>  s     z6squad_convert_examples_to_features.<locals>.<listcomp>)Zdtypec                 S   s   g | ]
}|j qS r   )r:   rm   r   r   r   rp     s     c                 S   s   g | ]
}|j qS r   )r9   rm   r   r   r   rp     s     c                 S   s   g | ]
}|j qS r   )ra   rm   r   r   r   rp     s     c                 S   s   g | ]
}|j qS r   )rb   rm   r   r   r   rp     s     c                 S   s   g | ]
}|j qS r   )r?   rm   r   r   r   rp     s     c                 S   s   g | ]
}|j qS r   )r=   rm   r   r   r   rp     s     c                 S   s   g | ]
}|j qS r   )r>   rm   r   r   r   rp     s     tfz<TensorFlow must be installed to return a TensorFlow dataset.c                  3   s   t  D ]x\} }|jd krL|j|j| |jd|j|j|j|j|j	dfV  q|j|j|j| |jd|j|j|j|j|j	dfV  qd S )Nr3   r:   feature_indexr@   Zstart_positionsZend_positionsra   rb   r?   r3   r:   r9   rs   r@   )
r   r9   r3   r:   r@   r=   r>   ra   rb   r?   )r\   exr[   r   r   gen  s6    
z/squad_convert_examples_to_features.<locals>.genr9   ru   rt   Nrr   ) r    r   r   rf   r   rc   listr   imaprF   r;   r<   rE   r   RuntimeErrortorchZtensorlongfloatZarangesizer   r   Zmodel_input_namesrq   Zint32Zint64stringZTensorShapedataZDatasetZfrom_generator)examplesr   rV   rW   rX   rZ   rY   Zreturn_datasetthreadsZtqdm_enabledpZ	annotate_Znew_featuresr<   r;   Zexample_featuresZexample_featureZall_input_idsZall_attention_masksZall_token_type_idsZall_cls_indexZ
all_p_maskZall_is_impossibleZall_feature_indexdatasetZall_start_positionsZall_end_positionsrx   Ztrain_typesZtrain_shapesr   rw   r   "squad_convert_examples_to_features<  s    0	   




     
%






r   c                   @   sH   e Zd ZdZdZdZdddZdddZddd	Zdd
dZ	dd Z
dS )SquadProcessorz
    Processor for the SQuAD data set. overridden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and
    version 2.0 of SQuAD, respectively.
    NFc              	   C   s   |s8|d d d   d}|d d d   }g }n,dd t|d d |d d D }d }d }t|d   d|d	   d|d
   d|||d   d|dS )Nanswerstextr   utf-8answer_startc                 S   s(   g | ] \}}|  |  d dqS )r   )r   r   )numpydecode)rn   r   r   r   r   r   rp   ,  s   z@SquadProcessor._get_example_from_tensor_dict.<locals>.<listcomp>idquestioncontexttitle)r@   rJ   context_textrA   start_position_characterr   r   )r   r   zipSquadExample)selftensor_dictevaluateanswerr   r   r   r   r   _get_example_from_tensor_dict&  s$    z,SquadProcessor._get_example_from_tensor_dictc                 C   s@   |r|d }n|d }g }t |D ]}|| j||d q"|S )av  
        Creates a list of [`~data.processors.squad.SquadExample`] using a TFDS dataset.

        Args:
            dataset: The tfds dataset loaded from *tensorflow_datasets.load("squad")*
            evaluate: Boolean specifying if in evaluation mode or in training mode

        Returns:
            List of SquadExample

        Examples:

        ```python
        >>> import tensorflow_datasets as tfds

        >>> dataset = tfds.load("squad")

        >>> training_examples = get_examples_from_dataset(dataset, evaluate=False)
        >>> evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
        ```Z
validationtrain)r   )r   rE   r   )r   r   r   r   r   r   r   r   get_examples_from_dataset>  s    
z(SquadProcessor.get_examples_from_datasetc              	   C   sj   |dkrd}| j dkrtdttj||dkr6| j n|ddd}t|d }W 5 Q R X | |dS )	a  
        Returns the training examples from the data directory.

        Args:
            data_dir: Directory containing the data files used for training and evaluating.
            filename: None by default, specify this if the training file has a different name than the original one
                which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.

        Nr0   NSquadProcessor should be instantiated via SquadV1Processor or SquadV2Processorrr   encodingr   r   )	
train_file
ValueErroropenospathr   jsonload_create_examplesr   data_dirfilenamereader
input_datar   r   r   get_train_examples_  s    

  z!SquadProcessor.get_train_examplesc              	   C   sj   |dkrd}| j dkrtdttj||dkr6| j n|ddd}t|d }W 5 Q R X | |dS )	a  
        Returns the evaluation example from the data directory.

        Args:
            data_dir: Directory containing the data files used for training and evaluating.
            filename: None by default, specify this if the evaluation file has a different name than the original one
                which is `dev-v1.1.json` and `dev-v2.0.json` for squad versions 1.1 and 2.0 respectively.
        Nr0   r   r   r   r   r   dev)	dev_filer   r   r   r   r   r   r   r   r   r   r   r   get_dev_examplesu  s    	
  zSquadProcessor.get_dev_examplesc                 C   s   |dk}g }t |D ]}|d }|d D ]}|d }|d D ]|}	|	d }
|	d }d }d }g }|	dd	}|s|r|	d
 d }|d }|d }n|	d
 }t|
|||||||d}|| q<q(q|S )Nr   r   Z
paragraphsr   Zqasr   r   r?   Fr   r   r   r   )r@   rJ   r   rA   r   r   r?   r   )r   getr   rE   )r   r   set_typerZ   r   entryr   Z	paragraphr   Zqar@   rJ   r   rA   r   r?   r   rU   r   r   r   r     s>    

zSquadProcessor._create_examples)F)F)N)N)rH   
__module____qualname____doc__r   r   r   r   r   r   r   r   r   r   r   r     s   

!

r   c                   @   s   e Zd ZdZdZdS )SquadV1Processorztrain-v1.1.jsonzdev-v1.1.jsonNrH   r   r   r   r   r   r   r   r   r     s   r   c                   @   s   e Zd ZdZdZdS )SquadV2Processorztrain-v2.0.jsonzdev-v2.0.jsonNr   r   r   r   r   r     s   r   c                   @   s   e Zd ZdZg dfddZdS )r   aT  
    A single training/test example for the Squad dataset, as loaded from disk.

    Args:
        qas_id: The example's unique identifier
        question_text: The question string
        context_text: The context string
        answer_text: The answer string
        start_position_character: The character position of the start of the answer
        title: The title of the example
        answers: None by default, this is used during evaluation. Holds answers as well as their start positions.
        is_impossible: False by default, set to True if the example has no possible answer.
    Fc	                 C   s   || _ || _|| _|| _|| _|| _|| _d\| _| _g }	g }
d}| jD ]H}t	|rZd}n$|rj|	
| n|	d  |7  < d}|

t|	d  qH|	| _|
| _|d k	r|s|
| | _|
t|t| d t|
d  | _d S )N)r   r   Tr   Fr   )r@   rJ   r   rA   r   r?   r   r=   r>   r+   rE   rF   r   char_to_word_offsetr    )r   r@   rJ   r   rA   r   r   r   r?   r   r   Zprev_is_whitespacer*   r   r   r   __init__  s4    

zSquadExample.__init__NrH   r   r   r   r   r   r   r   r   r     s   r   c                   @   s"   e Zd ZdZdeedddZdS )rS   a  
    Single squad example features to be fed to a model. Those features are model-specific and can be crafted from
    [`~data.processors.squad.SquadExample`] using the
    :method:*~transformers.data.processors.squad.squad_convert_examples_to_features* method.

    Args:
        input_ids: Indices of input sequence tokens in the vocabulary.
        attention_mask: Mask to avoid performing attention on padding token indices.
        token_type_ids: Segment token indices to indicate first and second portions of the inputs.
        cls_index: the index of the CLS token.
        p_mask: Mask identifying tokens that can be answers vs. tokens that cannot.
            Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer
        example_index: the index of the example
        unique_id: The unique Feature identifier
        paragraph_len: The length of the context
        token_is_max_context:
            List of booleans identifying which tokens have their maximum context in this feature object. If a token
            does not have their maximum context in this feature object, it means that another feature object has more
            information related to that token and should be prioritized over this feature for that token.
        tokens: list of tokens corresponding to the input ids
        token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
        start_position: start of the answer token index
        end_position: end of the answer token index
        encoding: optionally store the BatchEncoding with the fast-tokenizer alignment methods.
    N)r@   r   c                 C   sd   || _ || _|| _|| _|| _|| _|| _|| _|	| _|
| _	|| _
|| _|| _|| _|| _|| _d S re   )r3   r:   r9   ra   rb   r;   r<   r4   r7   r5   r6   r=   r>   r?   r@   r   )r   r3   r:   r9   ra   rb   r;   r<   r4   r7   r5   r6   r=   r>   r?   r@   r   r   r   r   r     s     zSquadFeatures.__init__)NN)rH   r   r   r   strr   r   r   r   r   r   rS     s   *  rS   c                   @   s   e Zd ZdZdddZdS )SquadResultaJ  
    Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.

    Args:
        unique_id: The unique identifier corresponding to that example.
        start_logits: The logits corresponding to the start of the answer
        end_logits: The logits corresponding to the end of the answer
    Nc                 C   s,   || _ || _|| _|r(|| _|| _|| _d S re   )start_logits
end_logitsr<   start_top_indexend_top_index
cls_logits)r   r<   r   r   r   r   r   r   r   r   r   E  s    zSquadResult.__init__)NNNr   r   r   r   r   r   ;  s   	r   )r.   Fr   T)+r   r   	functoolsr   multiprocessingr   r   r   rQ   r   Zmodels.bert.tokenization_bertr   Ztokenization_utils_baser   r	   r
   utilsr   r   r   r   rN   r|   Ztorch.utils.datar   Z
tensorflowrq   Z
get_loggerrH   rC   r   r$   r%   r+   rc   rf   r   r   r   r   r   rS   r   r   r   r   r   <module>   sH   
 P    
 b ?C