U
    4A·fxu  ã                   @   s  d dl Z d dlZd dlZd dlmZ d dlmZ ddlmZ e 	e
¡ZG dd„ dejƒZG dd	„ d	ejƒZG d
d„ dejƒZG dd„ dejƒZG dd„ dejƒZG dd„ dejƒZd$dd„Zd%dd„Zd&dd„ZG dd„ deƒZG dd„ deƒZG dd„ deƒZd'd d!„ZG d"d#„ d#eƒZdS )(é    N)Únn)ÚFunctioné   )Úloggingc                	       s,   e Zd ZdZd‡ fdd„	Zdd	d
„Z‡  ZS )ÚQuantEmbeddingaÞ  
    Quantized version of `torch.nn.Embedding`. Adds quantization-specific arguments on top of `torch.nn.Embedding`.

    Args:
        weight_bit (`int`, *optional*, defaults to `8`):
            Bitwidth for the quantized weight.
        momentum (`float`, *optional*, defaults to `0.95`):
            Momentum for updating the activation quantization range.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
    Nç       @Fé   çffffffî?c                    s”   t ƒ  ¡  || _|| _|| _|| _|| _|| _|| _t	 
t ||g¡¡| _|  dt d¡¡ |  dt | j¡¡ |	| _|
| _|| _d| _tj| _d S )NÚweight_scaling_factoré   Úweight_integerF)ÚsuperÚ__init__Znum_ÚdimÚpadding_idxÚmax_normÚ	norm_typeÚscale_grad_by_freqÚsparser   Ú	ParameterÚtorchÚzerosÚweightÚregister_bufferÚ
zeros_likeÚ
weight_bitÚmomentumÚ
quant_modeÚpercentile_modeÚSymmetricQuantFunctionÚapplyÚweight_function)ÚselfZnum_embeddingsZembedding_dimr   r   r   r   r   Z_weightr   r   r   ©Ú	__class__© úK/tmp/pip-unpacked-wheel-zw5xktn0/transformers/models/ibert/quant_modules.pyr   ,   s     
zQuantEmbedding.__init__c           	   	   C   sº   | j s.tj || j| j| j| j| j| j	¡d fS | j}|j
 ¡ }| ¡  d¡}| ¡  d¡}t| j||dƒ| _|  | j| j| j| j¡| _tj || j| j| j| j| j| j	¡}|| j | jfS )Nr   F)r   r   Ú
functionalZ	embeddingr   r   r   r   r   r   ÚdataÚdetachÚminÚexpandÚmaxÚ$symmetric_linear_quantization_paramsr   r
   r!   r   r   )	r"   ÚxZ	positionsZincremental_stateÚwÚw_transformÚw_minÚw_maxZemb_intr%   r%   r&   ÚforwardM   sB    ù	ö
   ÿù	zQuantEmbedding.forward)	NNr   FFNr   r	   F)NN)Ú__name__Ú
__module__Ú__qualname__Ú__doc__r   r3   Ú__classcell__r%   r%   r#   r&   r      s            ô!r   c                       s4   e Zd ZdZd‡ fdd„	Zdd„ Zdd	d
„Z‡  ZS )ÚQuantActap  
    Quantizes the given activation.

    Args:
        activation_bit (`int`):
            Bitwidth for the quantized activation.
        act_range_momentum (`float`, *optional*, defaults to `0.95`):
            Momentum for updating the activation quantization range.
        per_channel (`bool`, *optional*, defaults to `False`):
            Whether to or not use channel-wise quantization.
        channel_len (`int`, *optional*):
            Specify the channel length when set the *per_channel* True.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
    r	   FNc                    s–   t ƒ  ¡  || _|| _|| _|| _d| _tj| _	| jsŠ|  
dt d¡¡ |  
dt d¡¡ |  
dt d¡¡ |  jd8  _|  jd7  _ntdƒ‚d S )NFÚx_minr   Úx_maxÚact_scaling_factorgñhãˆµøä>ú;per-channel mode is not currently supported for activation.)r   r   Úactivation_bitÚact_range_momentumr   Úper_channelÚ
percentiler   r    Úact_functionr   r   r   r:   r;   ÚNotImplementedError)r"   r>   r?   r@   Zchannel_lenr   r#   r%   r&   r   ƒ   s    
zQuantAct.__init__c              
   C   s:   | j j› d| j› d| j› d| j ¡ d›d| j ¡ d›d
