U
    4Af                    @  sp  d Z ddlmZ ddlZddlmZ ddlmZmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZ dd	lmZmZmZmZmZmZ dd
lmZmZmZ ddlmZm Z m!Z!m"Z"m#Z#m$Z$ ddl%m&Z& e#'e(Z)dZ*dZ+dZ,ddddddZ-dBdddddZ.dCdddddZ/G dd dej0j1Z2G d d! d!ej0j3Z4G d"d# d#ej0j3Z5G d$d% d%ej0j3Z6G d&d' d'ej0j3Z7G d(d) d)ej0j3Z8G d*d+ d+eZ9eG d,d- d-eZ:eG d.d/ d/eZ;eG d0d1 d1eZ<d2Z=d3Z>eG d4d5 d5ej0j3Z?eG d6d7 d7ej0j3Z@eG d8d9 d9ej0j3ZAe!d:e=G d;d< d<e9ZBG d=d> d>ej0j3ZCe!d?e=G d@dA dAe9ZDdS )DzTF 2.0 LED model.    )annotationsN)	dataclass)ListOptionalTupleUnion   )get_tf_activation)+TFBaseModelOutputWithPastAndCrossAttentions)TFModelInputTypeTFPreTrainedModelget_initializerkeraskeras_serializableunpack_inputs)check_embeddings_within_bounds
shape_liststable_softmax)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )	LEDConfigzallenai/led-base-16384r   g    ח	tf.Tensorint)	input_idspad_token_iddecoder_start_token_idc              
   C  s   t || j}t || j}t t| d dft || j}t || d d d df gd}t |dkt t|t || j|}t j	|t j
d| jd}t |g t |}W 5 Q R X |S )Nr   r   idtype)tfcastr#   fillr   convert_to_tensorconcatwhere	debuggingZassert_greater_equalconstantZcontrol_dependenciesidentity)r   r   r    Zstart_tokensZshifted_input_idsZassert_gte0 r-   K/tmp/pip-unpacked-wheel-zw5xktn0/transformers/models/led/modeling_tf_led.pyshift_tokens_right<   s       r/   tf.TensorShape)input_ids_shapepast_key_values_lengthc                 C  s   | d }| d }t ||ft }t t|d }t |t |d t|d dfk d|}|dkrt jt ||f|gdd}t 	|ddddddf |dddfS )zB
    Make causal mask used for bi-directional self-attention.
    r   r   r!           axisN)
r$   onesLARGE_NEGATIVEranger   r)   reshaper(   zerostile)r1   r2   bsztgt_lenmaskZ	mask_condr-   r-   r.   _make_causal_maskU   s    *r?   zOptional[int])r>   r=   c                 C  sj   t | d }|dk	r|n|}td}tj| |jd} t| ddddddf dd|df}|| t S )z_
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    r   Ng      ?r"   )r   r$   r+   r%   r#   r;   r7   )r>   r=   src_lenZone_cstZexpanded_maskr-   r-   r.   _expand_maskg   s    
(rA   c                      s>   e Zd ZdZddd fddZdddd fd	d
Z  ZS )TFLEDLearnedPositionalEmbeddingzN
    This module learns positional embeddings up to a fixed maximum size.
    r   )num_embeddingsembedding_dimc                   s   t  j||f| d S N)super__init__)selfrC   rD   kwargs	__class__r-   r.   rG   y   s    z(TFLEDLearnedPositionalEmbedding.__init__r   r0   )input_shaper2   c                   s8   |d }t j|ddd}||7 }t t j|t jdS )z/Input is expected to be of size [bsz x seqlen].r   r8   )deltanamer"   )r$   r8   rF   callr%   int32)rH   rL   r2   seq_lenZposition_idsrJ   r-   r.   rO   |   s    z$TFLEDLearnedPositionalEmbedding.call)r   __name__
__module____qualname____doc__rG   rO   __classcell__r-   r-   rJ   r.   rB   t   s   rB   c                      s   e Zd Z fddZdddZd ddZd	d
 Zedd Zdd Z	edd Z
edd Zedd Zedd Zdd Zdd Zdd Zdd Z  ZS )!TFLEDEncoderSelfAttentionc                   s  t  jf | || _|j|j dkr<td|j d|j |j| _t|j|j | _|j| _	t
jj| j	t|jdd| _t
jj| j	t|jdd| _t
jj| j	t|jdd| _t
jj| j	t|jdd| _t
jj| j	t|jd	d| _t
jj| j	t|jd
d| _t
j|j| _t
j|j| _|| _|j| j }|d dks\td| j d| |dks|td| j d| |d | _d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads (query)Zkernel_initializerrN   keyvaluequery_global
key_globalvalue_global   z`attention_window` for layer z  has to be an even value. Given z has to be positive. Given )rF   rG   confighidden_sizeZnum_attention_heads
ValueError	num_headsr   head_dim	embed_dimr   layersDenser   Zinitializer_rangerY   rZ   r[   r\   r]   r^   DropoutZattention_probs_dropout_probdropoutglobal_dropoutlayer_idattention_windowAssertionErrorone_sided_attn_window_size)rH   r`   rk   rI   rl   rJ   r-   r.   rG      sh    
z"TFLEDEncoderSelfAttention.__init__Nc              	   C  s  | j s~td | j| jjf W 5 Q R X td | j| jjf W 5 Q R X td | j| jjf W 5 Q R X | j rd S d| _ t	| dd d k	rt| j
j | j
d d | jjg W 5 Q R X t	| dd d k	rt| jj | jd d | jjg W 5 Q R X t	| dd d k	rRt| jj | jd d | jjg W 5 Q R X t	| dd d k	rt| jj | jd d | jjg W 5 Q R X t	| dd d k	rt| jj | jd d | jjg W 5 Q R X t	| dd d k	rt| jj | jd d | jjg W 5 Q R X d S )Nr\   r]   r^   TrY   rZ   r[   )builtr$   
name_scoper\   buildr`   ra   r]   r^   getattrrY   rN   rZ   r[   rH   rL   r-   r-   r.   rq      s8         zTFLEDEncoderSelfAttention.buildFc                 C  s  |\}}}}}}|  |}	| |}
| |}t|\}}}tjj|| jd| j d| d |	tj	tj
| j|	jd }	t|	||| j| jf}	t|
||| j| jf}
| |	|
| j}|dk}tj
||	jdt }| tt||| j}||7 }tjjt|||| j| jd d gd| d	| d	| j d	| jd d  d
t| 
d | |\}}}}|r| j||	|
||||d}t|dd}|rt|ddddddf dd| j| jd | d f}n4t|ddddddf dd| j| jd d f}t|tjt||jd|}|dk	rftjjt|| jgd| j dt| d t|d| }| j||d}t|||| j| jf}|r| j|||||d}n| ||| j}tjjt|||| j| jgdd t||||f}|r| j|||||||||d	\}}nt|| j||f}|rht|ddddddf dd| j| jd | d f}n4t|ddddddf dd| j| jd d f}t|tjt||jd|}|||f}|S )a  
        LongformerSelfAttention expects *len(hidden_states)* to be multiple of *attention_window*. Padding to
        *attention_window* happens in LongformerModel.forward to avoid redoing the padding on each layer.

        The *attention_mask* is changed in [`LongformerModel.forward`] from 0, 1, 2 to:

            - -10000: no attention
            - 0: local attention
            - +10000: global attention
        z&hidden_states should have embed_dim = z
