如何将PReLU合并到量化模型中?

2024-03-29 15:05:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图量化一个使用PReLU的模型。用ReLU替换PReLU是不可能的,因为它会严重影响网络性能,甚至毫无用处

据我所知,Pytorch在量化方面不支持PReLU。因此,我尝试手动重写这个模块,并使用torch.FloatFunctional()实现乘法和加法来绕过这个限制

这就是我到目前为止提出的问题:

class PReLU_Quantized(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.weight = prelu_object.weight
        self.quantized_op = nn.quantized.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, inputs):
        # inputs = torch.max(0, inputs) + self.weight * torch.min(0, inputs)    
        self.weight = self.quant(self.weight)
        weight_min_res = self.quantized_op.mul(self.weight, torch.min(inputs)[0])
        inputs = self.quantized_op.add(torch.max(inputs)[0], weight_min_res).unsqueeze(0)
        self.weight = self.dequant(self.weight)
        return inputs

至于更换:

class model(nn.Module):
     def __init__(self)
         super().__init__()
         .... 
        self.prelu = PReLU()
        self.prelu_q = PReLU_Quantized(self.prelu)
         ....

基本上,我读取了现有prelu模块的学习参数,并在新模块中自己运行计算。从某种意义上说,该模块似乎在工作,它并没有使整个应用程序失败

然而,为了评估我的实现是否确实正确并产生与原始模块相同的结果,我尝试对其进行测试。
以下是正常模型(即非量化模型)的对应项:
由于某种原因,实际的PReLU和我的实现之间的错误非常大

以下是不同层中的示例差异:

diff : 1.1562038660049438
diff : 0.02868632599711418
diff : 0.3653906583786011
diff : 1.6100226640701294
diff : 0.8999372720718384
diff : 0.03773299604654312
diff : -0.5090572834014893
diff : 0.1654307246208191
diff : 1.161868691444397
diff : 0.026089997962117195
diff : 0.4205571115016937
diff : 1.5337920188903809
diff : 0.8799554705619812
diff : 0.03827812895178795
diff : -0.40296515822410583
diff : 0.15618863701820374

在前进过程中,差的计算如下:

def forward(self, x):
    residual = x
    out = self.bn0(x)
    out = self.conv1(out)
    out = self.bn1(out)

    out = self.prelu(out)
    out2 = self.prelu2(out)
    print(f'diff : {( out - out2).mean().item()}')

    out = self.conv2(out)
...

这是我在普通模型(即未量化!)上使用的正常实现,用于评估它是否产生正确的结果,然后转到量化版本:

class PReLU_2(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight

    def forward(self, inputs):
        x = self.weight
        tmin, _ = torch.min(inputs,dim=0)
        tmax, _ = torch.max(inputs,dim=0)
        weight_min_res = torch.mul(x, tmin)
        inputs = torch.add(tmax, weight_min_res)
        inputs = inputs.unsqueeze(0)
        return inputs

我错过了什么


Tags: 模块模型selfobjectinitdefdiffnn
1条回答
网友
1楼 · 发布于 2024-03-29 15:05:31

我知道了!我一开始就犯了一个很大的错误。我需要计算一下

PReLU(x)=max(0,x)+a∗min(0,x)


enter image description here
而不是实际的torch.min!或者torch.max!这没有任何意义! 以下是正常模型的最终解决方案(即未量化)!:

class PReLU_2(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight

    def forward(self, inputs):
        pos = torch.relu(inputs)
        neg = -self.weight * torch.relu(-inputs)
        inputs = pos + neg
        return inputs

这是量化版本:

class PReLU_Quantized(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight
        self.quantized_op = nn.quantized.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, inputs):
        # inputs = max(0, inputs) + alpha * min(0, inputs) 
        self.weight = self.quant(self.weight)
        weight_min_res = self.quantized_op.mul(-self.weight, torch.relu(-inputs))
        inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)
        inputs = self.dequant(inputs)
        self.weight = self.dequant(self.weight)
        return inputs

旁注:
我在计算差异时也有一个输入错误:

    out = self.prelu(out)
    out2 = self.prelu2(out)
    print(f'diff : {( out - out2).mean().item()}')

    out = self.conv2(out)

需要

    out1 = self.prelu(out)
    out2 = self.prelu2(out)
    print(f'diff : {( out1 - out2).mean().item()}')
    out = self.conv2(out1)

更新:

如果您面临量化问题,您可以尝试以下方法version

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.quantized as nnq
from torch.quantization import fuse_modules


class QPReLU(nn.Module):
    def __init__(self, num_parameters=1, init: float = 0.25):
        super(QPReLU, self).__init__()
        self.num_parameters = num_parameters
        self.weight = nn.Parameter(torch.Tensor(num_parameters).fill_(init))
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.f_mul_neg_one1 = nnq.FloatFunctional()
        self.f_mul_neg_one2 = nnq.FloatFunctional()
        self.f_mul_alpha = nnq.FloatFunctional()
        self.f_add = nnq.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.quant2 = torch.quantization.QuantStub()
        self.quant3 = torch.quantization.QuantStub()
        # self.dequant2 = torch.quantization.QuantStub()
        self.neg_one = torch.Tensor([-1.0])
        
    
    def forward(self, x):
        x = self.quant(x)
        
        # PReLU, with modules only
        x1 = self.relu1(x)
        
        neg_one_q = self.quant2(self.neg_one)
        weight_q = self.quant3(self.weight)
        x2 = self.f_mul_alpha.mul(
            weight_q, self.f_mul_neg_one2.mul(
                self.relu2(
                    self.f_mul_neg_one1.mul(x, neg_one_q),
                ),
            neg_one_q)
        )
        
        x = self.f_add.add(x1, x2)
        x = self.dequant(x)
        return x
    
m1 = nn.PReLU()
m2 = QPReLU()

# check correctness in fp
for i in range(10):
    data = torch.randn(2, 2) * 1000
    assert torch.allclose(m1(data), m2(data))

# toy model
class M(nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.prelu = QPReLU()
        
    def forward(self, x):
        x = self.prelu(x)
        return x
    
# quantize it
m = M()
m.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(m, inplace=True)
# calibrate
m(torch.randn(4, 4))
# convert
torch.quantization.convert(m, inplace=True)
# run some data through
res = m(torch.randn(4, 4))
print(res)

并确保阅读相关注释here

相关问题 更多 >