获取输出相对于输入的梯度

2024-04-18 14:44:09 发布

您现在位置:Python中文网/ 问答频道 /正文

我目前正在尝试用Pytorch实现一个ODE解算器,我的解决方案需要计算每个输出wtr到其输入的梯度

y = model(x)

for i in range(len(y)): #compute output grad wrt input
       y[i].backward(retain_graph=True)
    
ydx=x.grad 

我想知道是否有一种更优雅的方法来计算批处理中每个输出的梯度,因为高阶ODE和PDE的代码会变得混乱。 我尝试使用:

torch.autograd.backward(x,y,retain_graph=True)

没有多少成功


Tags: intrueformodelrangepytorch解决方案graph
2条回答

如果您的PyTorch版本实现了API,请尝试torch.autograd.functional.jacobian。我也在为汉堡方程式做同样的事情,并在同一主题上发布了这条帖子:PyTorch how to compute second order jacobian?

使用DL解决PDE是当前的热门话题

可以使用torch.autograd.grad函数直接获取渐变。一个问题是它要求输出(y)是标量。因为您的输出是一个数组,所以仍然需要循环遍历它的值

这个电话看起来像这样

[torch.autograd.grad(outputs=out, inputs=x, retain_graph=True)[0][i] 
    for i, out in enumerate(y)]

这是我的意思的一个例子。让我们考虑变量^ {< CD3}},值{{CD4}}和一个只对其输入进行平方的模型。

x = torch.Tensor([1, 2, 3])
x.requires_grad = True

def model(x): return x ** 2

y = model(x)

现在,如果您像我描述的那样调用torch.autograd.grad,您将得到:

[torch.autograd.grad(outputs=out, inputs=x, retain_graph=True)[0][i] 
    for i, out in enumerate(y)]

# [tensor(2.), tensor(4.), tensor(6.)]

这是wrt衍生产品列表。到值x-[ydx0, ydx1, ydx2]

相关问题 更多 >