理解努比的信仰

2024-06-11 12:07:25 发布

您现在位置:Python中文网/ 问答频道 /正文

我很难确切地理解einsum是如何工作的。我看过文档和一些例子,但似乎不太好用。

下面是我们在课堂上复习的一个例子:

C = np.einsum("ij,jk->ki", A, B)

对于两个数组AB

我想这需要A^T * B,但我不确定(这需要其中一个的转位对吧?)。有人能告诉我这里到底发生了什么吗(通常在使用einsum时)?


Tags: 文档np数组例子课堂jkijki
3条回答

让我们制作两个数组,用不同但兼容的维度来突出它们之间的相互作用

In [43]: A=np.arange(6).reshape(2,3)
Out[43]: 
array([[0, 1, 2],
       [3, 4, 5]])


In [44]: B=np.arange(12).reshape(3,4)
Out[44]: 
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

你的计算,取一个(2,3)和一个(3,4)的“点”(乘积之和)来产生一个(4,2)数组。iA的第一个维度,是C的最后一个维度;kB的最后一个维度,是C的第一个维度。j被总和“消耗”。

In [45]: C=np.einsum('ij,jk->ki',A,B)
Out[45]: 
array([[20, 56],
       [23, 68],
       [26, 80],
       [29, 92]])

这与np.dot(A,B).T相同-这是最后一个被转置的输出。

要查看j发生的更多情况,请将C下标更改为ijk

In [46]: np.einsum('ij,jk->ijk',A,B)
Out[46]: 
array([[[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [16, 18, 20, 22]],

       [[ 0,  3,  6,  9],
        [16, 20, 24, 28],
        [40, 45, 50, 55]]])

这也可以通过以下方式产生:

A[:,:,None]*B[None,:,:]

也就是说,在A的末尾添加一个k维度,在B的前面添加一个i维度,得到一个(2,3,4)数组。

0 + 4 + 16 = 209 + 28 + 55 = 92等;对j求和并转置以获得早期结果:

np.sum(A[:,:,None] * B[None,:,:], axis=1).T

# C[k,i] = sum(j) A[i,j (,k) ] * B[(i,)  j,k]

如果你能直观地理解^{}的概念,那么理解它是非常容易的。作为一个例子,让我们从一个简单的描述开始,涉及矩阵乘法。


要使用^{},您只需将所谓的下标字符串作为参数传递,然后传递输入数组

假设你有两个二维数组,分别是AB,你想做矩阵乘法。所以,你需要:

np.einsum("ij, jk -> ik", A, B)

这里,下标字符串ij对应于数组A,而下标字符串jk对应于数组B。另外,这里要注意的最重要的一点是,每个下标字符串中的字符数必须与数组的维数匹配。(也就是说,二维数组有两个字符,三维数组有三个字符,等等)如果在下标字符串之间重复这些字符(在我们的例子中,是指在下标字符串之间重复这些字符),那么就意味着希望沿着这些维度发生einsum。这样,它们的总和就会减少。(即,该维度将消失)

在这个->之后的下标字符串将是我们的结果数组。 如果将其保留为空,则将对所有内容进行求和,并作为结果返回标量值。否则生成的数组将根据下标字符串具有维数。在我们的示例中,它将是ik。这是直观的,因为我们知道,对于矩阵乘法,数组中的列数必须与数组中的行数相匹配,这就是这里所发生的情况(即,我们通过重复下标字符串中的charj


这里有更多的例子,简明扼要地说明了^{}在实现一些常见的张量nd数组操作时的使用/能力。

输入

# a vector
In [197]: vec
Out[197]: array([0, 1, 2, 3])

# an array
In [198]: A
Out[198]: 
array([[11, 12, 13, 14],
       [21, 22, 23, 24],
       [31, 32, 33, 34],
       [41, 42, 43, 44]])

# another array
In [199]: B
Out[199]: 
array([[1, 1, 1, 1],
       [2, 2, 2, 2],
       [3, 3, 3, 3],
       [4, 4, 4, 4]])

1)矩阵乘法(类似于np.matmul(arr1, arr2)

In [200]: np.einsum("ij, jk -> ik", A, B)
Out[200]: 
array([[130, 130, 130, 130],
       [230, 230, 230, 230],
       [330, 330, 330, 330],
       [430, 430, 430, 430]])

2)沿着主对角线提取元素(类似于np.diag(arr)

In [202]: np.einsum("ii -> i", A)
Out[202]: array([11, 22, 33, 44])

3)Hadamard积(即两个数组的元素积)(类似于arr1 * arr2

In [203]: np.einsum("ij, ij -> ij", A, B)
Out[203]: 
array([[ 11,  12,  13,  14],
       [ 42,  44,  46,  48],
       [ 93,  96,  99, 102],
       [164, 168, 172, 176]])

4)按元素的平方(类似于np.square(arr)arr ** 2

In [210]: np.einsum("ij, ij -> ij", B, B)
Out[210]: 
array([[ 1,  1,  1,  1],
       [ 4,  4,  4,  4],
       [ 9,  9,  9,  9],
       [16, 16, 16, 16]])

5)痕量(即主要对角线元素之和)(类似于np.trace(arr)

In [217]: np.einsum("ii -> ", A)
Out[217]: 110

6)矩阵转置(类似于np.transpose(arr)

In [221]: np.einsum("ij -> ji", A)
Out[221]: 
array([[11, 21, 31, 41],
       [12, 22, 32, 42],
       [13, 23, 33, 43],
       [14, 24, 34, 44]])

7)外积(向量的)(类似于np.outer(vec1, vec2)

In [255]: np.einsum("i, j -> ij", vec, vec)
Out[255]: 
array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]])

