在numpy中数值稳定地相乘对数概率矩阵的方法
我需要对两个包含对数概率的NumPy矩阵(或者其他二维数组)进行矩阵乘法。直接用 np.log(np.dot(np.exp(a), np.exp(b)))
这种方法显然不太好。
使用
from scipy.misc import logsumexp
res = np.zeros((a.shape[0], b.shape[1]))
for n in range(b.shape[1]):
# broadcast b[:,n] over rows of a, sum columns
res[:, n] = logsumexp(a + b[:, n].T, axis=1)
的方法可以实现,但速度比 np.log(np.dot(np.exp(a), np.exp(b)))
慢大约100倍。
使用
logsumexp((tile(a, (b.shape[1],1)) + repeat(b.T, a.shape[0], axis=0)).reshape(b.shape[1],a.shape[0],a.shape[1]), 2).T
或者其他的组合,比如tile和reshape,也能工作,但由于需要的内存量太大,速度比上面的方法还要慢,尤其是当输入矩阵比较大时。
我现在在考虑用C语言写一个NumPy扩展来计算这个,但我当然希望能避免这样做。有没有什么成熟的方法,或者有没有人知道更省内存的计算方式?
编辑:感谢larsmans提供的这个解决方案(下面有推导过程):
def logdot(a, b):
max_a, max_b = np.max(a), np.max(b)
exp_a, exp_b = a - max_a, b - max_b
np.exp(exp_a, out=exp_a)
np.exp(exp_b, out=exp_b)
c = np.dot(exp_a, exp_b)
np.log(c, out=c)
c += max_a + max_b
return c
用iPython的魔法命令 %timeit
快速比较一下这个方法和上面提到的 logdot_old
方法,结果如下:
In [1] a = np.log(np.random.rand(1000,2000))
In [2] b = np.log(np.random.rand(2000,1500))
In [3] x = logdot(a, b)
In [4] y = logdot_old(a, b) # this takes a while
In [5] np.any(np.abs(x-y) > 1e-14)
Out [5] False
In [6] %timeit logdot_old(a, b)
1 loops, best of 3: 1min 18s per loop
In [6] %timeit logdot(a, b)
1 loops, best of 3: 264 ms per loop
显然,larsmans的方法比我的要好得多!
4 个回答
目前被接受的答案是Fred Foo和Hassan的答案,但它们在数值上不太稳定(Hassan的答案更好)。稍后会提供一个Hassan答案失败的输入示例。我的实现方式如下:
import numpy as np
from scipy.special import logsumexp
def logmatmulexp(log_A: np.ndarray, log_B: np.ndarray) -> np.ndarray:
"""Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates
(log_A.exp() @ log_B.exp()).log() in a numerically stable way.
Has O(ϴRI) time complexity and space complexity."""
ϴ, R = log_A.shape
I = log_B.shape[1]
assert log_B.shape == (R, I)
log_A_expanded = np.broadcast_to(np.expand_dims(log_A, 2), (ϴ, R, I))
log_B_expanded = np.broadcast_to(np.expand_dims(log_B, 0), (ϴ, R, I))
log_pairwise_products = log_A_expanded + log_B_expanded # shape: (ϴ, R, I)
return logsumexp(log_pairwise_products, axis=1)
和Hassan的答案以及Fred Foo的答案一样,我的答案的时间复杂度是O(ϴRI)。他们的答案的空间复杂度是O(ϴR+RI)(我其实不太确定这个),而我的答案不幸的是空间复杂度是O(ϴRI)——这是因为numpy可以在不额外分配ϴ×R×I大小的数组的情况下,将一个ϴ×R的矩阵与一个R×I的矩阵相乘。O(ϴRI)的空间复杂度并不是我方法的固有特性——我认为如果你用循环来写,可以避免这个空间复杂度,但不幸的是,我认为使用标准的numpy函数是做不到的。
我检查了我的代码实际运行的时间,它比普通的矩阵乘法慢20倍。
以下是我如何证明我的答案在数值上是稳定的:
- 显然,除了返回行以外的所有行都是数值稳定的。
logsumexp
函数是公认的数值稳定。- 因此,我的
logmatmulexp
函数也是数值稳定的。
我的实现还有另一个不错的特性。如果你不使用numpy,而是用pytorch或其他具有自动微分的库编写相同的代码,你将自动获得数值稳定的反向传播。以下是我们如何知道反向传播将是数值稳定的:
- 我代码中的所有函数在任何地方都是可微分的(与
np.max
不同)。 - 显然,除了返回行以外的所有行的反向传播都是数值稳定的,因为那里没有发生任何奇怪的事情。
- 通常,pytorch的开发者知道他们在做什么。所以只要相信他们实现了logsumexp的反向传播是数值稳定的就可以了。
- 实际上,logsumexp的梯度是softmax函数(参考可以谷歌“softmax is gradient of logsumexp”或查看https://arxiv.org/abs/1704.00805的命题1)。已知softmax可以以数值稳定的方式计算。所以pytorch的开发者可能在这里直接使用softmax(我实际上没有检查过)。
下面是相同的代码在pytorch中的实现(如果你需要反向传播的话)。由于pytorch反向传播的工作方式,在前向传播期间,它会保存log_pairwise_products
张量以供反向传播使用。这个张量很大,你可能不想让它被保存——你可以在反向传播期间重新计算一次。在这种情况下,我建议你使用检查点技术——这真的很简单——请看下面的第二个函数。
import torch
from torch.utils.checkpoint import checkpoint
def logmatmulexp(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
"""Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates
(log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
ϴ, R = log_A.shape
I = log_B.shape[1]
assert log_B.shape == (R, I)
log_A_expanded = log_A.unsqueeze(2).expand((ϴ, R, I))
log_B_expanded = log_B.unsqueeze(0).expand((ϴ, R, I))
log_pairwise_products = log_A_expanded + log_B_expanded # shape: (ϴ, R, I)
return torch.logsumexp(log_pairwise_products, dim=1)
def logmatmulexp_lowmem(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
"""Same as logmatmulexp, but doesn't save a (ϴ, R, I)-shaped tensor for backward pass.
Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates
(log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
return checkpoint(logmatmulexp, log_A, log_B)
以下是一个Hassan的实现失败但我的实现给出正确输出的输入示例:
def logmatmulexp_hassan(A, B):
max_A = np.max(A,1,keepdims=True)
max_B = np.max(B,0,keepdims=True)
C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
np.log(C, out=C)
C += max_A + max_B
return C
log_A = np.array([[-500., 900.]], dtype=np.float64)
log_B = np.array([[900.], [-500.]], dtype=np.float64)
print(logmatmulexp_hassan(log_A, log_B)) # prints -inf, while the correct answer is approximately 400.69.
假设有两个矩阵,A.shape==(n,r)
和 B.shape==(r,m)
。在计算这两个矩阵的乘积 C=A*B
时,实际上会进行 n*m
次求和。为了在对数空间中获得稳定的结果,你需要在每次求和时使用 logsumexp 技巧。幸运的是,利用 numpy 的广播功能,可以很方便地分别控制矩阵 A 和 B 的行和列的稳定性。
下面是代码:
def logdotexp(A, B):
max_A = np.max(A,1,keepdims=True)
max_B = np.max(B,0,keepdims=True)
C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
np.log(C, out=C)
C += max_A + max_B
return C
注意:
这个方法的原理和 FredFoo 的回答类似,但他只为每个矩阵使用了一个最大值。因为他没有考虑到每次 n*m
的求和,所以最终矩阵中的某些元素可能仍然不稳定,这在评论中也提到过。
与当前被接受的答案进行比较,使用 @identity-m 的反例:
def logdotexp_less_stable(A, B):
max_A = np.max(A)
max_B = np.max(B)
C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
np.log(C, out=C)
C += max_A + max_B
return C
print('old method:')
print(logdotexp_less_stable([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
print('new method:')
print(logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
这段代码会输出:
old method:
[[ -inf 0.69314718]
[ -inf 0.69314718]]
new method:
[[-9.99306853e+02 6.93147181e-01]
[-9.99306853e+02 6.93147181e-01]]
logsumexp
的工作原理是先计算等式右边的部分
log(∑ exp[a]) = max(a) + log(∑ exp[a - max(a)])
也就是说,它在开始求和之前先找出最大的值,这样可以避免在计算 exp
时出现溢出。这个方法同样可以在进行向量点积之前使用:
log(exp[a] ⋅ exp[b])
= log(∑ exp[a] × exp[b])
= log(∑ exp[a + b])
= max(a + b) + log(∑ exp[a + b - max(a + b)]) { this is logsumexp(a + b) }
但通过不同的推导方式,我们可以得到
log(∑ exp[a] × exp[b])
= max(a) + max(b) + log(∑ exp[a - max(a)] × exp[b - max(b)])
= max(a) + max(b) + log(exp[a - max(a)] ⋅ exp[b - max(b)])
最终的形式里面包含了一个向量的点积。这个方法也很容易扩展到矩阵乘法,所以我们得到了这个算法
def logdotexp(A, B):
max_A = np.max(A)
max_B = np.max(B)
C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
np.log(C, out=C)
C += max_A + max_B
return C
这个过程会创建两个大小为 A
的临时变量和两个大小为 B
的临时变量,但其中一个可以通过以下方式消除
exp_A = A - max_A
np.exp(exp_A, out=exp_A)
对于 B
也是一样的情况。(如果输入的矩阵可能会被函数修改,那么所有的临时变量都可以被消除。)