作者在this notebook中写了以下nesterov update:
def nesterov_update(w, dw, v, lr, weight_decay, momentum):
dw.add_(weight_decay, w).mul_(-lr)
v.mul_(momentum).add_(dw)
w.add_(dw.add_(momentum, v))
据我所知,PyTorch中的a.add(b)
实现了a+b
和a.add(b,c)
实现了a+(b*c)
,因为b
位于alpha parameter的插槽中。最后,add_
执行add
的就地版本
Q:我到目前为止是对的吗?
然后,如果我要以一种说明逻辑的扩展形式绘制上述nesterov更新,我会写:
dw = -lr*(dw + weight_decay*w)
v = v*momentum + dw
w = w + dw + momentum*v
Q:这是正确的吗?
我不打算使用上面扩展的“代码”,我只是以这种方式编写它,试图传达我所理解的它正在做的事情,以进行检查
请务必注意本教程使用的PyTorch版本(1.1.0)。根据1.1.0,torch.add的函数原型是
torch.add(input, value=1, other, out=None)
。那么,你对下面这句话的解释是:as:
dw = dw + weight_decay * w
是正确的。所以,你第一个问题的答案是,是的,你是对的但是,对于PyTorch的最新版本,如果以相同的方式使用torch.add,则会出现错误
上述代码给出:(在PyTorch 1.5.0中)
但是,如果执行以下操作,则它可以正常工作
注意,torch.add的原型现在是:
torch.add(input, other, *, alpha=1, out=None)
第二个问题的答案是,是的,你是对的
相关问题 更多 >
编程相关推荐