Python Numpy 高维矩阵乘法

3 投票
2 回答
3313 浏览
提问于 2025-04-18 06:04

我正在寻找一个在 numpy 中可以加速以下计算的矩阵操作。

我有两个三维矩阵 AB。它们的第一个维度表示样本,每个矩阵都有 n_examples 个样本。我想要做的是对 AB 中的每个样本进行点乘,然后把结果加起来:

import numpy as np

n_examples = 10
A = np.random.randn(n_examples, 20,30)
B = np.random.randn(n_examples, 30,5)
sum = np.zeros([20,5])
for i in range(len(A)):
  sum += np.dot(A[i],B[i])

2 个回答

2

哈哈,这个可以用一行代码搞定:np.einsum('nmk,nkj->mj',A,B)

看看爱因斯坦求和法:http://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html

虽然不是同一个问题,但思路差不多,可以看看我们刚讨论过的这个话题里的讨论和其他方法:numpy 矩阵相乘保留第三个轴

不要把你的变量命名为 sum,这样会覆盖掉内置的 sum 函数。

正如 @Jaime 指出的,对于这些维度,循环实际上更快。事实上,基于 mapsum 的解决方案虽然更简单,但速度反而更慢:

In [19]:

%%timeit
SUM = np.zeros([20,5])
for i in range(len(A)):
  SUM += np.dot(A[i],B[i])
10000 loops, best of 3: 115 µs per loop
In [20]:

%timeit np.array(map(np.dot, A,B)).sum(0)
1000 loops, best of 3: 445 µs per loop
In [21]:

%timeit np.einsum('nmk,nkj->mj',A,B)
1000 loops, best of 3: 259 µs per loop

对于更大的维度,情况就不同了:

n_examples = 1000
A = np.random.randn(n_examples, 20,1000)
B = np.random.randn(n_examples, 1000,5)

还有:

In [46]:

%%timeit
SUM = np.zeros([20,5])
for i in range(len(A)):
  SUM += np.dot(A[i],B[i])
1 loops, best of 3: 191 ms per loop
In [47]:

%timeit np.array(map(np.dot, A,B)).sum(0)
1 loops, best of 3: 164 ms per loop
In [48]:

%timeit np.einsum('nmk,nkj->mj',A,B)
1 loops, best of 3: 451 ms per loop
4

这是一个使用 np.tensordot() 的典型例子:

sum = np.tensordot(A, B, [[0,2],[0,1]])

计时

使用以下代码:

import numpy as np

n_examples = 100
A = np.random.randn(n_examples, 20,30)
B = np.random.randn(n_examples, 30,5)

def sol1():
    sum = np.zeros([20,5])
    for i in range(len(A)):
      sum += np.dot(A[i],B[i])
    return sum

def sol2():
    return np.array(map(np.dot, A,B)).sum(0)

def sol3():
    return np.einsum('nmk,nkj->mj',A,B)

def sol4():
    return np.tensordot(A, B, [[2,0],[1,0]])

def sol5():
    return np.tensordot(A, B, [[0,2],[0,1]])

结果:

timeit sol1()
1000 loops, best of 3: 1.46 ms per loop

timeit sol2()
100 loops, best of 3: 4.22 ms per loop

timeit sol3()
1000 loops, best of 3: 1.87 ms per loop

timeit sol4()
10000 loops, best of 3: 205 µs per loop

timeit sol5()
10000 loops, best of 3: 172 µs per loop

在我的电脑上,tensordot() 是最快的解决方案,改变轴的计算顺序既没有改变结果,也没有影响性能。

撰写回答