S )Nz(activation_bit=z, quant_mode: z, Act_min: z.2fz, Act_max: ú))r$   r4   r>   r   r:   Úitemr;   )r"   r%   r%   r&   Ú__repr__–   s    8ÿzQuantAct.__repr__c                 C   s°  |d kr|n|| }| j r| jr*tdƒ‚| jr8tdƒ‚|j ¡ }|j ¡ }	|	 ¡  ¡ dkrl| ¡  ¡ dksttdƒ‚| j	 ¡ dkrª| j
 ¡ dk rª| j	| | _	| j
|	 | _
nd| jdkrÖt | j	|¡| _	t | j
|	¡| _
n8| j	| j |d| j   | _	| j
| j |	d| j   | _
| js|d fS |d kr.| j	n|}|d krB| j
n|}	t| j||	| jd	| _|d kr~|  || j| j| j¡}
nt ||| j| j||¡}
| j d¡}|
| | jfS )
Nz:percentile mode is not currently supported for activation.r=   r   z5NaN detected when computing min/max of the activationg¢&ú|”ç¾g¢&ú|”ç>éÿÿÿÿr   )r@   )ÚtrainingrA   ÚAssertionErrorr@   r(   r*   r,   ÚisnanÚsumr:   r;   r?   r   r   r-   r>   r<   rB   ÚFixedPointMulr    Úview)r"   r.   Úpre_act_scaling_factorÚidentityÚidentity_scaling_factorZspecified_minZspecified_maxZx_actr:   r;   Zquant_act_intZcorrect_output_scaler%   r%   r&   r3      sT    	

ÿÿþ
   ÿ
ú	zQuantAct.forward)r	   FNF)NNNNN©r4   r5   r6   r7   r   rF   r3   r8   r%   r%   r#   r&   r9   r   s   
     ùr9   c                       s8   e Zd ZdZd‡ fdd„	Z‡ fdd	„Zddd„Z‡  ZS )ÚQuantLineara8  
    Quantized version of `torch.nn.Linear`. Adds quantization-specific arguments on top of `torch.nn.Linear`.

    Args:
        weight_bit (`int`, *optional*, defaults to `8`):
            Bitwidth for the quantized weight.
        bias_bit (`int`, *optional*, defaults to `32`):
            Bitwidth for the quantized bias.
        per_channel (`bool`, *optional*, defaults to `False`):
            Whether or not to use channel-wise quantization.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
    Tr   é    Fc                    s®   t ƒ  ¡  || _|| _t t ||g¡¡| _|  	dt 
| j¡¡ |  	dt | j¡¡ |r~t t |¡¡| _|  	dt 
| j¡¡ || _|| _|| _|| _|| _d| _tj| _d S )Nr   Úfc_scaling_factorÚbias_integerF)r   r   Úin_featuresÚout_featuresr   r   r   r   r   r   r   Úbiasr   r   r@   Úbias_bitr   r   r    r!   )r"   rV   rW   rX   r   rY   r@   r   r#   r%   r&   r   ë   s     
zQuantLinear.__init__c                    s*   t ƒ  ¡ }d|› d| j› d| j› d}|S )Nú(z weight_bit=z, quant_mode=rD   )r   rF   r   r   )r"   Úsr#   r%   r&   rF     s    
zQuantLinear.__repr__Nc           
      C   s   | j s tjj|| j| jdd fS |d k	r2|jdks:tdƒ‚| j}|j 	¡ }| j
rztj|dd d\}}tj|dd d\}}n| ¡  d¡}| ¡  d¡}t| j||| j
ƒ| _|  | j| j| j| j¡| _| j| }| jd k	rî|  | j| jd|¡| _| dd¡}|| }	tjj|	| j| jd| |fS )N)r   rX   )r   z«Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. Please add a QuantAct layer with `per_channel = True` before this QuantAct layerr   )r   ÚoutFrG   )r   r   r'   Zlinearr   rX   ÚshaperI   r(   r)   r@   r   r*   r,   r+   r-   r   rT   r!   r   r   rY   rU   rM   )
r"   r.   Zprev_act_scaling_factorr/   r0   r1   Ú_r2   Zbias_scaling_factorÚx_intr%   r%   r&   r3     s6    ÿ
   ÿ

