U
    4Afe                     @   s(  d Z ddlZddlmZ ddlmZmZmZ ddlZ	ddl
mZ ddlmZmZ ddlmZmZ dd	lmZ dd
lmZmZ ddlmZ eeZeG dd deZG dd de	jjjZ G dd de	jjjZ!G dd de	jjjZ"G dd de	jjjZ#G dd de	jjjZ$G dd deZ%dS )zOTF IdeficsVision model: a copy of CLIPVisionModel using a simpler config object    N)	dataclass)OptionalTupleUnion   )get_tf_activation)TFBaseModelOutputTFBaseModelOutputWithPooling)TFPreTrainedModel
shape_list)flatten)ModelOutputlogging   )IdeficsVisionConfigc                   @   s^   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dS )TFIdeficsVisionModelOutputa  
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.

    Args:
        image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The image embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nimage_embedslast_hidden_statehidden_states
attentions)__name__
__module____qualname____doc__r   r   tfTensor__annotations__r   r   r   r    r   r   I/tmp/pip-unpacked-wheel-zw5xktn0/transformers/models/idefics/vision_tf.pyr   "   s
   
r   c                       s\   e Zd Zed fddZejeeejdddZdeje	ejdd	d
Z
dddZ  ZS )TFIdeficsVisionEmbeddingsconfigc              	      s   t  jf | || _|j| _|j| _|j| _tjj	j
| j| j| jddddd| _| j| j d | _| jd | _tjj	j| j| jdd	| _d S )
NFZvalidZchannels_lastpatch_embedding)filtersZkernel_sizestridesZuse_biaspaddingZdata_formatname   r   position_embeddingr&   )super__init__r!   hidden_size	embed_dim
image_size
patch_sizer   keraslayersZConv2Dr"   num_patchesnum_positionsZ	Embeddingr(   selfr!   kwargs	__class__r   r   r+   @   s*    
  z"TFIdeficsVisionEmbeddings.__init__)
embeddingsheightwidthreturnc                 C   s  t |d d }| | j}t |d d }||kr@||kr@|S |d d df }|d d dd f }t |d }	|| jj }
|| jj }|
d |d  }
}tt|}t	|dt
|t
||	f}|
| }|| }tt|d tj}tt|d tj}t|| tj}t|| tj}tjj|||gtjjjd}t
|
t |d kspt
|t |d krtd	t
|
t
|f d
t |d t |d f dt	|dd|	f}tj|tjd d f |fddS )Nr   r   g?r'   )sizemethodzNumber of patches for images (z/) don't match the shape of position embedding ()Zaxis)r   r(   position_idsr!   r/   mathsqrtfloatr   reshapeintcastshapeZfloat32Zint32imageresizeZResizeMethodZBICUBIC
ValueErrorconcatnewaxis)r5   r9   r:   r;   r2   Z	pos_embedr3   Zclass_pos_embedZpatch_pos_embedr-   Znum_h_patchesZnum_w_patchesZsqrt_num_positionsZscale_heightZscale_widthZoriginal_heightZoriginal_widthZ
new_heightZ	new_widthr   r   r   interpolate_pos_encodingX   sB      0z2TFIdeficsVisionEmbeddings.interpolate_pos_encodingF)pixel_valuesrQ   r<   c           
   
   C   s   t |tr|d }tj|dd}t|\}}}}|sn|| jksH|| jkrntd| d| d| j d| j d	| |}t|dd	}t	| j
