在Pytorch中检查导数张量
在这个问题中,大家在讨论一个函数的导数,函数的形式是f(x) = Axx'A / (x'AAx)
,其中x
是一个向量,而A
是一个对称的、半正定的方阵。
这个函数在某个点x
的导数是一个张量。当这个张量“作用”于另一个向量h
时,它就变成了一个矩阵。这个帖子下面的回答在这个矩阵的表达式上有不同的看法,所以我想用Pytorch
或Autograd
来进行数值验证。
这是我用Pytorch的尝试:
import torch
def P(x, A):
x = x.unsqueeze(1) # Convert to column vector
vector = torch.matmul(A, x)
denom = (vector.transpose(0, 1) @ vector).squeeze()
P_matrix = (vector @ vector.transpose(0, 1)) / denom
return P_matrix.squeeze()
A = torch.tensor([[1.0, 0.5], [0.5, 1.3]], dtype=torch.float32)
x = torch.tensor([1.0, 2.0], dtype=torch.float32, requires_grad=True)
h = torch.tensor([2.0, -1.0], dtype=torch.float32)
Pxh = torch.matmul(P(x, A), h)
# compute gradient
Pxh.backward()
但是这个方法不行。我哪里做错了呢?
JAX
我也希望能有一个JAX的解决方案。我试过jax.grad
,但也不行。
1 个回答
1
如果你想要进行反向传播并处理非标量值(也就是不是单个数字的值),你需要传递一个全是1的张量。
import torch
def P(x, A):
x = x.unsqueeze(1) # Convert to column vector
vector = torch.matmul(A, x)
denom = (vector.transpose(0, 1) @ vector).squeeze()
P_matrix = (vector @ vector.transpose(0, 1)) / denom
return P_matrix.squeeze()
A = torch.tensor([[1.0, 0.5], [0.5, 1.3]], dtype=torch.float32)
x = torch.tensor([1.0, 2.0], dtype=torch.float32, requires_grad=True)
h = torch.tensor([2.0, -1.0], dtype=torch.float32)
Pxh = torch.matmul(P(x, A), h)
Pxh.backward(torch.ones_like(Pxh))
x.grad
> tensor([ 0.4853, -0.2427])