þzQuantLinear.forward)Tr   rS   FF)NrQ   r%   r%   r#   r&   rR   Ü   s            ÿrR   c                       s4   e Zd ZdZd‡ fdd„	Zdd„ Zdd	d
„Z‡  ZS )ÚIntGELUa}  
    Quantized version of `torch.nn.GELU`. Adds quantization-specific arguments on top of `torch.nn.GELU`.

    Args:
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
        force_dequant (`str`, *optional*, defaults to `"none"`):
            Force dequantize the layer if either "gelu" or "nonlinear" is given.
    TÚnonec                    sl   t ƒ  ¡  || _|dkr(t d¡ d| _| js8t ¡ | _d| _d| _	dddg| _
| j
d	  | j
d
   < d S )N)Ú	nonlinearZgeluzForce dequantize geluFgà- ö?é   g]mÅþ²{Ò¿gçû©ñÒMü¿r   é   r   )r   r   r   ÚloggerÚinfor   ZGELUÚactivation_fnÚkÚconstÚcoeff)r"   r   Úforce_dequantr#   r%   r&   r   7  s    


zIntGELU.__init__c                 C   sš   t  | jd | ¡}t  | jd |d  ¡}t  |¡}t  t  |¡| ¡}||| d |  }|d | jd  }t |d| j  ¡}|d| j  }||fS ©Nr   rd   r   )	r   Úfloorrj   Úsignr*   ÚabsÚ	floor_ster    ri   )r"   r_   Úscaling_factorÚb_intÚc_intrn   Zabs_intÚy_intr%   r%   r&   Úint_erfG  s    
zIntGELU.int_erfNc                 C   s^   | j s|  |¡d fS || }|  ||| j ¡\}}d| }|||  }|| d }|| |fS )Nç      ð?rd   )r   rg   ru   rh   )r"   r.   rq   r_   Zsigmoid_intZsigmoid_scaling_factorZ	shift_intr%   r%   r&   r3   V  s    zIntGELU.forward)Tra   )N)r4   r5   r6   r7   r   ru   r3   r8   r%   r%   r#   r&   r`   ,  s   
r`   c                       s:   e Zd ZdZd‡ fdd„	Zdd„ Zdd	„ Zd
d„ Z‡  ZS )Ú
IntSoftmaxaØ  
    Quantized version of `torch.nn.Softmax`. Adds quantization-specific arguments on top of `torch.nn.Softmax`.

    Args:
        output_bit (`int`):
            Bitwidth for the layer output activation.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
        force_dequant (`str`, *optional*, defaults to `"none"`):
            Force dequantize the layer if either "softmax" or "nonlinear" is given.
    Fra   c                    s   t ƒ  ¡  || _d| _|| _|dkr4t d¡ d| _td| jd| _d| _	d| _
d	d
dg| _| jd  | jd   < | jd  | jd   < d S )NrS   )rb   ÚsoftmaxzForce dequantize softmaxFé   ©r   gvqà-æ¿é   gN„ª$ôëÖ?g¾Ã'|:ï?rv   r   r   rd   )r   r   Ú
output_bitÚmax_bitr   re   rf   r9   ÚactÚx0ri   Úcoef)r"   r|   r   rk   r#   r%   r&   r   r  s    

zIntSoftmax.__init__c              	   C   sj   t  ¡ 2 t  | jd | ¡}t  | jd |d  ¡}W 5 Q R X || | | }| jd |d  }||fS rl   )r   Úno_gradrm   r€   )r"   r_   rq   rr   rs   Úzr%   r%   r&   Úint_polynomialƒ  s    
"zIntSoftmax.int_polynomialc              	   C   s˜   t  ¡  t  | j| ¡}W 5 Q R X t  || j| ¡}t || ¡}|||  }|  ||¡\}}t j	t |d| j|   ¡dd}|d| j  }||fS )Nrd   r   ©r*   )
r   r   rm   r   r,   ri   rp   r    rƒ   Úclamp)r"   r_   rq   Zx0_intÚqÚrÚexp_intÚexp_scaling_factorr%   r%   r&   Úint_exp‹  s    
"zIntSoftmax.int_expc                 C   s¾   | j stjj|ddd fS || }|jddd\}}|| }|  ||¡\}}|  ||¡\}}|| }|jddd}	t 	d| j
 |	 ¡}
t 	||
 d| j
| j   ¡}dd| j  }|| |fS )NrG   ©r   T)r   Úkeepdimrd   r   )r   r   r'   rx   r,   rŠ   r~   rK   rp   r    r}   r|   )r"   r.   rq   r_   Z	x_int_maxr^   rˆ   r‰   ÚexpZexp_int_sumÚfactorr%   r%   r&   r3   —  s    zIntSoftmax.forward)Fra   )	r4   r5   r6   r7   r   rƒ   rŠ   r3   r8   r%   r%   r#   r&   rw   e  s
   rw   c                       s<   e Zd ZdZd‡ fdd„	Zdd„ Zd	d
„ Zddd„Z‡  ZS )ÚIntLayerNormaû  
    Quantized version of `torch.nn.LayerNorm`. Adds quantization-specific arguments on top of `torch.nn.LayerNorm`.

    Args:
        output_bit (`int`, *optional*, defaults to `8`):
            Bitwidth for the layer output activation.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
        force_dequant (`str`, *optional*, defaults to `"none"`):
            Force dequantize the layer if either "layernorm" or "nonlinear" is given.
    r   Fra   c                    s’   t ƒ  ¡  || _|| _t t |¡¡| _t t |¡¡| _	|| _
|dkrXt d¡ d| _
|  dt d¡¡ || _d| _d | _t| j| j
d| _d S )N)rb   Z	layernormzForce dequantize layernormFÚshiftr   rS   rz   )r   r   Únormalized_shapeÚepsr   r   r   r   r   rX   r   re   rf   r   r|   r}   Údim_sqrtr9   Z
activation)r"   r‘   r’   r|   r   rk   r#   r%   r&   r   ¹  s    

zIntLayerNorm.__init__c              	   C   sˆ   t  ¡ v |d }t j|ddd}t  t  |d| j  ¡¡ ¡  ¡ }| j}t  | j|¡| _t	 
dt|ƒ› dt| jƒ› ¡ W 5 Q R X d S )Nrd   T©ZaxisrŒ   zDynamic shift adjustment: z -> )r   r   rK   Úlog2Úsqrtr}   Úceilr,   r   re   rf   Úint)r"   rt   Úy_sq_intÚvar_intr   Z	shift_oldr%   r%   r&   Ú	set_shiftÌ  s    
"zIntLayerNorm.set_shiftc                 C   s:   |   |¡ t |d| j  ¡}|d }tj|ddd}|S )z±
        This fallback function is called when overflow is detected during training time, and adjusts the `self.shift`
        to avoid overflow in the subsequent runs.
        rd   Tr”   )r›   rp   r    r   r   rK   )r"   rt   Úy_int_shiftedr™   rš   r%   r%   r&   Úoverflow_fallbackÕ  s
    
zIntLayerNorm.overflow_fallbackNc                 C   s²  | j s\|jddd}|| }tj|d ddd}|t | j| ¡ }|| j | j }|d fS | jd krtj|j	d tj
d}t |¡ |j¡| _|| }t |jddd¡}|| }	t |	d| j  ¡}
|
d }tj|ddd}| jr&| ¡ d| j kr&|  |	¡}| ¡ d| j d k s&tdƒ‚t t |¡¡d| j  }t d| ¡}t |	| d ¡}	| jd }| jj ¡ | jj ¡  }t || ¡}|	| }	|| j }|	| }||fS )	Nrd   Tr”   )Zdtypegš™™™™™¹?zfError detected in overflow handling: `var_int` exceeds `self.max_bit` (the maximum possible bit width)l        i   @)r   Úmeanr   r–   r’   r   rX   r“   Útensorr]   ÚfloatÚtoÚdeviceÚ	round_ster    rp   r   rK   rH   r,   r}   r   rI   r(   r)   )r"   r.   rq   rž   ÚyÚvarÚnr_   Zmean_intrt   rœ   r™   rš   Zstd_intrŽ   rX   Zbias_intr%   r%   r&   r3   à  s@    

ÿ

zIntLayerNorm.forward)r   Fra   )N)	r4   r5   r6   r7   r   r›   r   r3   r8   r%   r%   r#   r&   r   ¬  s
   	r   Fc           	      C   s€   | j d }t|d|d   ƒ}t|| d ƒ}tj| |dj}|dkrP|d }ntj|  |dj }|sx| ¡ }| ¡ }||fS )aÆ  
    Calculate the percentile max and min values in a given tensor

    Args:
        input (`torch.Tensor`):
            The target tensor to calculate percentile max and min.
        lower_percentile (`float`):
            If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min.
        upper_percentile (`float`):
            If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max.
        output_tensor (`bool`, *optional*, defaults to `False`):
            If True, this function returns tensors, otherwise it returns values.

    Returns:
        `Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of *input*
    r   r   g{®Gáz„?)rh   )r]   Úroundr   ZkthvalueÚvaluesrE   )	ÚinputZlower_percentileZupper_percentileZoutput_tensorZinput_lengthZlower_indexZupper_indexZupper_boundZlower_boundr%   r%   r&   Úget_percentile_min_max  s    

