U
    4Aft                  
   @   sH  d Z ddlZddlZddlZddlmZmZmZmZ ddl	m
Z
 ddlmZmZ ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZ eeeeef  eeeeeef  eeeeef   eeeeeef   f ZG dd deZeeeeef eeeef dddZeeedddZdd Zdd Z dd Z!dddZ"dS )zProcessor class for KOSMOS-2.    N)ListOptionalTupleUnion   )BatchFeature)
ImageInput
is_batched)ProcessorMixin)
AddedToken)BatchEncodingPaddingStrategy	TextInputTruncationStrategy)
TensorTypec                       sn  e Zd ZdZddgZdgZdZdZd& fdd		Zd'e	e
eee f eee ee eee
eeef e
eeef ee ee ee eeee
eef  edddZdd Zdd Zd(e
eee f e	eee e
eee f dddZdd Zdd Zd)ddZedd Zee
eee  eee   f ed d!d"Z!e
eeef ee e e e f f eeef d#d$d%Z"  Z#S )*Kosmos2Processora,  
    Constructs an KOSMOS-2 processor which wraps a KOSMOS-2 image processor and a KOSMOS-2 tokenizer into a single
    processor.

    [`Kosmos2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and some functionalities of
    [`XLMRobertaTokenizerFast`]. See the docstring of [`~Kosmos2Processor.__call__`] and [`~Kosmos2Processor.decode`]
    for more information.

    Args:
        image_processor (`CLIPImageProcessor`):
            An instance of [`CLIPImageProcessor`]. The image processor is a required input.
        tokenizer (`XLMRobertaTokenizerFast`):
            An instance of ['XLMRobertaTokenizerFast`]. The tokenizer is a required input.
        num_patch_index_tokens (`int`, *optional*, defaults to 1024):
            The number of tokens that represent patch indices.
    image_processor	tokenizernum_patch_index_tokensZCLIPImageProcessor)ZXLMRobertaTokenizerZXLMRobertaTokenizerFast   c                    s   d|_ d| _d| _d| _d| _d| _d| _d| _d	| _d
| _	d| _
d| _| j| j| j| j| j| j| j| j| j	| j
| jg| _|| _dd t| jD }g }| j| D ]}|t|dddd q|| t || d S )NFz</doc>z<image>z</image>z</chunk>z</line>z<phrase>z	</phrase>z<object>z	</object></delimiter_of_multi_objects/>z<grounding>c                 S   s"   g | ]}d t |d dqS )<patch_index_   >)strzfill.0x r   R/tmp/pip-unpacked-wheel-zw5xktn0/transformers/models/kosmos2/processing_kosmos2.py
<listcomp>a   s     z-Kosmos2Processor.__init__.<locals>.<listcomp>T)lstriprstrip
normalized)Zreturn_token_type_idsZ	eod_token	boi_token	eoi_tokenZ	eoc_tokenZ	eol_tokenZ	bop_tokenZ	eop_tokenZ	boo_tokenZ	eoo_tokenZ	dom_tokenZ	grd_tokenZ
tag_tokensr   rangeappendr   Z
add_tokenssuper__init__)selfr   r   r   kwargsZpatch_index_tokensZtokens_to_addtoken	__class__r   r    r*   =   s>    
zKosmos2Processor.__init__N@   TF)imagestextbboxesnum_image_tokensfirst_image_token_idadd_special_tokensadd_eos_tokenpadding
truncation
max_lengthpad_to_multiple_ofreturn_attention_maskreturn_lengthverbosereturn_tensorsreturnc           !         sp  |dkr|dkrt dt }|dk	r>j||d}|| |dk	rj||||d}|r|st|tr|jj | }nt|t	rfdd|D }jf ||o||o|dk|	|
|dkr|n||||dkr|ndd	|}|| |dk	rl|dk	rl|dkrjj
d }|}t|d }t	t||| }d	gdg|  d	g }g }g }|d
 }t|tr||g}|d g|d< |D ]n}|d| | ||| d  }|| t|}|rd	g| }|d	gt|t|  7 }|| qt|t	rtdd t|jD dd d}|d	 \}}|d \} }jf ||  g|oJ|||	|
