沿着指定轴打乱NumPy数组

32 投票
3 回答
31549 浏览
提问于 2025-04-16 12:04

给定下面这个NumPy数组,

> a = array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5],[1, 2, 3, 4, 5]])

我们可以很简单地打乱一行数据,

> shuffle(a[0])
> a
array([[4, 2, 1, 3, 5],[1, 2, 3, 4, 5],[1, 2, 3, 4, 5]])

那么有没有办法用索引的方式来独立打乱每一行呢?还是说必须一个一个地遍历这个数组?我想的方式是这样的,

> numpy.shuffle(a[:])
> a
array([[4, 2, 3, 5, 1],[3, 1, 4, 5, 2],[4, 2, 1, 3, 5]]) # Not the real output

不过这显然是行不通的。

3 个回答

11

对于最近关注这个问题的人来说,numpy 提供了一个叫做 permuted 的方法,可以在指定的轴上对数组进行独立的随机打乱。

以下是他们文档中的内容(使用 random.Generator

rng = np.random.default_rng()
x = np.arange(24).reshape(3, 8)
x
array([[ 0,  1,  2,  3,  4,  5,  6,  7],
       [ 8,  9, 10, 11, 12, 13, 14, 15],
       [16, 17, 18, 19, 20, 21, 22, 23]])

y = rng.permuted(x, axis=1)
y
array([[ 4,  3,  6,  7,  1,  2,  5,  0],  
       [15, 10, 14,  9, 12, 11,  8, 13],
       [17, 16, 20, 21, 18, 22, 23, 19]])
32

使用 rand+argsort 技巧的向量化解决方案

我们可以在指定的轴上生成唯一的索引,然后用这些索引去访问输入数组,这个过程叫做 高级索引。为了生成这些唯一的索引,我们会使用一种技巧,结合 随机生成浮点数和排序,这样就能得到一个向量化的解决方案。我们还会把这个方法推广到处理一般的 n维 数组,并且可以在不同的 上使用 np.take_along_axis。最终的实现大概是这样的 -

def shuffle_along_axis(a, axis):
    idx = np.random.rand(*a.shape).argsort(axis=axis)
    return np.take_along_axis(a,idx,axis=axis)

请注意,这个洗牌过程不会在原地进行,而是返回一个洗牌后的副本。

示例运行 -

In [33]: a
Out[33]: 
array([[18, 95, 45, 33],
       [40, 78, 31, 52],
       [75, 49, 42, 94]])

In [34]: shuffle_along_axis(a, axis=0)
Out[34]: 
array([[75, 78, 42, 94],
       [40, 49, 45, 52],
       [18, 95, 31, 33]])

In [35]: shuffle_along_axis(a, axis=1)
Out[35]: 
array([[45, 18, 33, 95],
       [31, 78, 52, 40],
       [42, 75, 94, 49]])
24

你需要多次调用 numpy.random.shuffle(),因为你要独立地打乱多个序列。numpy.random.shuffle() 可以在任何可变的序列上使用,并不是一个 ufunc。如果你想要最简洁、最高效的代码来分别打乱一个二维数组 a 的所有行,下面的代码可能是最好的选择:

list(map(numpy.random.shuffle, a))

有些人更喜欢用列表推导的方式来写这个:

[numpy.random.shuffle(x) for x in a]

撰写回答