, but has messager"   r   r_   r   zattn_probs should be of size (z, z), but is of size )attn_scoresquery_vectorskey_vectorsmax_num_global_attn_indicesis_index_global_attn_nonzero"is_local_index_global_attn_nonzero%is_local_index_no_global_attn_nonzeror!   r4   N/Head mask for a single layer should be of size 	, but is )r   r   r!   r   training)value_vectors
attn_probsry   rz   r{   zUnexpected size)	attn_outputhidden_statesry   layer_head_maskr{   rz   r|   is_index_maskedr   )rY   rZ   r[   r   r$   r*   assert_equalre   mathsqrtr%   rd   r#   r9   rc    _sliding_chunks_query_key_matmulrn   r7   r6   _get_global_attn_indices"_concat_with_global_key_attn_probsr   r;   r)   r:   ri   (_compute_attn_output_with_global_indices'_sliding_chunks_matmul_attn_probs_value'_compute_global_attn_output_from_hidden)rH   inputsr   r   attention_maskr   r   is_index_global_attnis_global_attnrw   rx   r   
batch_sizerQ   re   rv   Z#remove_from_windowed_attention_maskZ
float_maskZdiagonal_maskry   rz   r{   r|   r   Zmasked_indexr   global_attn_probsZmasked_global_attn_indexoutputsr-   r-   r.   rO      s    


  0

	    
zTFLEDEncoderSelfAttention.callc              
   C  s  t |\}}}}tjj||d  dd|d  d| d tjjt |t |dt | dt | d || d }tt|d	|| ||f}tt|d	|| ||f}| ||}	| ||}
tj|	|
jd
}	t	d|	|
}t
ddgddgddgddgg}| ||}tj|ddddd|d|d f |dddd|dd|d f gdd}tjtj|| d||f|jd
|dddd|d  d|d df gdd}tjtj|d|gddgdddddd|d|f tj|| d||f|jd
gdd}ttj|d tjd
dddddf || d||fdk }t|||}tj||gdd}tt||||d| d fd	}| ||}|S )a  
        Matrix multiplication of query and key tensors using with a sliding window attention pattern. This
        implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an
        overlap of size window_overlap
        r_   r   z&Sequence length should be multiple of z. Given rt   z7Shape of query and key should be equal, but got query: z
 and key: r   r   r_   r   r   r"   zbcxd,bcyd->bcxyNr!   r4   r   )shiftr5   )r   r$   r*   r   r9   	transpose_chunkr%   r#   einsumr'    _pad_and_transpose_last_two_dimsr(   r:   Zrollr;   r8   Zint64r)   _mask_invalid_locations)rH   rY   rZ   window_overlapr   rQ   rc   rd   chunks_countZchunked_queryZchunked_keyZchunked_attention_scorespaddingsZ!diagonal_chunked_attention_scoresZdiagonal_attn_scores_up_triangZdiagonal_attn_scores_low_triangZ diagonal_attn_scores_first_chunkZfirst_chunk_maskZdiagonal_attention_scoresr-   r-   r.   r     s    
	
"
""	(
" z:TFLEDEncoderSelfAttention._sliding_chunks_query_key_matmulc                 C  s   t jt jt j||d fddddgd}t dt| d | gdt| d | d gg}t ||}|t j|ddgd }t |d d d d d d f t| d dddf}t	d t 
|  }t t j|d|| } | S )Nr   shaper!   r   r4   r   inf)r$   reverseZlinalgZ	band_partr6   r'   r   padr;   floatZ	ones_liker)   r   greater)Zinput_tensorr   Zmask_2d_upperpaddingZmask_2dZmask_4dZ
inf_tensorr-   r-   r.   r     s    *0z1TFLEDEncoderSelfAttention._mask_invalid_locationsc              	   C  s  t |\}}}}tjj||d  ddd tjjt |dd t |dd dd tjjt |d d| d d	d || d }tt|d
|| || |d| d f}	tt|d
|| ||f}tddg||gddgg}
tj||
dd}d| | }t |d | | | }tj	t||| df||}t||| |d d| |f}tjjt ||| |d d| |gdd | 
|	}	td|	|}tt|||||fd
}|S )z
        Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the
        same shape as `attn_probs`
        r_   r   z0Seq_len has to be multiple of 2 * window_overlaprt   Nr   z:value and attn_probs must have same dims (except head_dim)r   z4attn_probs last dim has to be 2 * window_overlap + 1r   r!   Zconstant_valuesz!Chunked value has the wrong shapezbcwd,bcdh->bcwh)r   r$   r*   r   r9   r   r'   r   signalframe_pad_and_diagonalizer   )rH   r   r[   r   r   rQ   rc   rd   r   Zchunked_attn_probsr   Zpadded_value
frame_sizeframe_hop_sizeZchunked_valuecontextr-   r-   r.   r   2  sl    
  





zATFLEDEncoderSelfAttention._sliding_chunks_matmul_attn_probs_valuec                 C  s4   t | |} t| \}}}}t | ||||f} | S )z)pads rows and then flips rows and columns)r$   r   r   r9   )Zhidden_states_paddedr   r   
chunk_size
seq_length
hidden_dimr-   r-   r.   r   {  s     z:TFLEDEncoderSelfAttention._pad_and_transpose_last_two_dimsc                 C  s   t | \}}}}tddgddgddgd|d gg}t| |} t| ||df} | ddddd| f } t| ||||| f} | ddddddddf } | S )aY  
        shift every row 1 step right, converting columns into diagonals.

        Example:

        ```python
        chunked_hidden_states: [
            0.4983,
            2.6918,
            -0.0071,
            1.0492,
            -1.8348,
            0.7672,
            0.2986,
            0.0285,
            -0.7584,
            0.4206,
            -0.0405,
            0.1599,
            2.0514,
            -1.1600,
            0.5372,
            0.2629,
        ]
        window_overlap = num_rows = 4
        ```

                     (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
                       0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206,
                       -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
        r   r   r!   N)r   r$   r'   r   r9   )chunked_hidden_statesZtotal_num_headsZ
