numpy: 布尔索引和内存使用
考虑以下的 numpy
代码:
A[start:end] = B[mask]
这里:
A
和B
是两个维度相同的数组,它们的列数是一样的;start
和end
是单个的数值;mask
是一个一维的布尔数组,也就是只包含真或假的数组;(end - start) == sum(mask)
这个条件是成立的。
原则上,上面的操作可以用 O(1)
的临时存储来完成,也就是说可以直接把 B
的元素复制到 A
中。
那么,实际上是这样做的吗?还是说 numpy
会为 B[mask]
创建一个临时数组?如果是后者,有没有办法通过重写语句来避免这个临时数组的创建呢?
2 个回答
3
这一行代码
A[start:end] = B[mask]
根据Python语言的定义,它会先计算右边的内容,生成一个新的数组,这个数组里包含了从B
中选出的行,并且会占用额外的内存。为了避免这种情况,我知道的最有效的纯Python方法是使用一个显式的循环:
from itertools import izip, compress
for i, b in izip(range(start, end), compress(B, mask)):
A[i] = b
当然,这种方法的执行时间会比你原来的代码慢很多,但它只会使用O(1)的额外内存。此外,itertools.compress()
这个功能在Python 2.7或3.1及以上版本中是可以使用的。
2
使用布尔数组作为索引是一种高级的索引方式,因此numpy需要创建一个副本。如果你遇到内存问题,可以考虑写一个cython扩展来解决。