我是否正确理解PyTorch的加法和乘法?

2024-05-28 18:40:47 发布

您现在位置:Python中文网/ 问答频道 /正文

作者在this notebook中写了以下nesterov update

def nesterov_update(w, dw, v, lr, weight_decay, momentum):
    dw.add_(weight_decay, w).mul_(-lr)
    v.mul_(momentum).add_(dw)
    w.add_(dw.add_(momentum, v))

据我所知,PyTorch中的a.add(b)实现了a+ba.add(b,c)实现了a+(b*c),因为b位于alpha parameter的插槽中。最后,add_执行add的就地版本

Q:我到目前为止是对的吗?

然后,如果我要以一种说明逻辑的扩展形式绘制上述nesterov更新,我会写:

dw = -lr*(dw + weight_decay*w)
v = v*momentum + dw
w = w + dw + momentum*v

Q:这是正确的吗?

我不打算使用上面扩展的“代码”,我只是以这种方式编写它,试图传达我所理解的它正在做的事情,以进行检查


Tags: alphaadddefupdate作者pytorchthisnotebook
1条回答
网友
1楼 · 发布于 2024-05-28 18:40:47

请务必注意本教程使用的PyTorch版本(1.1.0)。根据1.1.0,torch.add的函数原型是torch.add(input, value=1, other, out=None)。那么,你对下面这句话的解释是:

dw.add_(weight_decay, w)

as:dw = dw + weight_decay * w是正确的。所以,你第一个问题的答案是,是的,你是对的

但是,对于PyTorch的最新版本,如果以相同的方式使用torch.add,则会出现错误

a = torch.FloatTensor([0, 1.0, 2.0, 3.0])
b = torch.FloatTensor([0, 4.0, 5.0, 6.0])
c = 1.0
z = a.add(b, c)

上述代码给出:(在PyTorch 1.5.0中)

TypeError: add() takes 1 positional argument but 2 were given

但是,如果执行以下操作,则它可以正常工作

z = a.add(b, alpha=c)

注意,torch.add的原型现在是:torch.add(input, other, *, alpha=1, out=None)


第二个问题的答案是,是的,你是对的

相关问题 更多 >

    热门问题