||dd|}t|jd	  | krjjdkrʇ fdd|D } fdd|D } fdd|d D |d< nNjjdkr fdd|D } fdd|D } fdd|d D |d< t|trN|dkrN|d	 }|d d	 |d< |d	 }|t||d |d|d |S )a	  
        This method uses [`CLIPImageProcessor.__call__`] method to prepare image(s) for the model, and
        [`XLMRobertaTokenizerFast.__call__`] to prepare text for the model.

        Please refer to the docstring of the above two methods for more information.

        The rest of this documentation shows the arguments specific to `Kosmos2Processor`.

        Args:
            bboxes (`Union[List[Tuple[int]], List[Tuple[float]], List[List[Tuple[int]]], List[List[Tuple[float]]]]`, *optional*):
                The bounding bboxes associated to `texts`.
            num_image_tokens (`int`, *optional* defaults to 64):
                The number of (consecutive) places that are used to mark the placeholders to store image information.
                This should be the same as `latent_query_num` in the instance of `Kosmos2Config` you are using.
            first_image_token_id (`int`, *optional*):
                The token id that will be used for the first place of the subsequence that is reserved to store image
                information. If unset, will default to `self.tokenizer.unk_token_id + 1`.
            add_eos_token (`bool`, defaults to `False`):
                Whether or not to include `EOS` token id in the encoding when `add_special_tokens=True`.
        Nz*You have to specify either images or text.)r?   )r4   c                    s   g | ]} j j | qS r   )r   	bos_token)r   s)r+   r   r    r!      s     z-Kosmos2Processor.__call__.<locals>.<listcomp>)	r2   r6   r8   r9   r:   r;   r<   r>   r?      r   	input_idsattention_maskc                 S   s   g | ]\}}|t |fqS r   len)r   idxr   r   r   r    r!      s     c                 S   s   | d S Nr   )r   r   r   r    <lambda>       z+Kosmos2Processor.__call__.<locals>.<lambda>)keyrJ   )r2   r6   r8   r9   r:   r;   r>   r?   rightc                    s&   g | ]}|j jg t|   qS r   r   Zpad_token_idrG   r   max_len_paddedr+   r   r    r!      s     c                    s"   g | ]}|d g t |   qS r   rF   r   rQ   r   r    r!      s    c                    s"   g | ]}|d g t |   qS rR   rF   r   rS   r   r    r!      s    leftc                    s&   g | ]}j jg t|  | qS r   rO   r   rP   r   r    r!      s     c                    s"   g | ]}d g t |  | qS rR   rF   r   rS   r   r    r!      s    c                    s"   g | ]}d g t |  | qS rR   rF   r   rS   r   r    r!      s    )rD   rE   image_embeds_position_mask)dataZtensor_type)
ValueErrorr   r   updatepreprocess_examples
isinstancer   r   rA   listZunk_token_idintr'   r(   copyrG   sorted	enumeraterD   Zpadding_sider   )!r+   r1   r2   r3   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   r?   r,   encodingZimage_encodingZtext_encodingZwith_bosstart_indexZimage_token_idsZbase_image_embeds_position_maskrD   rU   Zall_input_idsZtext_idsmaskZsorted_length_Zmin_len_not_paddedrH   r   rP   r    __call__j   s    '






 


 	





