遍历numpy数组的前d个轴

7 投票
3 回答
3708 浏览
提问于 2025-04-18 16:52

我有一个数组,这个数组的维度(也就是轴)可以是任意数量的。我想要遍历前面的'd'个维度,应该怎么做呢?

一开始我想,我可以先创建一个数组,把我想要循环的所有索引放进去,使用

i = np.indices(a.shape[:d])
indices = np.transpose(np.asarray([x.flatten() for x in i]))
for idx in indices:
    a[idx]

但是看起来我不能这样用数组来索引,也就是说,不能用另一个数组来包含索引。

3 个回答

1

你可以把 a 重新调整形状,把第一个 d 维度压缩成一个维度:

for x in a.reshape(-1,*a.shape[d:]):
    print x

或者

aa=a.reshape(-1,*a.shape[d:])
for i in range(aa.shape[0]):
    print aa[i]

我们其实需要更多关于你想怎么使用 a[i] 的信息。


shx2 使用了 np.ndenumerate。这个函数的文档提到了 ndindex。你可以这样使用它:

for i in np.ndindex(a.shape[:d]):
    print i
    print a[i]

这里的 i 是一个元组。查看这些函数的 Python 代码会很有帮助。例如,ndindex 是使用 nditer 的。

1

写一个简单的递归函数:

import numpy as np

data = np.random.randint(0,10,size=4).reshape((1,1,1,1,2,2))

def recursiveIter(d, level=0):
    print level
    if level <= 2:
        for item in d:
            recursiveIter(item, level+1)
    else:
        print d

recursiveIter(data)

输出结果:

0
1
2
3
[[[2 5]
  [6 0]]]
8

你可以使用 ndindex 这个工具:

d = 2
a = np.random.random((2,3,4))
for i in np.ndindex(a.shape[:d]):
    print i, a[i]

输出结果:

(0, 0) [ 0.72730488  0.2349532   0.36569509  0.31244037]
(0, 1) [ 0.41738425  0.95999499  0.63935274  0.9403284 ]
(0, 2) [ 0.90690468  0.03741634  0.33483221  0.61093582]
(1, 0) [ 0.06716122  0.52632369  0.34441657  0.80678942]
(1, 1) [ 0.8612884   0.22792671  0.15628046  0.63269415]
(1, 2) [ 0.17770685  0.47955698  0.69038541  0.04838387]

撰写回答