8)内积(矢量)(类似于np.inner(vec1, vec2)

In [256]: np.einsum("i, i -> ", vec, vec)
Out[256]: 14

9)沿0轴求和(类似于np.sum(arr, axis=0)

In [260]: np.einsum("ij -> j", B)
Out[260]: array([10, 10, 10, 10])

10)沿轴1求和(类似于np.sum(arr, axis=1)

In [261]: np.einsum("ij -> i", B)
Out[261]: array([ 4,  8, 12, 16])

11)批量矩阵乘法

In [287]: BM = np.stack((A, B), axis=0)

In [288]: BM
Out[288]: 
array([[[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]],

       [[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3],
        [ 4,  4,  4,  4]]])

In [289]: BM.shape
Out[289]: (2, 4, 4)

# batch matrix multiply using einsum
In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)

In [293]: BMM
Out[293]: 
array([[[1350, 1400, 1450, 1500],
        [2390, 2480, 2570, 2660],
        [3430, 3560, 3690, 3820],
        [4470, 4640, 4810, 4980]],

       [[  10,   10,   10,   10],
        [  20,   20,   20,   20],
        [  30,   30,   30,   30],
        [  40,   40,   40,   40]]])

In [294]: BMM.shape
Out[294]: (2, 4, 4)

12)沿轴2求和(类似于np.sum(arr, axis=2)

In [330]: np.einsum("ijk -> ij", BM)
Out[330]: 
array([[ 50,  90, 130, 170],
       [  4,   8,  12,  16]])

13)对数组中的所有元素求和(类似于np.sum(arr)

In [335]: np.einsum("ijk -> ", BM)
Out[335]: 480

14)多轴求和(即边缘化)
(类似于np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7))

# 8D array
In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))

# marginalize out axis 5 (i.e. "n" here)
In [363]: esum = np.einsum("ijklmnop -> n", R)

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))

In [365]: np.allclose(esum, nsum)
Out[365]: True

15)Double Dot Products(类似于np.sum(hadamard乘积)cf.3

In [772]: A
Out[772]: 
array([[1, 2, 3],
       [4, 2, 2],
       [2, 3, 4]])

In [773]: B
Out[773]: 
array([[1, 4, 7],
       [2, 5, 8],
       [3, 6, 9]])

In [774]: np.einsum("ij, ij -> ", A, B)
Out[774]: 124

16)二维和三维阵列乘法

这种乘法在求解线性方程组时非常有用(Ax=b)要验证结果的位置。

# inputs
In [115]: A = np.random.rand(3,3)
In [116]: b = np.random.rand(3, 4, 5)

# solve for x
In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)

# 2D and 3D array multiplication :)
In [118]: Ax = np.einsum('ij, jkl', A, x)

# indeed the same!
In [119]: np.allclose(Ax, b)
Out[119]: True

相反,如果必须使用^{}进行此验证,则必须执行两次reshape操作才能获得相同的结果,如:

# reshape 3D array `x` to 2D, perform matmul
# then reshape the resultant array to 3D
In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)

# indeed correct!
In [124]: np.allclose(Ax, Ax_matmul)
Out[124]: True

奖励:在这里阅读更多数学知识:Einstein-Summation当然也在这里阅读:Tensor-Notation

(注意:这个答案是基于我不久前写的关于einsum的一个简短的blog post。)

einsum做什么?

假设我们有两个多维数组,AB。现在假设我们想。。。

  • 以特定方式将AB相乘,以创建新的产品阵列;然后可能
  • 这个新数组沿着特定的轴;然后可能
  • 将新数组的轴按特定顺序进行转置。

很有可能einsum将帮助我们更快、更高效地完成这项工作,而像multiplysumtranspose这样的NumPy函数的组合将允许这样做。

einsum如何工作?

下面是一个简单(但不是完全琐碎)的例子。使用以下两个数组:

A = np.array([0, 1, 2])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

我们将AB元素相乘,然后沿着新数组的行求和。在“正常”的数字里我们会写:

>>> (A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])

因此,在这里,A上的索引操作将两个数组的第一个轴排成一行,以便可以广播乘法运算。然后对产品数组的行求和以返回答案。

如果我们想用einsum代替,我们可以写:

>>> np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])

签名字符串'i,ij->i'是这里的关键,需要一点解释。你可以把它分成两半。在左边(在->的左边),我们已经标记了两个输入数组。在->的右边,我们已经标记了要结束的数组。

下面是接下来发生的事情:

  • A有一个轴;我们将其标记为i。并且B有两个轴;我们将轴0标记为i,轴1标记为j

  • 通过在两个输入数组中重复标签i,我们告诉einsum这两个轴应该一起相乘。换句话说,我们将数组A与数组B的每一列相乘,就像A[:, np.newaxis] * B那样。

  • 注意,j并没有作为标签出现在我们想要的输出中;我们刚刚使用了i(我们希望以1D数组结束)。通过省略标签,我们告诉einsum沿着这个轴的和。换句话说,我们对产品行求和,就像.sum(axis=1)那样。

这就是使用einsum所需的基本知识。这有助于发挥一点;如果我们在输出中保留两个标签,'i,ij->ij',我们将得到一个二维产品数组(与A[:, np.newaxis] * B相同)。如果我们说没有输出标签,'i,ij->,我们将返回一个数字(与执行(A[:, np.newaxis] * B).sum()相同)。

然而,关于einsum的好处是,它不会首先构建产品的临时数组;它只是对产品进行汇总。这可以大大节省内存使用。

稍大一点的例子

为了解释点积,这里有两个新数组:

A = array([[1, 1, 1],
           [2, 2, 2],
           [5, 5, 5]])

B = array([[0, 1, 0],
           [1, 1, 0],
           [1, 1, 1]])

我们将使用np.einsum('ij,jk->ik', A, B)计算点积。下面的图片显示了AB的标签以及我们从函数中获得的输出数组:

enter image description here

你可以看到标签j是重复的-这意味着我们将A的行与B的列相乘。此外,标签j不包含在输出中-我们正在对这些产品求和。标签ik保留用于输出,因此我们得到一个2D数组。

将这个结果与标签j而不是总和的数组进行比较可能更清楚。在下面的左侧,您可以看到由于写入np.einsum('ij,jk->ijk', A, B)而产生的3D数组(即,我们保留了标签j):

enter image description here

求和轴j给出期望的点积,如右图所示。

一些练习

为了更好地了解einsum,实现熟悉的NumPy是很有用的使用下标符号的数组操作。任何涉及乘法和求和轴组合的内容都可以使用einsum来编写。

设A和B是两个长度相同的一维数组。例如,A = np.arange(10)B = np.arange(5, 15)

  • 可以写入A的和:

    np.einsum('i->', A)
    
  • 元素乘法A * B,可以写入:

    np.einsum('i,i->i', A, B)
    
  • 内积或点积np.inner(A, B)np.dot(A, B)可以写为:

    np.einsum('i,i->', A, B) # or just use 'i,i'
    
  • 外部产品np.outer(A, B)可以写入:

    np.einsum('i,j->ij', A, B)
    

对于二维数组CD,如果轴是兼容的长度(相同长度或其中一个轴的长度为1),下面是几个示例:

  • 可写入C(主对角线之和)的轨迹np.trace(C)

    np.einsum('ii', C)
    
  • 可以写入C的元素乘法和DC * D.T的转置:

    np.einsum('ij,ji->ij', C, D)
    
  • C的每个元素乘以数组D(构成4D数组),可以写入C[:, :, None, None] * D

    np.einsum('ij,kl->ijkl', C, D)  
    

相关问题 更多 >