rª   c                 C   s¢   t | jƒdkr0| dddd¡}| dddd¡}n<t | jƒdkrX| dd¡}| dd¡}n| d¡}| d¡}|rŒ|  d| ¡ |¡ ¡  | S t d| |  | ¡S )a?  
    Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.

    Args:
        input (`torch.Tensor`):
            Single-precision input tensor to be quantized.
        scale (`torch.Tensor`):
            Scaling factor for quantization.
        zero_pint (`torch.Tensor`):
            Shift for quantization.
        inplace (`bool`, *optional*, defaults to `False`):
            Whether to compute inplace or not.

    Returns:
        `torch.Tensor`: Linearly quantized value of *input* according to *scale* and *zero_point*.
    é   rG   r   rd   rv   )Úlenr]   rM   Zmul_Zadd_Zround_r   r§   )r©   ÚscaleÚ
zero_pointÚinplacer%   r%   r&   Úlinear_quantize5  s    

r°   c              	   C   sŒ   t  ¡ z d| d  d }|rZt jt j| ¡ | ¡ gdddd\}}t j|dd| }n$t| ¡ | ¡ ƒ}t j|dd| }W 5 Q R X |S )a/  
    Compute the scaling factor with the given quantization range for symmetric quantization.

    Args:
        saturation_min (`torch.Tensor`):
            Lower bound for quantization range.
        saturation_max (`torch.Tensor`):
            Upper bound for quantization range.
        per_channel (`bool`, *optional*, defaults to `False`):
            Whether to or not use channel-wise quantization.

    Returns:
        `torch.Tensor`: Scaling factor that linearly quantizes the given range between *saturation_min* and
        *saturation_max*.
    rd   r   r‹   g:Œ0âŽyE>r„   )r   r   r,   Ústackro   r…   )Znum_bitsZsaturation_minZsaturation_maxr@   r¦   r­   r^   r%   r%   r&   r-   X  s    
