我试图量化一个使用PReLU
的模型。用ReLU
替换PReLU
是不可能的,因为它会严重影响网络性能,甚至毫无用处
据我所知,Pytorch在量化方面不支持PReLU
。因此,我尝试手动重写这个模块,并使用torch.FloatFunctional()
实现乘法和加法来绕过这个限制
这就是我到目前为止提出的问题:
class PReLU_Quantized(nn.Module):
def __init__(self, prelu_object):
super().__init__()
self.weight = prelu_object.weight
self.quantized_op = nn.quantized.FloatFunctional()
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, inputs):
# inputs = torch.max(0, inputs) + self.weight * torch.min(0, inputs)
self.weight = self.quant(self.weight)
weight_min_res = self.quantized_op.mul(self.weight, torch.min(inputs)[0])
inputs = self.quantized_op.add(torch.max(inputs)[0], weight_min_res).unsqueeze(0)
self.weight = self.dequant(self.weight)
return inputs
至于更换:
class model(nn.Module):
def __init__(self)
super().__init__()
....
self.prelu = PReLU()
self.prelu_q = PReLU_Quantized(self.prelu)
....
基本上,我读取了现有prelu模块的学习参数,并在新模块中自己运行计算。从某种意义上说,该模块似乎在工作,它并没有使整个应用程序失败
然而,为了评估我的实现是否确实正确并产生与原始模块相同的结果,我尝试对其进行测试。
以下是正常模型(即非量化模型)的对应项:
由于某种原因,实际的PReLU
和我的实现之间的错误非常大
以下是不同层中的示例差异:
diff : 1.1562038660049438
diff : 0.02868632599711418
diff : 0.3653906583786011
diff : 1.6100226640701294
diff : 0.8999372720718384
diff : 0.03773299604654312
diff : -0.5090572834014893
diff : 0.1654307246208191
diff : 1.161868691444397
diff : 0.026089997962117195
diff : 0.4205571115016937
diff : 1.5337920188903809
diff : 0.8799554705619812
diff : 0.03827812895178795
diff : -0.40296515822410583
diff : 0.15618863701820374
在前进过程中,差的计算如下:
def forward(self, x):
residual = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.prelu(out)
out2 = self.prelu2(out)
print(f'diff : {( out - out2).mean().item()}')
out = self.conv2(out)
...
这是我在普通模型(即未量化!)上使用的正常实现,用于评估它是否产生正确的结果,然后转到量化版本:
class PReLU_2(nn.Module):
def __init__(self, prelu_object):
super().__init__()
self.prelu_weight = prelu_object.weight
self.weight = self.prelu_weight
def forward(self, inputs):
x = self.weight
tmin, _ = torch.min(inputs,dim=0)
tmax, _ = torch.max(inputs,dim=0)
weight_min_res = torch.mul(x, tmin)
inputs = torch.add(tmax, weight_min_res)
inputs = inputs.unsqueeze(0)
return inputs
我错过了什么
我知道了!我一开始就犯了一个很大的错误。我需要计算一下
或
而不是实际的
torch.min
!或者torch.max
!这没有任何意义! 以下是正常模型的最终解决方案(即未量化)!:这是量化版本:
旁注:
我在计算差异时也有一个输入错误:
需要
更新:
如果您面临量化问题,您可以尝试以下方法version:
并确保阅读相关注释here
相关问题 更多 >
编程相关推荐