从numpy数组中删除多个切片
我有一个numpy数组,还有一个包含多个切片对象的列表(或者说是包含(开始, 结束)
元组的列表)。我想从原始数组中去掉这些切片对象的位置,然后得到一个新的数组,里面是剩下的值。
举个简单的例子:
myarray = np.arange(20)
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19])
mylist=(slice(2,4),slice(15,19))
做一些操作后,结果应该是
array([0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
这个数组可能有几十万的大小,而切片对象的列表可能有几千个元素,我需要经常执行这个操作,所以速度对我来说很重要。
据我所知,numpy的删除功能似乎不支持直接使用切片列表?
目前我是在生成我的切片对象列表的补集,然后对那个进行切片,但生成补集的过程有点麻烦,我需要先对切片列表进行排序,然后逐个遍历,按需创建补集的切片对象。我希望能找到一种更优雅的方法,但我还没想到!
3 个回答
1
我想不出一个干净的方法来连接这些切片;不过,我觉得使用组合的方式是个不错的选择。也许可以试试这样做:
import numpy as np
# Create test data
n_data = 1000000
n_slices = 10000
data = np.arange(n_data)
slices = []
for i in range(n_slices):
r = np.random.randint(n_data-1000)
slices.append(slice(r,r + np.random.randint(1000)))
# Remove slices
keep_mask = np.ones_like(data, dtype=bool)
for slice in slices: keep_mask[slice] = False
data = data[keep_mask] # or np.take, etc.
2
你可以使用 np.r_[]
来把切片合并成一个数组:
myarray = np.arange(20)
mylist=(slice(2, 4),slice(15, 19))
np.delete(myarray, np.r_[tuple(mylist)])
输出结果:
array([ 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19])
不过我觉得这个方法速度不是很快。
1
你可以用 set()
来找出哪些位置会被保留,然后用 np.take()
来获取相应的值,像这样做:
ind = np.indices(myarray.shape)[0]
rm = np.hstack([ind[i] for i in mylist])
ans = np.take(myarray, sorted(set(ind)-set(rm)))
需要注意的是,np.hstack()
是用来把所有要移除的索引合并成一个数组的。这种方法大约只需要 @HYRY 方案一半的时间。