tjtjd d f |d| jg}tj||gdd
}	|r|	| |	|| }	n|	| | j }	|	S )NrR   )r   r'   r   r   permzInput image size (*z) doesn't match model (z8). You should try to set `interpolate_pos_encoding=True`r   r'   rC   )
isinstancedictr   	transposer   r.   rN   r"   r   Zbroadcast_toclass_embeddingrP   r-   rO   rQ   r(   rD   )
r5   rR   rQ   Z
batch_sizer:   r;   num_channelsZpatch_embedsZclass_embedsr9   r   r   r   call   s(    
 
 
zTFIdeficsVisionEmbeddings.callNc              	   C   s   | j r
d S d| _ tj| jddtjd d f | _| j| jfdd| _t	| dd d k	rt
| jj | jd d d | jjg W 5 Q R X t	| dd d k	rt
| jj | jd  W 5 Q R X d S )NTzself.position_idsr)   rY   )rK   r&   r"   r(   )builtr   ranger3   rP   rD   Z
add_weightr-   rY   getattr
name_scoper"   r&   buildr!   rZ   r(   r5   input_shaper   r   r   r`      s     "zTFIdeficsVisionEmbeddings.build)F)N)r   r   r   r   r+   r   r   rI   rQ   boolr[   r`   __classcell__r   r   r7   r   r   ?   s   '#r   c                       s   e Zd ZdZ fddZejeedddZdeje	ej e	ej e	e
 eeje	ej e	eej  f d	d
dZdddZ  ZS )TFIdeficsVisionAttentionz=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t  jf | || _|j| _|j| _| j| j | _| j| j | jkr^td| j d| j d| jd | _	|j
| _tjjj| jdd| _tjjj| jdd| _tjjj| jdd| _tjjj| jd	d| _d S )
Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      k_projr)   v_projq_projout_proj)r*   r+   r!   r,   r-   Znum_attention_heads	num_headshead_dimrN   scaleZattention_dropoutdropoutr   r0   r1   Denserf   rg   rh   ri   r4   r7   r   r   r+      s    z!TFIdeficsVisionAttention.__init__)tensorseq_lenbszc                 C   s*   t jt |||| j| jfddddgdS )Nr   r'   r   r   rS   )r   rX   rH   rj   rk   )r5   ro   rp   rq   r   r   r   _shape   s    zTFIdeficsVisionAttention._shapeNFr   attention_maskcausal_attention_maskoutput_attentionsr<   c              	   C   s  t |\}}}| || j }| | |d|}	| | |d|}
|| j d| jf}t	| ||||}t	|	|}	t	|
|}
t |	d }tj
j||	dd}tjjt||| j ||gd|| j ||g dt| d |dk	rXt ||d||gkr&td	|d||f dt | t	||| j||f| }t	||| j ||f}|dk	rt ||d||gkrtd	|d||f dt | t	||| j||f| }t	||| j ||f}tjj|dd
}|rt	||| j||f}t	||| j ||f}nd}tjj|| jd}tj
||
}tjjt||| j || jgd|| j || jg dt| d t	||| j|| jf}tj|ddddgd}t	||||f}| |}||fS )z#Input shape: Batch x Time x Channelr=   r   T)Ztranspose_bz$Attention weights should be of size z	, but is )messageNz!Attention mask should be of size rC   )Zrater   r'   r   rS   )r   rh   rl   rr   rf   rg   rj   rk   r   rH   Zlinalgmatmul	debuggingZassert_equalrK   rN   nnZsoftmaxrm   rX   ri   )r5   r   rt   ru   rv   rq   Ztgt_lenr-   Zquery_statesZ
key_statesZvalue_statesZ
proj_shapeZsrc_lenattn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr   r   r   r[      s\    	 

"
zTFIdeficsVisionAttention.callc              	   C   s  | j r
d S d| _ t| dd d k	rNt| jj | j| j| jf W 5 Q R X t| dd d k	rt| jj | j| j| jf W 5 Q R X t| dd d k	rt| j	j | j	| j| jf W 5 Q R X t| dd d k	r
t| j
j | j
| j| jf W 5 Q R X d S )NTrf   rg   rh   ri   )r\   r^   r   r_   rf   r&   r`   r-   rg   rh   ri   ra   r   r   r   r`     s    zTFIdeficsVisionAttention.build)NNF)N)r   r   r   r   r+   r   r   rI   rr   r   rc   r   r[   r`   rd   r   r   r7   r   re      s      Nre   c                       s:   e Zd Z fddZejejdddZd	ddZ  ZS )
TFIdeficsVisionMLPc                    sP   t  jf | || _t|j| _tjjj	|j
dd| _tjjj	|jdd| _d S )Nfc1r)   fc2)r*   r+   r!   r   Z
hidden_actactivation_fnr   r0   r1   rn   intermediate_sizer}   r,   r~   r4   r7   r   r   r+   *  s
    zTFIdeficsVisionMLP.__init__)r   r<   c                 C   s"   |  |}| |}| |}|S N)r}   r   r~   )r5   r   r   r   r   r[   1  s    