num_chunksr   r   r   r-   r-   r.   r     s&    !&   z.TFLEDEncoderSelfAttention._pad_and_diagonalizec           	      C  s   t | \}}}d|d|   d }|| }d| }t| ||| f} tj| ||}tjjt ||||gd|||g dt | dd t|||d| |f}|S )zBconvert into overlapping chunks. Chunk size = 2w, overlap size = wr_   r   z^Make sure chunking is correctly applied. `Chunked hidden states should have output  dimension z
, but got .rt   )r   r$   r9   r   r   r*   r   )	r   r   r   r   r   Znum_output_chunksr   r   r   r-   r-   r.   r     s     	z TFLEDEncoderSelfAttention._chunkc                 C  sz   t jj| dd}t j|t djd}t |}t | }t |t j	|ddk }t |}t t j
|}||||fS )z<compute global attn indices required throughout forward passr   r4   r"   r!   )r$   r   Zcount_nonzeror%   r+   r#   Z
reduce_maxr)   r8   Zexpand_dimsZlogical_not)r   Znum_global_attn_indicesry   rz   Zis_local_index_global_attnr{   r|   r-   r-   r.   r     s    

 
z2TFLEDEncoderSelfAttention._get_global_attn_indicesc                 C  s   t |d }t||}	tj||	||| j| jfd}
td||
}t|d}t |d ftt |dd   }t	|d }tj
||jd}t|||}t|d}tj||fd	d
}|S )Nr   r   zblhd,bshd->blhs)r   r   r   r_        r"   )r   r_   r   r   r!   r4   )r   r$   	gather_nd
scatter_ndrc   rd   r   r   tupler6   r%   r#   tensor_scatter_nd_updater(   )rH   rv   rx   rw   ry   rz   r{   r|   r   global_key_vectorsZkey_vectors_only_globalZattn_probs_from_global_keyZ attn_probs_from_global_key_trans
mask_shaper>   r-   r-   r.   r     s4    
z<TFLEDEncoderSelfAttention._concat_with_global_key_attn_probsc                 C  s   t |d }|d d d d d d d |f }t||}tj||||| j| jfd}	td||	}
|d d d d d d |d f }| ||| j}|
| S )Nr   r   zblhs,bshd->blhd)	r   r$   r   r   rc   rd   r   r   rn   )rH   r   r   ry   rz   r{   r   Zattn_probs_only_globalglobal_value_vectorsZvalue_vectors_only_globalZattn_output_only_globalZattn_probs_without_globalZattn_output_without_globalr-   r-   r.   r   -  s(        zBTFLEDEncoderSelfAttention._compute_attn_output_with_global_indicesc
                 C  s  t |d d \}
}t||}tj|||
|| jfd}| |}| |}| |}|tj	tj
| j|jd }| ||
}| ||
}| ||
}tj||dd}tjjt ||
| j ||gd|
| j ||f dt | dd	 t||
| j||f}t|d
}t |d ftt |dd   }t|d }tj
||jd}t|||}t|d
}t|d d d d d d f dt |d ddf}t|d|}t||
| j ||f}t|dd}|d k	r6tjjt || jgd| j dt | d	 t|dt||
| j||f }t||
| j ||f}| j||	d}t||}tjjt ||
| j || jgd|
| j || jf dt | dd	 t||
| j|| jf}tt|d
|}t|t |d df}t|||}t||
| j||f}||fS )Nr_   r   r"   TZtranspose_bz7global_attn_scores have the wrong size. Size should be r~   r   rt   r   r   r   r   r   r!   r4   r}   r   r!   r   r   r   z=global_attn_output tensor has the wrong size. Size should be )r   r$   r   r   re   r\   r]   r^   r   r   r%   rd   r#   reshape_and_transposematmulr*   r   rc   r9   r   r   r6   r   r;   r)   r   rj   )rH   r   r   ry   r   r{   rz   r|   r   r   r   rQ   Zglobal_attn_hidden_statesZ global_query_vectors_only_globalr   r   Zglobal_attn_scoresZglobal_attn_scores_transr   Zglobal_attn_maskZ	attn_maskZglobal_attn_probs_floatr   Zglobal_attn_outputZnonzero_global_attn_outputr-   r-   r.   r   V  s    



 
0
  "

   zATFLEDEncoderSelfAttention._compute_global_attn_output_from_hiddenc                 C  s6   t t t ||d| j| jfd|| j d| jfS )Nr!   r   )r$   r9   r   rc   rd   )rH   Zvectorr   r-   r-   r.   r     s    z/TFLEDEncoderSelfAttention.reshape_and_transpose)N)F)rS   rT   rU   rG   rq   rO   r   staticmethodr   r   r   r   r   r   r   r   r   r   rW   r-   r-   rJ   r.   rX      s,   ;
" 
 @t
I


3

6) 
rX   c                      s0   e Zd Z fddZd	ddZd
ddZ  ZS )TFLEDEncoderAttentionc                   s>   t  jf | t||dd| _tjj|jddd| _|| _	d S )Nlongformer_self_attn)rk   rN   ToutputZuse_biasrN   )
rF   rG   rX   r   r   rf   rg   d_modeloutput_denser`   rH   r`   rk   rI   rJ   r-   r.   rG     s    zTFLEDEncoderAttention.__init__Fc                 C  sR   |\}}}}}}| j ||||||g|d}	| j|	d |d}
|
f|	dd   }|S )Nr   r   r   )r   r   )rH   r   r   r   r   r   r   r   r   Zself_outputsZattention_outputr   r-   r-   r.   rO     s    zTFLEDEncoderAttention.callNc              	   C  s   | j r
d S d| _ t| dd d k	rFt| jj | jd  W 5 Q R X t| dd d k	rt| jj | jd d | jj	g W 5 Q R X d S )NTr   r   )
ro   rr   r$   rp   r   rN   rq   r   r`   r   rs   r-   r-   r.   rq     s    zTFLEDEncoderAttention.build)F)NrS   rT   rU   rG   rO   rq   rW   r-   r-   rJ   r.   r     s   
r   c                      sf   e Zd ZdZddddddd fd	d
ZddddddZddddddddddZdddZ  ZS )TFLEDDecoderAttentionz6Multi-headed attention from "Attention Is All You Needr3   FTr   r   bool)re   rc   ri   
is_decoderbiasc                   s   t  jf | || _|| _tj|| _|| | _| j| | jksJt	d| jd | _
|| _tjj||dd| _tjj||dd| _tjj||dd| _tjj||dd| _d S )Nz(embed_dim must be divisible by num_headsg      k_projr   q_projv_projout_proj)rF   rG   re   rc   r   rf   rh   ri   rd   rm   scalingr   rg   r   r   r   r   )rH   re   rc   ri   r   r   rI   rJ   r-   r.   rG     s    	
zTFLEDDecoderAttention.__init__r   )tensorrQ   r<   c              	   C  s    t t |||| j| jfdS )Nr   )r$   r   r9   rc   rd   )rH   r   rQ   r<   r-   r-   r.   _shape+  s    zTFLEDDecoderAttention._shapeNtf.Tensor | NoneTuple[Tuple[tf.Tensor]] | Nonez"Tuple[tf.Tensor, tf.Tensor | None])r   key_value_statespast_key_valuer   r   returnc              	   C  sR  |dk	}t |\}}	}
| || j }|rD|dk	rD|d }|d }n|rr| | |d|}| | |d|}n|dk	r| | |d|}| | |d|}tj|d |gdd}tj|d |gdd}n(| | |d|}| | |d|}| jr||f}|| j	 d| j
f}t| ||	||}t||}t||}t |d }tj||dd}tjjt ||| j	 |	|gd	|| j	 |	|f d
t | d |dk	r tjjt ||d|	|gd|d|	|f d
t | d t||| j	|	|ftj||jd }t||| j	 |	|f}t|dd}|dk	rtjjt || j	gd| j	 d
t | d t|dt||| j	|	|f }t||| j	 |	|f}| j||d}t||}tjjt ||| j	 |	| j
gd|| j	|	| j
f d
t | d tt||| j	|	| j
fd}t|||	|
f}| |}t||| j	|	|f}|||fS )z#Input shape: Batch x Time x ChannelNr   r   r!   r_   r4   Tr   z$Attention weights should be of size r~   rt   z!Attention mask should be of size r"   r}   r   r   z `attn_output` should be of size r   )r   r   r   r   r   r   r$   r(   r   rc   rd   r9   r   r*   r   r%   r#   r   ri   r   r   )rH   r   r   r   r   r   r   Zis_cross_attentionr<   r=   re   Zquery_statesZ
key_statesZvalue_statesZ
proj_shaper@   Zattn_weightsr   r   r-   r-   r.   rO   .  s    
	

	 
	 	 