zKosmos2Processor.__call__c                 C   s   |dkrdS t |tstd|D ]x}|dkr2q"nt |tsB|g}|D ]R}t |trt|dkrrtdd |D sFt|dkrtdd |D sFtdqFq"dS )	a  
        Check `bboxes` for a single text example. It could be
            - `None`: no bounding box associated to a text.
            - A list with each element being the bounding boxes associated to one `<phrase> ... </phrase>` pair found
              in a text. This could be:
                  - `None`: no bounding box associated to a `<phrase> ... </phrase>` pair.
                  - A tuple of 2 integers: A single bounding box specified by patch indices.
                  - A tuple of 4 float point number: A single bounding box specified by (normalized) coordinates.
                  - A list containing the above 2 tuple types: Multiple bounding boxes for a
                   `<phrase> ... </phrase>` pair.
        Nz@`bboxes` (for a single text example) should be `None` or a list.   c                 s   s   | ]}t |tV  qd S N)rZ   r\   r   r   r   r    	<genexpr>*  s     zAKosmos2Processor._check_bboxes_for_single_text.<locals>.<genexpr>r   c                 s   s   | ]}t |tV  qd S rf   )rZ   floatr   r   r   r    rg   +  s     a'  Each element in `bboxes` (for a single text example) should be either `None`, a tuple containing 2 integers or 4 float point numbers, or a list containing such tuples. Also make sure the arguments `texts` and `bboxes` passed to `preprocess_text` are both in batches or both for a single example.)rZ   r[   rW   tuplerG   all)r+   r3   bboxelementr   r   r    _check_bboxes_for_single_text  s,    




z.Kosmos2Processor._check_bboxes_for_single_textc                 C   s.   |  }|d k	r| d| }| ||}|S )N )strip_insert_patch_index_tokens)r+   r2   imager3   img_info_tokensr   r   r    _preprocess_single_example4  s
    z+Kosmos2Processor._preprocess_single_example)textsr1   r3   r4   r@   c           	         sD  j g| }dj g| jg  d}t|tr>d}|g}|dkrVdgt| }nt|sd|g}t|t|krtdt| dt| d|s| |g}n>|dk	rt|t	std|D ]}| qndgt| }t|t|krtd	t| dt| d fd
dt
|||D }|s@|d }|S )a-  Add image and bounding box information to `texts` as image and patch index tokens.

        Args:
            texts (`Union[TextInput, List[TextInput]]`): The texts to be processed.
            images (`ImageInput`, *optional*): The images associated to `texts`.
            bboxes (`Union[List[Tuple[int]], List[Tuple[float]], List[List[Tuple[int]]], List[List[Tuple[float]]]]`, *optional*):
                The bounding bboxes associated to `texts`.
            num_image_tokens (`int`, *optional*, defaults to 64):
                The number of image tokens (used as latent queries). This should corresponds to the `latent_query_num`
                attribute in `Kosmos2Config`.

        Returns:
            `Union[TextInput, List[TextInput]]`: The processed texts with image and patch index tokens.
        rn   TFNzGThe number of examples in `texts` and `images` should be the same. Got  v.s. 	 instead.zS`bboxes` should be `None` or a list (as a batch) when `texts` is passed as a batch.zGThe number of examples in `texts` and `bboxes` should be the same. Got c                    s"   g | ]\}}} ||| qS r   )rs   )r   r2   rq   rk   rr   r+   r   r    r!   v  s   z8Kosmos2Processor.preprocess_examples.<locals>.<listcomp>r   )r%   joinr&   rZ   r   rG   r	   rW   rm   r[   zip)	r+   rt   r1   r3   r4   Z
img_tokensZbatchedr   resultr   rw   r    rY   >  sB    



z$Kosmos2Processor.preprocess_examplesc                 O   s   | j j||S )z
        This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
        refer to the docstring of this method for more information.
        )r   batch_decoder+   argsr,   r   r   r    r{     s    zKosmos2Processor.batch_decodec                 O   s   | j j||S )z
        This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
        the docstring of this method for more information.
        )r   decoder|   r   r   r    r~     s    zKosmos2Processor.decodec                 C   s    | | jd }|rt|S |S rI   )splitr&   +clean_text_and_extract_entities_with_bboxes)r+   r2   Zcleanup_and_extractcaptionr   r   r    post_process_generation  s    z(Kosmos2Processor.post_process_generationc                 C   s"   | j j}| jj}tt|| S rf   )r   model_input_namesr   r[   dictfromkeys)r+   Ztokenizer_input_namesZimage_processor_input_namesr   r   r    r     s    z"Kosmos2Processor.model_input_names)r2   r3   r@   c                 C   sT  |d kst |dkr|S ttjd|d}t |t |krXtdt | dt | dd}g }t||D ]\}}| \}}	||||	  |	}|d krqjt|t	r|g}g }