zTFIdeficsVisionMLP.callNc              	   C   s   | j r
d S d| _ t| dd d k	rJt| jj | j| jj W 5 Q R X t| dd d k	rt| j	j | j	| jj
 W 5 Q R X d S )NTr}   r~   )r\   r^   r   r_   r}   r&   r`   r!   r,   r~   r   ra   r   r   r   r`   7  s    zTFIdeficsVisionMLP.build)N)	r   r   r   r+   r   r   r[   r`   rd   r   r   r7   r   r|   )  s   r|   c                       sT   e Zd Zed fddZdejejejee e	ej dddZ
dd	d
Z  ZS )TFIdeficsVisionEncoderLayerr    c                    sb   t  jf | |j| _t|dd| _tjjj	|j
dd| _t|dd| _tjjj	|j
dd| _d S )N	self_attnr)   layer_norm1epsilonr&   mlplayer_norm2)r*   r+   r,   r-   re   r   r   r0   r1   LayerNormalizationlayer_norm_epsr   r|   r   r   r4   r7   r   r   r+   D  s    z$TFIdeficsVisionEncoderLayer.__init__Frs   c                 C   sd   |}|  |}| j||||d\}}|| }|}| |}| |}|| }|f}|r`||f7 }|S )a9  
        Args:
            hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`tf.Tensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   rt   ru   rv   )r   r   r   r   )r5   r   rt   ru   rv   Zresidualr{   outputsr   r   r   r[   L  s"    




z TFIdeficsVisionEncoderLayer.callNc              	   C   s   | j r
d S d| _ t| dd d k	rNt| jj | jd d | jg W 5 Q R X t| dd d k	rt| jj | jd d | jg W 5 Q R X d S )NTr   r   )	r\   r^   r   r_   r   r&   r`   r-   r   ra   r   r   r   r`   t  s    z!TFIdeficsVisionEncoderLayer.build)F)N)r   r   r   r   r+   r   r   r   rc   r   r[   r`   rd   r   r   r7   r   r   C  s    (r   c                
       sp   e Zd ZdZed fddZdeej eej ee	 ee	 ee	 ee	 e
eef dddZdd	d
Z  ZS )TFIdeficsVisionEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`TFIdeficsVisionEncoderLayer`].

    Args:
        config: IdeficsVisionConfig
    r    c                    s8   t  jf |  | _ fddt jD | _d| _d S )Nc                    s   g | ]}t  d | dqS )zlayers.r)   )r   ).0ir    r   r   
<listcomp>  s    z3TFIdeficsVisionEncoder.__init__.<locals>.<listcomp>F)r*   r+   r!   r]   Znum_hidden_layersr1   gradient_checkpointingr4   r7   r    r   r+     s    
zTFIdeficsVisionEncoder.__init__N)rt   ru   rv   output_hidden_statesreturn_dicttrainingr<   c                    s   dk	r n| j j |dk	r |n| j j}|dk	r4|n| j j}|rDdnd} rPdnd}	|}
t| jD ]l\}}|rx||
f }| jr|r fdd}t|||
||}n||
|| d}|d }
 rb|	|d f }	qb|r||