zTFLEDDecoderAttention.callc              	   C  s  | j r
d S d| _ t| dd d k	rNt| jj | jd d | jg W 5 Q R X t| dd d k	rt| jj | jd d | jg W 5 Q R X t| dd d k	rt| j	j | j	d d | jg W 5 Q R X t| dd d k	r
t| j
j | j
d d | jg W 5 Q R X d S )NTr   r   r   r   )ro   rr   r$   rp   r   rN   rq   re   r   r   r   rs   r-   r-   r.   rq     s    zTFLEDDecoderAttention.build)r3   FT)NNNNF)N)	rS   rT   rU   rV   rG   r   rO   rq   rW   r-   r-   rJ   r.   r     s           xr   c                      sH   e Zd Zddd fddZdddddddd	d
dZdddZ  ZS )TFLEDEncoderLayerr   r   )r`   rk   c                   s   t  jf | |j| _t||dd| _tjjddd| _	tj
|j| _t|j| _tj
|j| _tjj|jdd| _tjj| jdd| _tjjddd| _|| _d S )	N	self_attnrN   h㈵>self_attn_layer_normepsilonrN   fc1fc2final_layer_norm)rF   rG   r   re   r   r   r   rf   LayerNormalizationr   rh   ri   r	   activation_functionactivation_fnactivation_dropoutrg   encoder_ffn_dimr   r   r   r`   r   rJ   r-   r.   rG     s    zTFLEDEncoderLayer.__init__Fr   r   r   r   r   r   r   r   c           
      C  s   |}| j ||||||g|d}	|	d }tjjt|t|dt| dt| d | j||d}|| }| |}|}| | |}| j	||d}| 
|}| j||d}|| }| |}|f|	dd  S )a  
        Args:
            hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
            attention_mask (`tf.Tensor`): attention mask of size
                *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
            layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
                *(config.encoder_attention_heads,)*.
        r   r   z&Self attn modified the shape of query  to rt   r   N)r   r$   r*   r   r   ri   r   r   r   r   r   r   )
rH   r   r   r   r   r   r   r   residuallayer_outputsr-   r-   r.   rO     s,    


zTFLEDEncoderLayer.callNc              	   C  sH  | j r
d S d| _ t| dd d k	rFt| jj | jd  W 5 Q R X t| dd d k	rt| jj | jd d | jg W 5 Q R X t| dd d k	rt| j	j | j	d d | jg W 5 Q R X t| dd d k	rt| j
j | j
d d | jjg W 5 Q R X t| dd d k	rDt| jj | jd d | jg W 5 Q R X d S )NTr   r   r   r   r   )ro   rr   r$   rp   r   rN   rq   r   re   r   r   r`   r   r   rs   r-   r-   r.   rq     s$     zTFLEDEncoderLayer.build)F)Nr   r-   r-   rJ   r.   r     s    -r   c                	      sH   e Zd Zdd fddZdddddddd	d
ddZdddZ  ZS )TFLEDDecoderLayerr   r`   c                   s   t  jf | |j| _t| j|j|jddd| _tj	
|j| _t|j| _tj	
|j| _tj	jddd| _t| j|j|jddd| _tj	jdd	d| _tj	j|jd
d| _tj	j| jdd| _tj	jddd| _|| _d S )Nr   T)re   rc   ri   rN   r   r   r   r   encoder_attn)ri   rN   r   encoder_attn_layer_normr   r   r   r   )rF   rG   r   re   r   Zdecoder_attention_headsZattention_dropoutr   r   rf   rh   ri   r	   r   r   r   r   r   r   r   rg   decoder_ffn_dimr   r   r   r`   rH   r`   rI   rJ   r-   r.   rG   	  s2    zTFLEDDecoderLayer.__init__NFr   zTuple[tf.Tensor] | Nonez?Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]])r   encoder_hidden_statesencoder_attention_maskr   encoder_layer_head_maskr   r   c	                 C  s  |}	|dk	r|dd nd}
| j ||
||d\}}}| j||d}|	| }| |}d}d}|dk	r|}	|dk	r||dd nd}| j|||||d\}}}| j||d}|	| }| |}|| }|}	| | |}| j||d}| |}| j||d}|	| }| 	|}||||fS )a  
        Args:
            hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
            attention_mask (`tf.Tensor`): attention mask of size
                *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
            encoder_hidden_states (`tf.Tensor`):
                cross attention input to the layer of shape *(batch, seq_len, embed_dim)*
            encoder_attention_mask (`tf.Tensor`): encoder attention mask of size
                *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
            layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
                *(config.encoder_attention_heads,)*.
            encoder_layer_head_mask (`tf.Tensor`): mask for encoder attention heads in a given layer of
                size *(config.encoder_attention_heads,)*.
            past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states
        Nr_   )r   r   r   r   r   r   )r   r   r   r   r   )
r   ri   r   r   r   r   r   r   r   r   )rH   r   r   r   r   r   r   r   r   r   Zself_attn_past_key_valueZself_attn_weightspresent_key_valueZcross_attn_present_key_valueZcross_attn_weightsZcross_attn_past_key_valuer-   r-   r.   rO   %  sN    



