[Pytorch]寻求矢量化的实现

2021-05-16 09:01:08 发布

您现在位置:Python中文网/ 问答频道 /正文

这是我的代码,'x'是一个常规的CNN权重,大小为:(out\u c,in\u c,k\u h,k\u w),其中out\u c也是CNN内核的数量。lb(下限)和ub(上限)是两个预定义的1D numpy数组,长度out\c。你知道吗

任务是通过剪裁所有异常值,确保所有CNN内核权重的值都在lbub定义的范围内(即lb[i]<;x[i,:,:,:]<;ub[i],0<;=i<;out\u c)。 除此之外,我需要保留异常值的掩码以供其他使用。你知道吗

但是,我找不到完全矢量化的解决方案,因此不涉及for循环。我使用的是Pytorch 0.3.1,有没有什么方法可以使这个代码段完全矢量化(去掉for循环)?谢谢。你知道吗

N = x.size(0) # number of cnn kernel
lower_mask = []
upper_mask = []
for i in range(N):
    lower_mask.append(x[i] < lb.tolist()[i])
    upper_mask.append(ub.tolist()[i] < x[i])
    x[i].clamp_(lb.tolist()[i], ub.tolist()[i])
lower_mask = torch.stack(lower_mask).cuda()
upper_mask = torch.stack(upper_mask).cuda()