理解Python中的einsum

2 投票
2 回答
505 浏览
提问于 2025-04-18 10:02

我在用 einsum,因为它非常快,而且每次用的时候能让我省下2到3行代码。不过,我对它的理解有点困难。

我来举个例子:我想训练一个神经网络。为了计算与梯度相关的东西,我需要做以下操作:

假设有一个矩阵 W(第 i 行第 j 列表示从神经元 i 到下一层神经元 j 的连接权重),还有一个神经元输出的向量 S(希望下面的视觉示例能帮到你),我需要做的是:

S[i]*W[i,:] 这会生成一个新矩阵中的一行。

我发现下面的代码可以完成这个操作:

 einsum('ji,kj->ij',W,S)

现在我明白这个逻辑了,但我花了很长时间才搞明白。经历了很多尝试和错误(有些方法虽然能运行但结果是错的,有些则直接报错)。

现在我想一次性计算一批数据——也就是说,不再是向量 S,而是一个大小为 (NeuronsNum, BatchSize) 的矩阵,我想计算:

einsum('ji,kj->ij',W,S[:,b])

对于所有 b=0BatchSize-1。为了节省时间(也为了理解 einsum),我想一次性完成所有计算,并得到一个结果矩阵 (Neurons in layer l-1, Neurons in layer l, BatchSize)

我似乎一直搞不定这个。所以感谢你能读到这里,也希望能得到任何帮助来理解这个函数。

视觉示例:

enter image description here

i 个神经元发送一些值,这些值会根据它与目标之间的连接强度进行加权。我们对每个 i 都这样做。

2 个回答

1

感谢 usethedeathstar 的帮助,我终于找到了解决办法:

einsum('ki,jk->kij',W,S)

这段代码会给我一个三维数组,其中的结果 R 满足:

R[:,:,b]=einsum('ji,kj->ij',W,S[:,b])

他的建议是:写一个循环的代码,然后再把这个循环删掉,这样你就能得到索引了!

2

就像你不知道numpy可以做到这一点一样,直接用嵌套的for循环把它写出来,然后保留numpy einsum所需的索引。这样你就能把具体的公式写下来。

在你的例子中:

R = einsum('ki,jk->kij',W,S)

这会给你一个三维数组,结果R满足:

R[:,:,b] = einsum('ki,k->ki',W,S[:,b])

撰写回答