在numpy中数值稳定地相乘对数概率矩阵的方法

41 投票
4 回答
9164 浏览
提问于 2025-04-18 06:24

我需要对两个包含对数概率的NumPy矩阵(或者其他二维数组)进行矩阵乘法。直接用 np.log(np.dot(np.exp(a), np.exp(b))) 这种方法显然不太好。

使用

from scipy.misc import logsumexp
res = np.zeros((a.shape[0], b.shape[1]))
for n in range(b.shape[1]):
    # broadcast b[:,n] over rows of a, sum columns
    res[:, n] = logsumexp(a + b[:, n].T, axis=1) 

的方法可以实现,但速度比 np.log(np.dot(np.exp(a), np.exp(b))) 慢大约100倍。

使用

logsumexp((tile(a, (b.shape[1],1)) + repeat(b.T, a.shape[0], axis=0)).reshape(b.shape[1],a.shape[0],a.shape[1]), 2).T

或者其他的组合,比如tile和reshape,也能工作,但由于需要的内存量太大,速度比上面的方法还要慢,尤其是当输入矩阵比较大时。

我现在在考虑用C语言写一个NumPy扩展来计算这个,但我当然希望能避免这样做。有没有什么成熟的方法,或者有没有人知道更省内存的计算方式?

编辑:感谢larsmans提供的这个解决方案(下面有推导过程):

def logdot(a, b):
    max_a, max_b = np.max(a), np.max(b)
    exp_a, exp_b = a - max_a, b - max_b
    np.exp(exp_a, out=exp_a)
    np.exp(exp_b, out=exp_b)
    c = np.dot(exp_a, exp_b)
    np.log(c, out=c)
    c += max_a + max_b
    return c

用iPython的魔法命令 %timeit 快速比较一下这个方法和上面提到的 logdot_old 方法,结果如下:

In  [1] a = np.log(np.random.rand(1000,2000))

In  [2] b = np.log(np.random.rand(2000,1500))

In  [3] x = logdot(a, b)

In  [4] y = logdot_old(a, b) # this takes a while

In  [5] np.any(np.abs(x-y) > 1e-14)
Out [5] False

In  [6] %timeit logdot_old(a, b)
1 loops, best of 3: 1min 18s per loop

In  [6] %timeit logdot(a, b)
1 loops, best of 3: 264 ms per loop

显然,larsmans的方法比我的要好得多!

4 个回答

1

你正在访问 resb 的列,这样做的效率不太高,因为它们的 局部性差。可以尝试把它们存储成 列优先的方式

3

目前被接受的答案是Fred Foo和Hassan的答案,但它们在数值上不太稳定(Hassan的答案更好)。稍后会提供一个Hassan答案失败的输入示例。我的实现方式如下:

import numpy as np
from scipy.special import logsumexp

def logmatmulexp(log_A: np.ndarray, log_B: np.ndarray) -> np.ndarray:
    """Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                             
    (log_A.exp() @ log_B.exp()).log() in a numerically stable way.                                                                                                                                                                           
    Has O(ϴRI) time complexity and space complexity."""
    ϴ, R = log_A.shape
    I = log_B.shape[1]
    assert log_B.shape == (R, I)
    log_A_expanded = np.broadcast_to(np.expand_dims(log_A, 2), (ϴ, R, I))
    log_B_expanded = np.broadcast_to(np.expand_dims(log_B, 0), (ϴ, R, I))
    log_pairwise_products = log_A_expanded + log_B_expanded  # shape: (ϴ, R, I)                                                                                                                                                              
    return logsumexp(log_pairwise_products, axis=1)

和Hassan的答案以及Fred Foo的答案一样,我的答案的时间复杂度是O(ϴRI)。他们的答案的空间复杂度是O(ϴR+RI)(我其实不太确定这个),而我的答案不幸的是空间复杂度是O(ϴRI)——这是因为numpy可以在不额外分配ϴ×R×I大小的数组的情况下,将一个ϴ×R的矩阵与一个R×I的矩阵相乘。O(ϴRI)的空间复杂度并不是我方法的固有特性——我认为如果你用循环来写,可以避免这个空间复杂度,但不幸的是,我认为使用标准的numpy函数是做不到的。

我检查了我的代码实际运行的时间,它比普通的矩阵乘法慢20倍。

以下是我如何证明我的答案在数值上是稳定的:

  1. 显然,除了返回行以外的所有行都是数值稳定的。
  2. logsumexp函数是公认的数值稳定。
  3. 因此,我的logmatmulexp函数也是数值稳定的。

