Python中numpy einsum的更快替代方案

0 投票
1 回答
64 浏览
提问于 2025-04-12 16:35

我正在尝试对一些张量(可以理解为多维数组)进行一些操作。目前我在使用einsum这个方法,但我在想有没有其他方法(比如使用dot或tensordot)能让这些操作更快,因为我觉得我做的事情大致上就是一些外积和内积。

场景1:

A = np.arange(6).reshape((2, 3))
B = np.arange(12).reshape((2, 3, 2))
res1 = numpy.einsum('ij, kjh->ikjh', A, B)

>>> res1 = 
[[[[ 0  0]
   [ 2  3]
   [ 8 10]]

  [[ 0  0]
   [ 8  9]
   [20 22]]]


 [[[ 0  3]
   [ 8 12]
   [20 25]]

  [[18 21]
   [32 36]
   [50 55]]]].

场景2:

C = np.arange(12).reshape((2, 3, 2))
D = np.arange(6).reshape((3, 2))
res2 = np.einsum('ijk, jk->ij', C, D)

>>> res2 = 
[[ 1 13 41]
 [ 7 43 95]]

我试过使用tensordot和dot,但不知道为什么,我就是找不到正确的方式来设置轴...

1 个回答

1

让我们来看看你的第一个计算。我会用一个小例子开始,确保数值是匹配的。这个小例子的时间可能不代表你在实际应用中的需求。

In [138]: n,m,k = 2,3,4
In [141]: A = np.arange(n*m).reshape(n,m)
In [142]: B = np.arange(n*m*k).reshape(n,m,k)


In [144]: res1 = np.einsum('ij, kjh->ikjh', A, B)    
In [145]: res1.shape
Out[145]: (2, 2, 3, 4)

因为没有求和的乘积(j在所有项中都有),我们可以用广播相乘来处理:

In [146]: x=A[:,None,:,None]*B
In [147]: x.shape
Out[147]: (2, 2, 3, 4)

结果和形状都是匹配的:

In [148]: np.allclose(res1,x)
Out[148]: True

有时候(在通常的标量条件下):

In [149]: timeit res1 = np.einsum('ij, kjh->ikjh', A, B)
13.6 µs ± 67.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [150]: timeit x=A[:,None,:,None]*B
7.2 µs ± 74.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [151]: timeit res1 = np.einsum('ij, kjh->ikjh', A, B, optimize=True)
90.5 µs ± 2.09 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

广播的结果是最好的——我预计在更大的数组中也会这样。而且optimize并没有帮助。我们可以看看einsum_path,但只有两个参数的话,改进的空间不大。

第二个

In [152]: C = np.arange(n*m*k).reshape(n,m,k)
D = np.arange(m*k).reshape(m,k)

In [153]: res2 = np.einsum('ijk, jk->ij', C, D)
In [154]: res2.shape
Out[154]: (2, 3)

这些形状可以广播而不改变:

In [155]: (C*D).shape
Out[155]: (2, 3, 4)    
In [156]: y=(C*D).sum(2)  # sum-of-products on last dimension    
In [157]: y.shape
Out[157]: (2, 3)

这和einsum是匹配的:

In [158]: np.allclose(res2,y)
Out[158]: True

一种矩阵乘法的方法:

In [159]: (C@D.T).shape
Out[159]: (2, 3, 3)    
In [160]: np.allclose((C@D.T)[:,np.arange(3),np.arange(3)],res2)
Out[160]: True

我不喜欢必须取最后的对角线;我还需要再玩一玩。

对于这些小的时间测试,einsum仍然是最好的:

In [164]: timeit res2 = np.einsum('ijk, jk->ij', C, D)
11.9 µs ± 31 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [165]: timeit y=(C*D).sum(2)
13.9 µs ± 25.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [166]: timeit (C@D.T)[:,np.arange(3),np.arange(3)]
21.4 µs ± 78.8 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

第二个矩阵乘法

正确且更快的矩阵乘法

In [167]: (C[:,:,None,:]@D[:,:,None]).shape
Out[167]: (2, 3, 1, 1)

去掉多余的1:

In [168]: np.allclose((C[:,:,None,:]@D[:,:,None])[:,:,0,0],res2)
Out[168]: True

In [169]: timeit (C[:,:,None,:]@D[:,:,None])
6.63 µs ± 24.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

我可能可以用类似的技巧来用矩阵乘法完成第一个例子,利用一个虚拟的大小为1的维度进行求和乘积。

撰写回答