zTFLEDDecoderLayer.callc              	   C  s  | j r
d S d| _ t| dd d k	rFt| jj | jd  W 5 Q R X t| dd d k	rt| jj | jd d | jg W 5 Q R X t| dd d k	rt| j	j | j	d  W 5 Q R X t| dd d k	rt| j
j | j
d d | jg W 5 Q R X t| dd d k	r8t| jj | jd d | jg W 5 Q R X t| dd d k	rzt| jj | jd d | jjg W 5 Q R X t| dd d k	rt| jj | jd d | jg W 5 Q R X d S )	NTr   r   r   r   r   r   r   )ro   rr   r$   rp   r   rN   rq   r   re   r   r   r   r   r`   r   r   rs   r-   r-   r.   rq   u  s0     zTFLEDDecoderLayer.build)NNNNNNF)Nr   r-   r-   rJ   r.   r     s          Pr   c                      s(   e Zd ZeZdZe fddZ  ZS )TFLEDPreTrainedModelledc                   s"   t  j}tjdtjdd|d< |S )N)NNglobal_attention_maskr   )rF   input_signaturer$   Z
TensorSpecrP   )rH   sigrJ   r-   r.   r     s    z$TFLEDPreTrainedModel.input_signature)	rS   rT   rU   r   config_classZbase_model_prefixpropertyr   rW   r-   r-   rJ   r.   r     s   r   c                   @  sB   e Zd ZU dZdZded< dZded< dZded< dZded< dS )	TFLEDEncoderBaseModelOutputaI  
    Base class for Longformer's outputs, with potential hidden states, local and global attentions.

    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
            attention_window + 1)`, where `x` is the number of tokens with global attention mask.

            Local attentions weights after the attention softmax, used to compute the weighted average in the
            self-attention heads. Those are the attention weights from every token in the sequence to every token with
            global attention (first `x` values) and to every token in the attention window (remaining `attention_window
            + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
            remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
            token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
            (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
            If the attention window contains a token with global attention, the attention weight at the corresponding
            index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
            attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
            accessed from `global_attentions`.
        global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
            is the number of tokens with global attention mask.

            Global attentions weights after the attention softmax, used to compute the weighted average in the
            self-attention heads. Those are the attention weights from every token with global attention to every token
            in the sequence.
    Nr   last_hidden_stateTuple[tf.Tensor, ...] | Noner   
attentionsglobal_attentions)	rS   rT   rU   rV   r   __annotations__r   r  r  r-   r-   r-   r.   r     s
   
#r   c                   @  s~   e Zd ZU dZdZded< dZded< dZded< dZded	< dZ	ded
< dZ
ded< dZded< dZded< dZded< dS )TFLEDSeq2SeqModelOutputa  
    Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
    decoding.

    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the decoder of the model.

            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
            hidden_size)` is output.
        past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
            sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
            used (see `past_key_values` input) to speed up sequential decoding.
        decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
        decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
            weighted average in the cross-attention heads.
        encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
        encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
        encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        encoder_global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
            is the number of tokens with global attention mask.

            Global attentions weights after the attention softmax, used to compute the weighted average in the
            self-attention heads. Those are the attention weights from every token with global attention to every token
            in the sequence.
    Nr   r   List[tf.Tensor] | Nonepast_key_valuesr  decoder_hidden_statesdecoder_attentionscross_attentionsr   encoder_last_hidden_stater   encoder_attentionsencoder_global_attentions)rS   rT   rU   rV   r   r  r  r  r	  r
  r  r   r  r  r-   r-   r-   r.   r    s   
7r  c                   @  s   e Zd ZU dZdZded< dZded< dZded< dZd	ed
< dZ	d	ed< dZ
d	ed< dZded< dZd	ed< dZd	ed< dZd	ed< dS )TFLEDSeq2SeqLMOutputap  
    Base class for sequence-to-sequence language models outputs.

    Args:
        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss.
        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
            sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
            used (see `past_key_values` input) to speed up sequential decoding.
        decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
        decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
            weighted average in the cross-attention heads.
        encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
        encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
        encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        encoder_global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
            is the number of tokens with global attention mask.

            Global attentions weights after the attention softmax, used to compute the weighted average in the
            self-attention heads. Those are the attention weights from every token with global attention to every token
            in the sequence.
    Nr   lossr   logitsr  r  r  r  r	  r
  r  r   r  r  )rS   rT   rU   rV   r  r  r  r  r  r	  r
  r  r   r  r  r-   r-   r-   r.   r    s   
5r  at	  
    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
    behavior.

    <Tip>

    TensorFlow models and layers in `transformers` accept two formats as input:

    - having all inputs as keyword arguments (like PyTorch models), or
    - having all inputs as a list, tuple or dict in the first positional argument.

    The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
    and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
    pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
    format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
    the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
    positional argument:

    - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
    - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
    `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
    - a dictionary with one or several input Tensors associated to the input names given in the docstring:
    `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`

    Note that when creating models and layers with
    [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
    about any of this, as you can just pass inputs like you would to any other Python function!

    </Tip>

    Args:
        config ([`LEDConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_ids (`tf.Tensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)

            LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
        decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
        head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        encoder_outputs (`tf.Tensor`, *optional*):
            hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
            of shape `(batch_size, sequence_length, hidden_size)` is a sequence of
        past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
            contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*, defaults to `True`):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`). Set to `False` during training, `True` during generation
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
            config will be used instead.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
            used instead.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
            eager mode, in graph mode the value will always be set to True.
        training (`bool`, *optional*, defaults to `False`):
            Whether or not to use the model in training mode (some modules like dropout modules have different
            behaviors between training and evaluation).
c                
      sh   e Zd ZeZdddd fddZdd Zd	d
 ZedddZ	e