(r-   c                   @   s(   e Zd ZdZedd„ ƒZedd„ ƒZdS )r   zw
    Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth.
    c                 C   sP   t  d¡ |j¡}d|d  d }t|||dd}t  || |d ¡}|| _|S )a6  
        Args:
            x (`torch.Tensor`):
                Floating point tensor to be quantized.
            k (`int`):
                Quantization bitwidth.
            percentile_mode (`bool`):
                Whether or not to use percentile calibration.
            scale (`torch.Tensor`):
                Pre-calculated scaling factor for *x*. Note that the current implementation of SymmetricQuantFunction
                requires pre-calculated scaling factor.

        Returns:
            `torch.Tensor`: Symmetric-quantized value of *input*.
        g        rd   r   F)r¯   )r   rŸ   r¡   r¢   r°   r…   r­   )Úctxr.   rh   r   r­   r®   r¦   Znew_quant_xr%   r%   r&   r3   }  s    zSymmetricQuantFunction.forwardc                 C   sb   | j }t|jƒdkr&| dddd¡}n&t|jƒdkrB| dd¡}n
| d¡}| ¡ | d d d d fS )Nr«   rG   r   rd   )r­   r¬   r]   rM   Úclone)r²   Úgrad_outputr­   r%   r%   r&   Úbackward—  s    
zSymmetricQuantFunction.backwardN©r4   r5   r6   r7   Ústaticmethodr3   rµ   r%   r%   r%   r&   r   x  s
   
r   c                   @   s(   e Zd ZdZedd„ ƒZedd„ ƒZdS )rp   z;
    Straight-through Estimator(STE) for torch.floor()
    c                 C   s
   t  |¡S ©N)r   rm   ©r²   r.   r%   r%   r&   r3   ª  s    zfloor_ste.forwardc                 C   s   |  ¡ S r¸   ©r³   ©r²   r´   r%   r%   r&   rµ   ®  s    zfloor_ste.backwardNr¶   r%   r%   r%   r&   rp   ¥  s
   
