从numpy数组中删除多个切片

10 投票
3 回答
6331 浏览
提问于 2025-04-18 12:34

我有一个numpy数组,还有一个包含多个切片对象的列表(或者说是包含(开始, 结束)元组的列表)。我想从原始数组中去掉这些切片对象的位置,然后得到一个新的数组,里面是剩下的值。

举个简单的例子:

myarray = np.arange(20)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19])

mylist=(slice(2,4),slice(15,19))

做一些操作后,结果应该是

array([0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])

这个数组可能有几十万的大小,而切片对象的列表可能有几千个元素,我需要经常执行这个操作,所以速度对我来说很重要。

据我所知,numpy的删除功能似乎不支持直接使用切片列表?

目前我是在生成我的切片对象列表的补集,然后对那个进行切片,但生成补集的过程有点麻烦,我需要先对切片列表进行排序,然后逐个遍历,按需创建补集的切片对象。我希望能找到一种更优雅的方法,但我还没想到!

3 个回答

1

我想不出一个干净的方法来连接这些切片;不过,我觉得使用组合的方式是个不错的选择。也许可以试试这样做:

import numpy as np

# Create test data
n_data = 1000000
n_slices = 10000

data = np.arange(n_data)
slices = []
for i in range(n_slices):
    r = np.random.randint(n_data-1000)
    slices.append(slice(r,r + np.random.randint(1000)))

# Remove slices
keep_mask = np.ones_like(data, dtype=bool)
for slice in slices: keep_mask[slice] = False
data = data[keep_mask] # or np.take, etc.
2

你可以使用 np.r_[] 来把切片合并成一个数组:

myarray = np.arange(20)
mylist=(slice(2, 4),slice(15, 19))
np.delete(myarray, np.r_[tuple(mylist)])

输出结果:

array([ 0,  1,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 19])

不过我觉得这个方法速度不是很快。

1

你可以用 set() 来找出哪些位置会被保留,然后用 np.take() 来获取相应的值,像这样做:

ind = np.indices(myarray.shape)[0]
rm = np.hstack([ind[i] for i in mylist])

ans = np.take(myarray, sorted(set(ind)-set(rm)))

需要注意的是,np.hstack() 是用来把所有要移除的索引合并成一个数组的。这种方法大约只需要 @HYRY 方案一半的时间。

撰写回答