从Numpy数组的每一行中选择随机样本,排除负数

3 投票
1 回答
2490 浏览
提问于 2025-04-16 00:38

我有一个Numpy数组,长得像这样:

>>> a
array([[ 3. ,  2. , -1. ],
       [-1. ,  0.1,  3. ],
       [-1. ,  2. ,  3.5]])

我想从每一行随机选择一个值,但我希望在随机选择时排除掉-1这个值。

我现在的做法是:

x=[]
for i in range(a.shape[0]):
    idx=numpy.where(a[i,:]>0)[0]
    idxr=random.sample(idx,1)[0]
    xi=a[i,idxr]
    x.append(xi)

然后得到:

>>> x
[3.0, 3.0, 2.0]

但是对于很大的数组,这种方法变得有点慢。我想知道有没有办法可以在不逐行处理的情况下,从原始的 a 矩阵中有条件地选择随机值。

1 个回答

3

我觉得在Numpy里找不到完全符合你要求的现成解决方案,所以我决定分享一些我想到的优化方法。

这里有几个可能导致速度慢的原因。首先,numpy.where()这个函数比较慢,因为它需要检查切片数组中的每一个值(每一行都会生成一个切片),然后再生成一个值的数组。如果你打算在同一个矩阵上反复进行这个操作,最好的办法是先对每一行进行排序。这样你就可以用二分查找来找到正值开始的位置,然后随机选择一个正值。当然,你也可以在第一次用二分查找找到正值的位置后,把这些位置的索引存起来,以后直接用。

如果你不打算多次进行这个操作,我建议你使用Cython来加速numpy.where这一行代码。Cython可以让你不需要切片行,从而整体加快处理速度。

最后,我的建议是使用random.choice而不是random.sample,除非你真的打算选择的样本数量大于1。

撰写回答