rp   c                   @   s(   e Zd ZdZedd„ ƒZedd„ ƒZdS )r£   z;
    Straight-through Estimator(STE) for torch.round()
    c                 C   s
   t  |¡S r¸   )r   r§   r¹   r%   r%   r&   r3   ¸  s    zround_ste.forwardc                 C   s   |  ¡ S r¸   rº   r»   r%   r%   r&   rµ   ¼  s    zround_ste.backwardNr¶   r%   r%   r%   r&   r£   ³  s
   
r£   é   c                 C   s®   |   ¡ }|  d¡} t |  ¡  ¡ ¡\}}g }|D ]6}tt |d|  ¡j	t d¡tj
dƒ}| |¡ q0t |¡}t|ƒ| }t |¡ | j¡ |¡t |¡ | j¡ |¡fS )zü
    Decompose the scaling factor into mantissa and twos exponent.

    Args:
        scaling_factor (`torch.Tensor`):
            Target scaling factor to decompose.

    Returns:
        ``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent
    rG   rd   Ú1)Úrounding)ÚsizerM   ÚnpÚfrexpÚcpuÚnumpyr˜   ÚdecimalÚDecimalÚquantizeÚROUND_HALF_UPÚappendÚarrayr    r   Z
from_numpyr¡   r¢   )Úinputsr}   Zshape_of_inputZoutput_mZoutput_eZtmp_mÚmZint_m_shiftedr%   r%   r&   Úbatch_frexpÁ  s    
"ÿ
þrÌ   c                   @   s*   e Zd ZdZeddd„ƒZedd„ ƒZdS )rL   aQ  
    Function to perform fixed-point arithmetic that can match integer arithmetic on hardware.

    Args:
        pre_act (`torch.Tensor`):
            Input tensor.
        pre_act_scaling_factor (`torch.Tensor`):
            Scaling factor of the input tensor *pre_act*.
        bit_num (`int`):
            Quantization bitwidth.
        z_scaling_factor (`torch.Tensor`):
            Scaling factor of the output tensor.
        identity (`torch.Tensor`, *optional*):
            Identity tensor, if exists.
        identity_scaling_factor (`torch.Tensor`, *optional*):
            Scaling factor of the identity tensor *identity*, if exists.

    Returns:
        `torch.Tensor`: Output tensor(*pre_act* if *identity* is not given, otherwise the addition of *pre_act* and
        *identity*), whose scale is rescaled to *z_scaling_factor*.
    Nc              
   C   sŒ  t |jƒdkrdd„ }ndd„ }|| _d|d  d }t ¡ B ||ƒ}|d k	rZ||ƒ}|| _t || ¡}	| tj¡}
| tj	¡ tj¡}|
| }||ƒ}t
|ƒ\}}|	 tj¡| tj¡ }t |d|  ¡}|d k	rZt || ¡}| tj¡}
| tj	¡ tj¡}|
| }||ƒ}t
|ƒ\}}| tj¡| tj¡ }t |d|  ¡}|| }t | tj	¡| d |¡W  5 Q R £ S Q R X d S )Nr   c                 S   s   | S r¸   r%   ©r.   r%   r%   r&   Ú<lambda>  ó    z'FixedPointMul.forward.<locals>.<lambda>c                 S   s   |   ddd¡S )Nr   rG   )rM   rÍ   r%   r%   r&   rÎ     rÏ   rd   r   r   )r¬   r]   rO   r   r   Úz_scaling_factorr§   ÚtypeÚdoubler    rÌ   r…   )r²   Zpre_actrN   Zbit_numrÐ   rO   rP   Zreshaper¦   Zz_intZ_AZ_BZ	new_scalerË   ÚeÚoutputZwx_intÚm1Úe1Zoutput1r%   r%   r&   r3   ú  s:    


zFixedPointMul.forwardc                 C   s8   d }| j d k	r| ¡ | j }| ¡ | j d d d d |d fS r¸   )rO   r³   rÐ   )r²   r´   Zidentity_gradr%   r%   r&   rµ   /  s    
zFixedPointMul.backward)NNr¶   r%   r%   r%   r&   rL   ã  s     ù4rL   )F)F)F)r¼   )rÄ   rÃ   rÀ   r   r   Ztorch.autogradr   Úutilsr   Z
get_loggerr4   re   ÚModuler   r9   rR   r`   rw   r   rª   r°   r-   r   rp   r£   rÌ   rL   r%   r%   r%   r&   Ú<module>   s(   
SjP9Ge
$
#
 -
"