numpy: 布尔索引和内存使用

7 投票
2 回答
1324 浏览
提问于 2025-04-16 17:26

考虑以下的 numpy 代码:

A[start:end] = B[mask]

这里:

  • AB 是两个维度相同的数组,它们的列数是一样的;
  • startend 是单个的数值;
  • mask 是一个一维的布尔数组,也就是只包含真或假的数组;
  • (end - start) == sum(mask) 这个条件是成立的。

原则上,上面的操作可以用 O(1) 的临时存储来完成,也就是说可以直接把 B 的元素复制到 A 中。

那么,实际上是这样做的吗?还是说 numpy 会为 B[mask] 创建一个临时数组?如果是后者,有没有办法通过重写语句来避免这个临时数组的创建呢?

2 个回答

3

这一行代码

A[start:end] = B[mask]

根据Python语言的定义,它会先计算右边的内容,生成一个新的数组,这个数组里包含了从B中选出的行,并且会占用额外的内存。为了避免这种情况,我知道的最有效的纯Python方法是使用一个显式的循环:

from itertools import izip, compress
for i, b in izip(range(start, end), compress(B, mask)):
    A[i] = b

当然,这种方法的执行时间会比你原来的代码慢很多,但它只会使用O(1)的额外内存。此外,itertools.compress()这个功能在Python 2.7或3.1及以上版本中是可以使用的。

2

使用布尔数组作为索引是一种高级的索引方式,因此numpy需要创建一个副本。如果你遇到内存问题,可以考虑写一个cython扩展来解决。

撰写回答