我的实现还有另一个不错的特性。如果你不使用numpy,而是用pytorch或其他具有自动微分的库编写相同的代码,你将自动获得数值稳定的反向传播。以下是我们如何知道反向传播将是数值稳定的:

  1. 我代码中的所有函数在任何地方都是可微分的(与np.max不同)。
  2. 显然,除了返回行以外的所有行的反向传播都是数值稳定的,因为那里没有发生任何奇怪的事情。
  3. 通常,pytorch的开发者知道他们在做什么。所以只要相信他们实现了logsumexp的反向传播是数值稳定的就可以了。
  4. 实际上,logsumexp的梯度是softmax函数(参考可以谷歌“softmax is gradient of logsumexp”或查看https://arxiv.org/abs/1704.00805的命题1)。已知softmax可以以数值稳定的方式计算。所以pytorch的开发者可能在这里直接使用softmax(我实际上没有检查过)。

下面是相同的代码在pytorch中的实现(如果你需要反向传播的话)。由于pytorch反向传播的工作方式,在前向传播期间,它会保存log_pairwise_products张量以供反向传播使用。这个张量很大,你可能不想让它被保存——你可以在反向传播期间重新计算一次。在这种情况下,我建议你使用检查点技术——这真的很简单——请看下面的第二个函数。

import torch
from torch.utils.checkpoint import checkpoint

def logmatmulexp(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
    """Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                             
    (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
    ϴ, R = log_A.shape
    I = log_B.shape[1]
    assert log_B.shape == (R, I)
    log_A_expanded = log_A.unsqueeze(2).expand((ϴ, R, I))
    log_B_expanded = log_B.unsqueeze(0).expand((ϴ, R, I))
    log_pairwise_products = log_A_expanded + log_B_expanded  # shape: (ϴ, R, I)                                                                                                                                                              
    return torch.logsumexp(log_pairwise_products, dim=1)


def logmatmulexp_lowmem(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
    """Same as logmatmulexp, but doesn't save a (ϴ, R, I)-shaped tensor for backward pass.                                                                                                                                                   

    Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                                
    (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
    return checkpoint(logmatmulexp, log_A, log_B)

以下是一个Hassan的实现失败但我的实现给出正确输出的输入示例:

def logmatmulexp_hassan(A, B):
    max_A = np.max(A,1,keepdims=True)
    max_B = np.max(B,0,keepdims=True)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

log_A = np.array([[-500., 900.]], dtype=np.float64)
log_B = np.array([[900.], [-500.]], dtype=np.float64)
print(logmatmulexp_hassan(log_A, log_B)) # prints -inf, while the correct answer is approximately 400.69.
5

假设有两个矩阵,A.shape==(n,r)B.shape==(r,m)。在计算这两个矩阵的乘积 C=A*B 时,实际上会进行 n*m 次求和。为了在对数空间中获得稳定的结果,你需要在每次求和时使用 logsumexp 技巧。幸运的是,利用 numpy 的广播功能,可以很方便地分别控制矩阵 A 和 B 的行和列的稳定性。

下面是代码:

def logdotexp(A, B):
    max_A = np.max(A,1,keepdims=True)
    max_B = np.max(B,0,keepdims=True)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

注意:

这个方法的原理和 FredFoo 的回答类似,但他只为每个矩阵使用了一个最大值。因为他没有考虑到每次 n*m 的求和,所以最终矩阵中的某些元素可能仍然不稳定,这在评论中也提到过。

与当前被接受的答案进行比较,使用 @identity-m 的反例:

def logdotexp_less_stable(A, B):
    max_A = np.max(A)
    max_B = np.max(B)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

print('old method:')
print(logdotexp_less_stable([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
print('new method:')
print(logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]]))

这段代码会输出:

old method:
[[      -inf 0.69314718]
 [      -inf 0.69314718]]
new method:
[[-9.99306853e+02  6.93147181e-01]
 [-9.99306853e+02  6.93147181e-01]]
28

logsumexp 的工作原理是先计算等式右边的部分

log(∑ exp[a]) = max(a) + log(∑ exp[a - max(a)])

也就是说,它在开始求和之前先找出最大的值,这样可以避免在计算 exp 时出现溢出。这个方法同样可以在进行向量点积之前使用:

log(exp[a] ⋅ exp[b])
 = log(∑ exp[a] × exp[b])
 = log(∑ exp[a + b])
 = max(a + b) + log(∑ exp[a + b - max(a + b)])     { this is logsumexp(a + b) }

但通过不同的推导方式,我们可以得到

log(∑ exp[a] × exp[b])
 = max(a) + max(b) + log(∑ exp[a - max(a)] × exp[b - max(b)])
 = max(a) + max(b) + log(exp[a - max(a)] ⋅ exp[b - max(b)])

最终的形式里面包含了一个向量的点积。这个方法也很容易扩展到矩阵乘法,所以我们得到了这个算法

def logdotexp(A, B):
    max_A = np.max(A)
    max_B = np.max(B)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

这个过程会创建两个大小为 A 的临时变量和两个大小为 B 的临时变量,但其中一个可以通过以下方式消除

exp_A = A - max_A
np.exp(exp_A, out=exp_A)

对于 B 也是一样的情况。(如果输入的矩阵可能会被函数修改,那么所有的临时变量都可以被消除。)

撰写回答