t
dd |D std	|D ]&}| |\}}|
| d
|  qt |
dkrqjd|
}|d| d qj|t |k rF|||d   d|}|S )Nr   z<phrase>.+?</phrase>)stringzuThe number of elements in `bboxes` should be the same as the number of `<phrase> ... </phrase>` pairs in `text`. Got ru   rv   c                 s   s   | ]}|d k	V  qd S rf   r   )r   boxr   r   r    rg     s     z>Kosmos2Processor._insert_patch_index_tokens.<locals>.<genexpr>zTThe multiple bounding boxes for a single phrase should not contain any `None` value.rn   z  </delimiter_of_multi_objects/> z	<object> z
 </object> )rG   r[   refinditerrW   ry   spanr(   rZ   ri   rj   #_convert_bbox_to_patch_index_tokensrx   )r+   r2   r3   Zmatched_phrasescurr_posbuffermatchedrk   rc   endZpatch_index_stringsr   Zpatch_index_1Zpatch_index_2Zposition_strr   r   r    rp     sB    


z+Kosmos2Processor._insert_patch_index_tokens)rk   r@   c                 C   sh   t |dkr|\}}ntt| j}t||\}}dt|d d}dt|d d}||fS )Nre   r   r   r   )rG   r\   mathsqrtr   coordinate_to_patch_indexr   r   )r+   rk   Zidx_1Zidx_2num_patches_per_sideZtoken_1Ztoken_2r   r   r    r     s    
z4Kosmos2Processor._convert_bbox_to_patch_index_tokens)r   )NNNr0   NTFFNNNNFTN)NNr0   )T)$__name__
__module____qualname____doc__
attributesZvalid_kwargsZimage_processor_classZtokenizer_classr*   r   r   r   r   	BboxInputr   r\   boolr   r   r   r   r   rd   rm   rs   rY   r{   r~   r   propertyr   r   rh   rp   r   __classcell__r   r   r.   r    r   &   sz   /                (#   C

*.
r   )rk   r   r@   c                 C   s   | \}}}}||kr||ks$t dt|| }t|| }t|| d }t|| d }	|| | }
|	| | }|
|fS )a  Convert a bounding box to a pair of patch indices.

    Args:
        bbox (`Tuple[float, float, float, float]`):
            The 4 coordinates of the bounding box, with the format being (x1, y1, x2, y2) specifying the upper-left and
            lower-right corners of the box. It should have x2 > x1 and y2 > y1.
        num_patches_per_side (`int`): the number of patches along each side.

    Returns:
        `Tuple[int, int]`: A pair of patch indices representing the upper-left patch and lower-right patch.
    zTThe coordinates in `bbox` should be `(x1, y1, x2, y2)` with `x2 > x1` and `y2 > y1`.rC   )rW   r   floorceil)rk   r   x1y1x2y2ul_xul_ylr_xlr_yul_idxlr_idxr   r   r    r     s    r   )r   r   r   c                 C   s   d| }| | }| | }|| }|| }| |krZ|| }|| }	|| | }