jdd Zdd ZdddZ  ZS )TFLEDEncoderNr    Optional[keras.layers.Embedding]r`   embed_tokensc                   s  t  jf |  | _tj j| _ jdkr8t	d d| _
 j| _t jtr jd dkshtd jdksztd jg j  _n,t j jkstd j dt j  j| _|| _t j jd	d
| _ fddt jD | _tjjddd| _ j| _d S )Nr   0Layerdrop is currently disabled in TFLED models.r3   r_   z1`config.attention_window` has to be an even valuez,`config.attention_window` has to be positivezQ`len(config.attention_window)` should equal `config.num_hidden_layers`. Expected z, given embed_positionsr   c                   s    g | ]}t  |d | dqS zlayers.r   )r   .0ir   r-   r.   
<listcomp>  s     z)TFLEDEncoder.__init__.<locals>.<listcomp>r   layernorm_embeddingr   )rF   rG   r`   r   rf   rh   ri   Zencoder_layerdroploggerwarning	layerdropr   padding_idx
isinstancerl   r   rm   Znum_hidden_layerslenr  rB   Zmax_encoder_position_embeddingsr   r  r8   Zencoder_layersr   r  re   rH   r`   r  rI   rJ   r   r.   rG     s0    

zTFLEDEncoder.__init__c                 C  s   | j S rE   r  rH   r-   r-   r.   get_embed_tokens  s    zTFLEDEncoder.get_embed_tokensc                 C  s
   || _ d S rE   r$  rH   r  r-   r-   r.   set_embed_tokens  s    zTFLEDEncoder.set_embed_tokensFc
              	     s  |dk	r|dk	rt dnL|dk	rDt|}
t|| jj | |}n"|dk	r^t|dd }
nt d|dkrzt|
d}|dk	r|tj|d |jd }| j	|||| j
d\ }}}t|}
tjt|tjd}tjt|tjd}tj|}| |
}|| }| |}| j||	d}|dk	rbt|ddd	d	ddf }|ddddddf }|rld
nd}|rzd
nd }}|dk	rtjjt|d	 t| jdt| j dt|d	  dd t| jD ]\}}|r| | }||f }td	d}|	r|| jk rq||||dk	r0|| nd|||d}|d	 }|r|t|d df }|t|d df }q| | }|r d	krt fdd|D n|}|r||f }|stdd |||fD S t||||dS )aW  
        Args:
            input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`tf.Tensor` of shape `(num_layers, num_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        NzDYou cannot specify both input_ids and inputs_embeds at the same timer!   z5You have to specify either input_ids or inputs_embedsr   r"   )r   r   inputs_embedsr   r   r   r-   &The head_mask should be specified for  layers, but it is for r   rt   r   r   r_   )r   r   r   r_   c                   s.   g | ]&}|d d d d d   d d f qS rE   r-   )r  statepadding_lenr-   r.   r  w  s     z%TFLEDEncoder.call.<locals>.<listcomp>c                 s  s   | ]}|d k	r|V  qd S rE   r-   r  vr-   r-   r.   	<genexpr>  s      z$TFLEDEncoder.call.<locals>.<genexpr>)r   r   r  r  ) rb   r   r   r  	input_dimr$   r&   r%   r#   _pad_to_window_sizer   r   ZlessZint8r   Z
reduce_anyr  r  ri   rA   r*   r   r"  rf   	enumeratecompute_hidden_statesrandomuniformr  r   r   r   )rH   r   r)  r   r   	head_maskoutput_attentionsoutput_hidden_statesreturn_dictr   rL   r   r   r   Z	embed_posr   Zencoder_statesZall_attentionsZall_global_attentionsidxZencoder_layerZhidden_states_to_adddropout_probabilityr   r-   r-  r.   rO     s    2







	
zTFLEDEncoder.callc                 C  s"   |dkr|d d d | f S |S )Nr   r-   )rH   r   r.  r-   r-   r.   r5    s    z"TFLEDEncoder.compute_hidden_statesc                 C  s   t | jtr| jnt| j}|d dks6td| |dk	rFt|nt|}|dd \}}|||  | }	|	dkrtd| d||	  d|  t	ddgd|	gg}
|dk	rtj
||
|d}|dk	r|	dkrt||	f|}| |}tj||gd	d
}tj
||
dd}|	|||fS )zaA helper function to pad tokens and mask to work with implementation of Longformer selfattention.r_   r   z2`attention_window` should be an even value. Given Nz(Input ids are automatically padded from r   z0 to be a multiple of `config.attention_window`: r   r   r4   F)r!  rl   r   maxrm   r   r  Zwarning_oncer$   r'   r   r&   r  r(   )rH   r   r   r)  r   rl   rL   r   rQ   r.  r   Zinput_ids_paddingZinputs_embeds_paddingr-   r-   r.   r3    s0    



z TFLEDEncoder._pad_to_window_sizec              
   C  s   | j r
d S d| _ t| dd d k	rFt| jj | jd  W 5 Q R X t| dd d k	rt| jj | jd d | jg W 5 Q R X t| dd d k	r| j	D ]&}t|j |d  W 5 Q R X qd S NTr  r  rf   )
ro   rr   r$   rp   r  rN   rq   r  re   rf   rH   rL   Zlayerr-   r-   r.   rq     s    
zTFLEDEncoder.build)N)	NNNNNNNNF)N)rS   rT   rU   r   r   rG   r&  r(  r   rO   r$   functionr5  r3  rq   rW   r-   r-   rJ   r.   r    s(   	          
-r  c                      sJ   e Zd ZeZdddd fddZdd Zedd
dZdddZ	  Z
S )TFLEDDecoderNr   r  r  c                   s   t  jf |  | _ j| _|| _ jdkr6td d| _	t
 j jdd| _ fddt jD | _tjjdd	d
| _tj j| _d S )Nr   r  r3   r  r   c                   s   g | ]}t  d | dqS r  )r   r  r   r-   r.   r    s     z)TFLEDDecoder.__init__.<locals>.<listcomp>r   r  r   )rF   rG   r`   r   r   r  Zdecoder_layerdropr  r  r  rB   Zmax_decoder_position_embeddingsr   r  r8   Zdecoder_layersrf   r   r   r  rh   ri   r#  rJ   r   r.   rG     s    

zTFLEDDecoder.__init__c                 C  s
   || _ d S rE   r$  r'  r-   r-   r.   r(    s    zTFLEDDecoder.set_embed_tokensFc              
   C  s  |dk	r|dk	rt dn4|dk	r,t|}n"|dk	rFt|dd }nt d|dk	rjt|d d d nd}| ||}|dkrt|| jj | |}|}|d dkrt||d}n&tt	|d |d | f|d d	}|dk	r
|d dkr
|t||d d	 }|dk	r.|dk	r.t||d d	}| 
|| }| j||d
}d}d}d}d}|dk	rtjjt|d t| jdt| j dt|d  dd t| jD ]\}}|r||f7 }tdd}|r|| jk rq|dk	r|| nd}||||||dk	r|| nd|dk	r*|| nd|d\}}}}|	rL||f7 }|
r||f7 }||f7 }q|r|||f7 }nd}|
r|nd}|
r|nd}|	r|nd}|stdd |||||fD S t|||||dS dS )aM  
        Args:
            input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
            attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
                selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
            head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            encoder_head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
                on hidden heads. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
                Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
                decoding. If `past_key_values` are used, the user can optionally input only the last
                `decoder_input_ids` (those that don't have their past key value states given to this model) of shape
                `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
                inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        NzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same timer!   zEYou have to specify either decoder_input_ids or decoder_inputs_embedsr   r_   r   )r2   )r=   r   r-   r*  r+  r   rt   )r   r   r   r   r   r   c                 s  s   | ]}|d k	r|V  qd S rE   r-   r/  r-   r-   r.   r1    s   z$TFLEDDecoder.call.<locals>.<genexpr>)r   r  r   r  r
  )rb   r   r  r   r  r2  r?   rA   r$   r6   r  ri   r*   r   r"  rf   r4  r6  r7  r  r   r
   )rH   r   r)  r   r   r   r8  encoder_head_maskr  	use_cacher9  r:  r;  r   rL   r2   Z	positionsr   Zcombined_attention_maskZall_hidden_statesZall_self_attnsZall_cross_attentionsZpresent_key_valuesr<  Zdecoder_layerr=  r   Zlayer_self_attnZlayer_cross_attnr   r-   r-   r.   rO     s    G

 
 

	



