返回数组中子数组的索引
我在用Python编程,并且使用了一个叫做 numpy
的库。
我有一个名为 may_a
的数组:
may_a = numpy.array([False, True, False, True, True, False, True, False, True, True, False])
我还有一个名为 may_b
的数组:
may_b = numpy.array([False,True,True,False])
我需要在数组 may_a
中找到数组 may_b
。
输出时,我想要得到这些出现位置的索引。
out_index=[2,7]
有人能告诉我,我该如何得到 out_index
呢?
5 个回答
3
这看起来很像一个字符串搜索问题。如果你不想自己实现这些字符串搜索算法,可以利用Python自带的字符串搜索功能,它非常快,你可以这样做:
# I've added [True, True, True] at the end.
may_a = numpy.array([False, True, False, True, True, False, True, False, True, True, False, True, True, True])
may_b = numpy.array([False,True,True,False])
may_a_str = may_a.tostring()
may_b_str = may_b.tostring()
idx = may_a_str.find(may_b_str)
out_index = []
while idx >= 0:
out_index.append(idx)
idx = may_a_str.find(may_b_str, idx+1)
这个方法对于布尔数组应该没问题。如果你想用这种方法处理其他类型的数组,你需要确保这两个数组的步长是匹配的,并且要把out_index除以那个步长。
你也可以使用正则表达式模块来代替循环进行字符串搜索。
5
有一种更酷的方法,虽然性能可能不太好,但它适用于任何数据类型,那就是使用 as_strided
:
In [2]: from numpy.lib.stride_tricks import as_strided
In [3]: may_a = numpy.array([False, True, False, True, True, False,
...: True, False, True, True, False])
In [4]: may_b = numpy.array([False,True,True,False])
In [5]: a = len(may_a)
In [6]: b = len(may_b)
In [7]: a_view = as_strided(may_a, shape=(a - b + 1, b),
...: strides=(may_a.dtype.itemsize,) * 2)
In [8]: a_view
Out[8]:
array([[False, True, False, True],
[ True, False, True, True],
[False, True, True, False],
[ True, True, False, True],
[ True, False, True, False],
[False, True, False, True],
[ True, False, True, True],
[False, True, True, False]], dtype=bool)
In [9]: numpy.where(numpy.all(a_view == may_b, axis=1))[0]
Out[9]: array([2, 7])
不过你得小心,因为尽管 a_view
是 may_a
数据的一个视图,但在和 may_b
比较时,会创建一个临时数组,大小是 (a - b + 1) * b
,这在处理大数据时可能会出现问题。
5
编辑 以下代码可以用来进行卷积的相等性检查。它把 True
映射为 1
,把 False
映射为 -1
。同时,它还会反转 b
,这是为了让它正常工作:
def search(a, b) :
return np.where(np.round(fftconvolve(a * 2 - 1, (b * 2 - 1)[::-1],
mode='valid') - len(b)) == 0)[0]
我已经检查过,它在多种随机输入下的输出和 as_strided
方法是一样的,确实如此。我还对这两种方法的速度进行了测试,发现卷积方法在处理大约256个项目的搜索时才开始显现优势。
虽然看起来有点复杂,但对于布尔数据,你可以使用(甚至滥用)卷积:
In [8]: np.where(np.convolve(may_a, may_b.astype(int),
...: mode='valid') == may_b.sum())[0]
Out[8]: array([2, 7])
对于更大的数据集,使用 scipy.signal.fftconvolve
可能会更快:
In [13]: np.where(scipy.signal.fftconvolve(may_a, may_b,
....: mode='valid') == may_b.sum())[0]
Out[13]: array([2, 7])
不过要小心,因为现在的输出是浮点数,四舍五入可能会影响相等性检查:
In [14]: scipy.signal.fftconvolve(may_a, may_b, mode='valid')
Out[14]: array([ 1., 1., 2., 1., 1., 1., 1., 2.])
所以你可能更适合使用类似于以下的方式:
In [15]: np.where(np.round(scipy.signal.fftconvolve(may_a, may_b, mode='valid') -
....: may_b.sum()) == 0)[0]
Out[15]: array([2, 7])