f }|st	dd	 |
||	fD S t
|
||	d
S )a  
        Args:
            inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr   c                    s    fdd}|S )Nc                     s    | f S r   r   )inputs)modulerv   r   r   custom_forward  s    zRTFIdeficsVisionEncoder.call.<locals>.create_custom_forward.<locals>.custom_forwardr   )r   r   rv   )r   r   create_custom_forward  s    z:TFIdeficsVisionEncoder.call.<locals>.create_custom_forwardr   r   r   c                 s   s   | ]}|d k	r|V  qd S r   r   )r   vr   r   r   	<genexpr>  s      z.TFIdeficsVisionEncoder.call.<locals>.<genexpr>)r   r   r   )r!   rv   r   use_return_dict	enumerater1   r   r   Zrecompute_gradtupler   )r5   inputs_embedsrt   ru   rv   r   r   r   Zencoder_statesZall_attentionsr   idxZencoder_layerr   Zlayer_outputsr   r   r   r[     sH    '


  zTFIdeficsVisionEncoder.callc              
   C   sR   | j r
d S d| _ t| dd d k	rN| jD ]&}t|j |d  W 5 Q R X q&d S )NTr1   )r\   r^   r1   r   r_   r&   r`   )r5   rb   Zlayerr   r   r   r`     s    
zTFIdeficsVisionEncoder.build)NNNNNN)N)r   r   r   r   r   r+   r   r   r   rc   r   r   r   r[   r`   rd   r   r   r7   r   r     s$         
Xr   c                
       sj   e Zd Zed fddZdeej ee ee ee ee ee e	e
ef dddZdd	d
Z  ZS )TFIdeficsVisionTransformerr    c                    sj   t  j|f| || _|j| _t|dd| _tjj	j
|jdd| _t|dd| _tjj	j
|jdd| _d S )Nr9   r)   pre_layrnormr   encoderpost_layernorm)r*   r+   r!   r,   r-   r   r9   r   r0   r1   r   r   r   r   r   r   r4   r7   r   r   r+     s    z#TFIdeficsVisionTransformer.__init__NF)rR   rv   r   rQ   r   r   r<   c                 C   s   |dk	r|n| j j}|dk	r |n| j j}|dk	r4|n| j j}|dkrLtd| j||d}| |}| j|||||d}|d }	|	dddddf }
| |
}
|s|	|
f|dd  S t	|	|
|j
|jdS )z
        Returns:

        Nz You have to specify pixel_values)rQ   )r   rv   r   r   r   r   r   )r   Zpooler_outputr   r   )r!   rv   r   r   rN   r9   r   r   r   r	   r   r   )r5   rR   rv   r   rQ   r   r   r   Zencoder_outputsr   Zpooled_outputr   r   r   r[     s4    

zTFIdeficsVisionTransformer.callc              	   C   s   | j r
d S d| _ t| dd d k	rFt| jj | jd  W 5 Q R X t| dd d k	rt| jj | jd d | jg W 5 Q R X t| dd d k	rt| j	j | j	d  W 5 Q R X t| dd d k	rt| j
j | j
d | jg W 5 Q R X d S )NTr9   r   r   r   )r\   r^   r   r_   r9   r&   r`   r   r-   r   r   ra   r   r   r   r`   -  s    z TFIdeficsVisionTransformer.build)NNNFNF)N)r   r   r   r   r+   r   r   r   rc   r   r   r	   r[   r`   rd   r   r   r7   r   r     s"         
.r   )&r   rE   Zdataclassesr   typingr   r   r   Z
tensorflowr   Zactivations_tfr   Zmodeling_tf_outputsr   r	   Zmodeling_tf_utilsr
   r   Ztf_utilsr   utilsr   r   Zconfiguration_ideficsr   Z
get_loggerr   loggerr   r0   r1   ZLayerr   re   r|   r   r   r   r   r   r   r   <module>   s&   
qy=s