zTFLEDDecoder.callc              
   C  s   | j r
d S d| _ t| dd d k	rFt| jj | jd  W 5 Q R X t| dd d k	rt| jj | jd d | jj	g W 5 Q R X t| dd d k	r| j
D ]&}t|j |d  W 5 Q R X qd S r?  )ro   rr   r$   rp   r  rN   rq   r  r`   r   rf   r@  r-   r-   r.   rq     s     
zTFLEDDecoder.build)N)NNNNNNNNNNNNF)N)rS   rT   rU   r   r   rG   r(  r   rO   rq   rW   r-   r-   rJ   r.   rB    s(   	              3rB  c                      sT   e Zd ZeZdd fddZdd Zdd ZedddddZ	dddZ
  ZS )TFLEDMainLayerr   r   c                   sl   t  jf | || _tjj|j|jtjj	| jj
ddd| _d| j_t|| jdd| _t|| jdd| _d S )N)stddevz
led.shared)r2  Z
output_dimZembeddings_initializerrN   encoderr   decoder)rF   rG   r`   r   rf   	Embedding
vocab_sizer   ZinitializersZTruncatedNormalZinit_stdsharedload_weight_prefixr  rG  rB  rH  r   rJ   r-   r.   rG     s    zTFLEDMainLayer.__init__c                 C  s   | j S rE   )rK  r%  r-   r-   r.   get_input_embeddings  s    z#TFLEDMainLayer.get_input_embeddingsc                 C  s   || _ | j | j_| j | j_d S rE   )rK  rG  r  rH  )rH   Znew_embeddingsr-   r-   r.   set_input_embeddings  s    
z#TFLEDMainLayer.set_input_embeddingsNFz3Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]])encoder_outputsc                 K  s   |d kr|d krd}|d kr:| j |||||
||||d	}n`|rt|tst|d t|dkrd|d nd t|dkrz|d nd d}n|st|ts| }| j|||d ||||	||||||d}|s|| S t|j|j	|j
|j|j|j|j
|j|jd	S )	NF)	r   r   r   r8  r)  r9  r:  r;  r   r   r   r_   )r   r   r  )r   r   r   r8  rC  r  r)  rD  r9  r:  r;  r   	r   r  r  r	  r
  r  r   r  r  )rG  r!  r   r"  r   Zto_tuplerH  r  r   r  r   r  r
  r  )rH   r   r   decoder_input_idsdecoder_attention_maskr8  decoder_head_maskrO  r   r  r)  decoder_inputs_embedsrD  r9  r:  r;  r   rI   Zdecoder_outputsr-   r-   r.   rO     sd    zTFLEDMainLayer.callc              	   C  s   | j r
d S d| _ t| jjd | jj d  | jd  W 5 Q R X t| dd d k	r|t| jj | jd  W 5 Q R X t| dd d k	rt| j	j | j	d  W 5 Q R X d S )NT/rG  rH  )
ro   r$   rp   rK  rL  rN   rq   rr   rG  rH  rs   r-   r-   r.   rq   	  s     zTFLEDMainLayer.build)NNNNNNNNNNNNNNNF)N)rS   rT   rU   r   r   rG   rM  rN  r   rO   rq   rW   r-   r-   rJ   r.   rE    s.                   MrE  zQThe bare LED Model outputting raw hidden-states without any specific head on top.c                      s   e Zd Z fddZdd Zdd Zeee	de
eeeddddddddddddddddddddddZdd ZdddZ  ZS )
TFLEDModelc                   s&   t  j|f|| t|dd| _d S )Nr   r   )rF   rG   rE  r   rH   r`   r   rI   rJ   r-   r.   rG   +	  s    zTFLEDModel.__init__c                 C  s   | j jS rE   r   rG  r%  r-   r-   r.   get_encoder0	  s    zTFLEDModel.get_encoderc                 C  s   | j jS rE   r   rH  r%  r-   r-   r.   get_decoder3	  s    zTFLEDModel.get_decoderzbatch_size, sequence_length)
checkpointoutput_typer   NFTFModelInputType | Noner   r   bool | Noner   z*Tuple[tf.Tensor] | TFLEDSeq2SeqModelOutput)r   r   rQ  rR  r8  rS  rO  r   r  r)  rT  rD  r9  r:  r;  r   r   c                 K  s.   | j |||||||||	|
||||||d}|S )N)r   r   rQ  rR  rO  r   r8  rS  r  r)  rT  rD  r9  r:  r;  r   )r   )rH   r   r   rQ  rR  r8  rS  rO  r   r  r)  rT  rD  r9  r:  r;  r   rI   r   r-   r-   r.   rO   6	  s&    zTFLEDModel.callc           	      C  s   | j jrt|jd nd }| j jr0t|jnd }| j jrHt|j	nd }| j jr`t|j
nd }| j jrxt|jnd }| j jrt|jnd }| j jrt|jnd }t|j|||||j|||d	S )Nr   rP  )r`   rD  r$   r   r  r:  r'   r  r9  r	  r
  r   r  r  r  r   r  	rH   r   ZpkvZdec_hsZ	dec_attnsZcross_attnsZenc_hsZ	enc_attnsZenc_g_attnsr-   r-   r.   serving_outputf	  s$    zTFLEDModel.serving_outputc              	   C  sJ   | j r
d S d| _ t| dd d k	rFt| jj | jd  W 5 Q R X d S )NTr   )ro   rr   r$   rp   r   rN   rq   rs   r-   r-   r.   rq   {	  s    zTFLEDModel.build)NNNNNNNNNNNNNNNF)N)rS   rT   rU   rG   rY  r[  r   r   LED_INPUTS_DOCSTRINGformatr   _CHECKPOINT_FOR_DOCr  _CONFIG_FOR_DOCrO   ra  rq   rW   r-   r-   rJ   r.   rV  &	  s:                   6)rV  c                      s(   e Zd ZdZ fddZdd Z  ZS )	BiasLayerz
    Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
    so all weights have to be registered in a layer.
    c                   s.   t  jf d|i| | j||||d| _d S )NrN   rN   r   initializer	trainable)rF   rG   Z
add_weightr   )rH   r   rh  ri  rN   rI   rJ   r-   r.   rG   	  s    zBiasLayer.__init__c                 C  s
   || j  S rE   )r   )rH   xr-   r-   r.   rO   	  s    zBiasLayer.callrR   r-   r-   rJ   r.   rf  	  s   rf  zKThe LED Model with a language modeling head. Can be used for summarization.c                      s   e Zd ZddgZ fddZdd Zdd Zd	d
 Zdd Zdd Z	dd Z