|| | }nz||ksj||kr|| }|| }	|| | }
|| | }n@|| |d  }|| |d  }	|| |d  }
|| |d  }||	|
|fS )a  
    Given a grid of length `num_patches_per_side` and the indices of the upper-left and lower-right corners of a
    bounding box, returns the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2).

    Args:
        ul_idx (`int`): the index of the grid cell that corresponds to the upper-left corner of the bounding box.
        lr_idx (`int`): the index of the grid cell that corresponds to the lower-right corner of the bounding box.
        num_patches_per_side (`int`): the number of patches along each side.

    Returns:
        `Tuple[float]`: the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2).
    g      ?re   r   )r   r   r   Z	cell_sizer   r   r   r   r   r   r   r   r   r   r    patch_index_to_coordinate  s(    r   c              	   C   s6  d}t || }g }|D ]}|d}| \}}}|sZd}|dd |dd f}|d}	g }
|	D ]v}t d|}t d|dd }|rl|rl|r|
t|dt|df ql|
t|dt|df ql|r||||
f q|
D ]0}d|d  d	|d  d
}||||gf qq|S )a  Extract entities contained in `text`. The bounding bboxes is given in the form of patch indices.

    This functioin is only intended to be used within `clean_text_and_extract_entities_with_bboxes` where further
    processing happens, including converting to normalized coordinates and whitespace character cleaning up.

    Examples:

    ```python
    >>> text = "<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>."
    >>> entities = extract_entities_with_patch_indices(text)
    >>> entities
    [(' a snowman', (31, 41), [(44, 863)]), (' a fire', (130, 137), [(5, 911)])]
    ```z(?:(<phrase>([^<]+)</phrase>))?<object>((?:<patch_index_\d+><patch_index_\d+></delimiter_of_multi_objects/>)*<patch_index_\d+><patch_index_\d+>)</object>re   Nr   r   z<patch_index_(\d+)>rC   r   z><patch_index_r   )	r   r   r   groupsr   searchr(   r\   group)r2   patternmatchesentities_with_patch_indicesmatchr   Z
phrase_tagphraseZmatch_contentZpatch_index_pairsZentity_bboxespairr   yrk   entityr   r   r    #extract_entities_with_patch_indices(  s0    


$$r   c                 C   sP   | \}\}}t tdd|d| }t tdd|d| }|||ff}|S )zfAdjust the positions of the entities in `text` to be relative to the text with special fields removed.<.*?>r   N)rG   r   sub)r   r2   entity_namestartr   Zadjusted_startZadjusted_endadjusted_entityr   r   r    adjust_entity_positionsb  s
    r   c                 C   s   |   }t| t|   }g }|D ]j\}\}}}t|t|  }	t|t|  }
|| |	 }|| |
 }|  }||||f|f q$||fS )z9Remove the spaces around the text and the entities in it.)ro   rG   r"   r#   r(   )r2   entitiesZnew_textZleading_spacesZnew_entitiesr   r   r   r3   Zentity_name_leading_spacesZentity_name_trailing_spacesr   r   r    _cleanup_spacesl  s    r       c           
         sp   t dd| }t| }g }|D ]F}|dd |d  }}t|| } fdd|D }	|||	f  qt||S )a  Remove the tag tokens from `text`, extract entities in it with some cleaning up of white characters.

    Examples:

    ```python
    >>> text = "<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>."
    >>> clean_text, entities = clean_text_and_extract_entities_with_bboxes(text)
    >>> clean_text
    'An image of a snowman warming himself by a fire.'

    >>> entities
    [('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]
    ```r   r   r   re   c                    s    g | ]}t |d  |d  qS )r   rC   )r   )r   rk   r   r   r    r!     s     z?clean_text_and_extract_entities_with_bboxes.<locals>.<listcomp>)r   r   r   r   r(   r   )
r2   r   Zprocessed_textr   r   itemr   r3   r   Zbboxes_in_coordsr   r   r    r     s    
r   )r   )#r   r]   r   r   typingr   r   r   r   Zimage_processing_utilsr   Zimage_utilsr   r	   Zprocessing_utilsr
   Ztokenization_utilsr   Ztokenization_utils_baser   r   r   r   utilsr   r\   rh   r   r   r   r   r   r   r   r   r   r   r   r    <module>   s6      9&-:
