Python - 求和四维数组

5 投票
3 回答
2543 浏览
提问于 2025-04-18 13:57

给定一个 4D 数组 M: (m, n, r, r),我们该如何把所有的 m * n 个内部矩阵(形状为 (r, r))相加,得到一个新的形状为 (r * r) 的矩阵呢?

举个例子,

    M [[[[ 4,  1],
         [ 2,  1]],

        [[ 8,  2],
         [ 4,  2]]],

       [[[ 8,  2],
         [ 4,  2]],

        [[ 12, 3],
         [ 6,  3]]]]

我希望得到的结果是

      [[32, 8],
       [16, 8]]

3 个回答

1
import numpy as np
l = np.array([[[[ 4,  1],
                [ 2,  1]],
               [[ 8,  2],
                [ 4,  2]]],
              [[[ 8,  2],
                [ 4,  2]],
               [[12,  3],
                [ 6,  3]]]])
sum(sum(l))

输出

array([[32,  8],
       [16,  8]])
3

在最近的numpy版本(1.7或更新版本)中,最简单的方法就是这样做:

np.sum(M, axis=(0, 1))

这样做不会生成一个中间数组,而如果你重复调用np.sum的话,就会生成一个。

5

你可以使用 einsum 这个方法:

In [21]: np.einsum('ijkl->kl', M)
Out[21]: 
array([[32,  8],
       [16,  8]])

还有其他选择,比如把前两个维度合并成一个维度,然后再调用 sum 方法:

In [24]: M.reshape(-1, 2, 2).sum(axis=0)
Out[24]: 
array([[32,  8],
       [16,  8]])

或者调用两次 sum 方法:

In [26]: M.sum(axis=0).sum(axis=0)
Out[26]: 
array([[32,  8],
       [16,  8]])

不过使用 np.einsum 会更快:

In [22]: %timeit np.einsum('ijkl->kl', M)
100000 loops, best of 3: 2.42 µs per loop

In [25]: %timeit M.reshape(-1, 2, 2).sum(axis=0)
100000 loops, best of 3: 5.69 µs per loop

In [43]: %timeit np.sum(M, axis=(0,1))
100000 loops, best of 3: 6.08 µs per loop

In [33]: %timeit sum(sum(M))
100000 loops, best of 3: 8.18 µs per loop

In [27]: %timeit M.sum(axis=0).sum(axis=0)
100000 loops, best of 3: 9.83 µs per loop

需要注意的是:使用 timeit 测试性能时,结果可能会因为很多因素而有很大差异(比如操作系统、NumPy 版本、NumPy 库、硬件等)。不同方法的相对性能有时也会依赖于 M 的大小。因此,最好在一个更接近你实际使用情况的 M 上进行自己的性能测试。

举个例子,对于稍微大一点的数组 M,调用 sum 方法两次可能是最快的:

In [34]: M = np.random.random((100,100,2,2))

In [37]: %timeit M.sum(axis=0).sum(axis=0)
10000 loops, best of 3: 59.9 µs per loop

In [39]: %timeit np.einsum('ijkl->kl', M)
10000 loops, best of 3: 99 µs per loop

In [40]: %timeit np.sum(M, axis=(0,1))
10000 loops, best of 3: 182 µs per loop

In [36]: %timeit M.reshape(-1, 2, 2).sum(axis=0)
10000 loops, best of 3: 184 µs per loop

In [38]: %timeit sum(sum(M))
1000 loops, best of 3: 202 µs per loop

撰写回答