eeeeeedd+dddddddddddddddddddddZdd  Zd,d!d"Zd#d$d%d&Zd'd( Zd-d)d*Z  ZS ).TFLEDForConditionalGenerationzled.encoder.embed_tokens.weightzled.decoder.embed_tokens.weightc                   sL   t  j|f|| t|dd| _|j| _tdd|jgddd| _d| _d S )Nr   r   final_logits_biasr   r:   Frg  )	rF   rG   rE  r   rD  rf  rJ  
bias_layerZsupports_xla_generationrW  rJ   r-   r.   rG   	  s       z&TFLEDForConditionalGeneration.__init__c                 C  s   | j jS rE   rZ  r%  r-   r-   r.   r[  	  s    z)TFLEDForConditionalGeneration.get_decoderc                 C  s   | j jS rE   rX  r%  r-   r-   r.   rY  	  s    z)TFLEDForConditionalGeneration.get_encoderc                 C  s   d| j jiS )Nrl  )rm  r   r%  r-   r-   r.   get_bias	  s    z&TFLEDForConditionalGeneration.get_biasc                 C  s:   |d j d }tdd|gddd| _| jj|d  d S )Nrl  r!   r   r:   Frg  )r   rf  rm  r   Zassign)rH   r[   rJ  r-   r-   r.   set_bias	  s       z&TFLEDForConditionalGeneration.set_biasc                 C  s   |   S rE   )rM  r%  r-   r-   r.   get_output_embeddings	  s    z3TFLEDForConditionalGeneration.get_output_embeddingsc                 C  s   |  | d S rE   )rN  )rH   r[   r-   r-   r.   set_output_embeddings	  s    z3TFLEDForConditionalGeneration.set_output_embeddings)r]  r   NFr^  znp.ndarray | tf.Tensor | Nonez"TFLEDEncoderBaseModelOutput | Nonez1Tuple[Tuple[Union[np.ndarray, tf.Tensor]]] | Noner_  r   r   z'Tuple[tf.Tensor] | TFLEDSeq2SeqLMOutput)r   r   rQ  rR  r8  rS  rO  r   r  r)  rT  rD  r9  r:  r;  labelsr   r   c                 C  s   |dk	r0d}|dkr0|dkr0t || jj| jj}| j|||||||||	|
||||||d}tj|d | jjjdd}| 	|}|dkrdn
| 
||}|s|f|dd  }|dk	r|f| S |S t|||j|j|j|j|j|j|j|jd
S )	a  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, TFLEDForConditionalGeneration
        >>> import tensorflow as tf

        >>> mname = "allenai/led-base-16384"
        >>> tokenizer = AutoTokenizer.from_pretrained(mname)
        >>> TXT = "My friends are <mask> but they eat too many carbs."
        >>> model = TFLEDForConditionalGeneration.from_pretrained(mname)
        >>> batch = tokenizer([TXT], return_tensors="tf")
        >>> logits = model(inputs=batch.input_ids).logits
        >>> probs = tf.nn.softmax(logits[0])
        >>> # probs[5] is associated with the mask token
        ```NF)r   rQ  rR  rO  r   r8  rS  r  r)  rT  rD  r9  r:  r;  r   r   Tr   r   )
r  r  r  r  r	  r
  r  r   r  r  )r/   r`   r   r    r   r$   r   rK  weightsrm  hf_compute_lossr  r  r  r	  r
  r  r   r  r  )rH   r   r   rQ  rR  r8  rS  rO  r   r  r)  rT  rD  r9  r:  r;  rr  r   r   Z	lm_logitsZmasked_lm_lossr   r-   r-   r.   rO   	  sX    *  
z"TFLEDForConditionalGeneration.callc           	      C  s   | j jrt|jd nd }| j jr0t|jnd }| j jrHt|j	nd }| j jr`t|j
nd }| j jrxt|jnd }| j jrt|jnd }| j jrt|jnd }t|j|||||j|||d	S )Nr   )	r  r  r  r	  r
  r  r   r  r  )r`   rD  r$   r   r  r:  r'   r  r9  r	  r
  r   r  r  r  r  r  r`  r-   r-   r.   ra  
  s$    z,TFLEDForConditionalGeneration.serving_outputc           	   	   K  s2   |d k	r|d d dd f }d |||||||dS )Nr!   )r   rO  r  rQ  r   r8  rS  rD  r-   )	rH   rQ  r  r   r8  rS  rD  rO  rI   r-   r-   r.   prepare_inputs_for_generation/
  s    z;TFLEDForConditionalGeneration.prepare_inputs_for_generationr   )rr  c                 C  s   t || jj| jjS rE   )r/   r`   r   r    )rH   rr  r-   r-   r.   %prepare_decoder_input_ids_from_labelsI
  s    zCTFLEDForConditionalGeneration.prepare_decoder_input_ids_from_labelsc                 C  s   t jjdt jjjd}| jjrpt|d}t	|| jj
}tt|dt|d f|}t||}|||S |tj||}tj|| jj
k|jd}|| }	t|	t| }
t|
dS )z(CrossEntropyLoss that ignores pad tokensT)Zfrom_logitsZ	reduction)r!   r!   r_   r"   )r   )r   ZlossesZSparseCategoricalCrossentropyZ	ReductionNONEr`   Ztf_legacy_lossr$   r9   	not_equalr   Zboolean_maskr   nnZrelur%   r#   Z
reduce_sum)rH   rr  r  Zloss_fnZmelted_labelsZactive_lossZreduced_logitsZunmasked_lossZ	loss_maskZmasked_lossZreduced_masked_lossr-   r-   r.   rt  L
  s     
z-TFLEDForConditionalGeneration.hf_compute_lossc              	   C  s   | j r
d S d| _ t| dd d k	rFt| jj | jd  W 5 Q R X t| dd d k	r|t| jj | jd  W 5 Q R X d S )NTr   rm  )ro   rr   r$   rp   r   rN   rq   rm  rs   r-   r-   r.   rq   ^
  s    z#TFLEDForConditionalGeneration.build)NNNNNNNNNNNNNNNNF)NNNNNN)N)rS   rT   rU   Z"_keys_to_ignore_on_load_unexpectedrG   r[  rY  rn  ro  rp  rq  r   r   rb  r   r  re  rO   ra  ru  rv  rt  rq   rW   r-   r-   rJ   r.   rk  	  sT   
                 8T      
rk  )r   )N)ErV   
__future__r   r6  Zdataclassesr   typingr   r   r   r   ZnumpynpZ
tensorflowr$   Zactivations_tfr	   Zmodeling_tf_outputsr
   Zmodeling_tf_utilsr   r   r   r   r   r   Ztf_utilsr   r   r   utilsr   r   r   r   r   r   Zconfiguration_ledr   Z
get_loggerrS   r  rd  re  r7   r/   r?   rA   rf   rI  rB   ZLayerrX   r   r   r   r   r   r   r  r  ZLED_START_DOCSTRINGrb  r  rB  rE  rV  rf  rk  r-   r-   r-   r.   <module>   s|     
      i' )P 	*CB)D   cy[