Python/NumPy子数组首次出现

39 投票
10 回答
35105 浏览
提问于 2025-04-16 23:49

在Python或NumPy中,找出一个子数组第一次出现的最佳方法是什么?

比如,我有

a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]

有没有什么快速的方法(从运行时间上看)来找出b在a中的位置?我知道对于字符串来说,这个很简单,但对于列表或numpy的ndarray呢?

非常感谢!

[编辑] 我更喜欢使用numpy的解决方案,因为根据我的经验,numpy的向量化操作比Python的列表推导要快得多。同时,那个大数组非常大,所以我不想把它转换成字符串;那样会太长。

10 个回答

24

(编辑:增加了更深入的讨论,更好的代码和更多的基准测试)


总结

为了追求速度和效率,可以使用 Cython 或 Numba 加速的版本(当输入是 Python 序列或 NumPy 数组时)来实现经典算法。

推荐的方法有:

  • find_kmp_cy() 用于 Python 序列(listtuple 等)
  • find_kmp_nb() 用于 NumPy 数组

其他高效的方法有 find_rk_cy()find_rk_nb(),它们在内存使用上更高效,但不一定能保证线性时间运行。

如果没有 Cython 或 Numba,那么 find_kmp()find_rk() 仍然是大多数情况下不错的解决方案,尽管在平均情况下,对于 Python 序列,简单的方法(如 find_pivot())可能会更快。对于 NumPy 数组,find_conv()(来自 @Jaime 的回答)的表现优于任何未加速的简单方法。

(完整代码在下面,这里那里可以找到。)


理论

这是计算机科学中的一个经典问题,称为字符串搜索或字符串匹配问题。简单的方法基于两个嵌套循环,平均计算复杂度为 O(n + m),但最坏情况下为 O(n m)。多年来,已经开发出多种 替代方法,可以保证更好的最坏情况性能。

在经典算法中,最适合通用序列的算法(因为它们不依赖于字母表)有:

最后一种算法依赖于计算 滚动哈希 来提高效率,因此可能需要对输入有一些额外的了解,以获得最佳性能。最终,它最适合同质数据,例如数字数组。在 Python 中,NumPy 数组就是一个显著的例子。

备注

  • 简单算法由于其简单性,可以在 Python 中有不同的实现,运行速度各异。
  • 其他算法在通过语言技巧进行优化时灵活性较低。
  • 在 Python 中显式循环可能成为速度瓶颈,可以使用多种技巧在解释器外部执行循环。
  • Cython 特别擅长加速通用 Python 代码中的显式循环。
  • Numba 特别擅长加速 NumPy 数组中的显式循环。
  • 这是生成器的一个优秀用例,因此所有代码将使用生成器而不是常规函数。

Python 序列(listtuple 等)

基于简单算法

  • find_loop()find_loop_cy()find_loop_nb() 分别是纯 Python、Cython 和使用 Numba JIT 的显式循环实现。注意 Numba 版本中的 forceobj=True,这是因为我们使用了 Python 对象作为输入。
def find_loop(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        found = True
        for j in range(m):
            if seq[i + j] != subseq[j]:
                found = False
                break
        if found:
            yield i
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_loop_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    for i in range(n - m + 1):
        found = True
        for j in range(m):
            if seq[i + j] != subseq[j]:
                found = False
                break
        if found:
            yield i
find_loop_nb = nb.jit(find_loop, forceobj=True)
find_loop_nb.__name__ = 'find_loop_nb'
  • find_all()all() 替代了内循环,使用了一个生成器表达式。
def find_all(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if all(seq[i + j] == subseq[j] for j in range(m)):
            yield i
  • find_slice() 在切片 [] 后用直接比较 == 替代了内循环。
def find_slice(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i:i + m] == subseq:
            yield i
  • find_mix()find_mix2() 在切片 [] 后用直接比较 == 替代了内循环,但在第一个(和最后一个)字符上增加了一两个短路判断,这可能更快,因为用 int 切片比用 slice() 切片要快得多。
def find_mix(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i] == subseq[0] and seq[i:i + m] == subseq:
            yield i
def find_mix2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i] == subseq[0] and seq[i + m - 1] == subseq[m - 1] \
                and seq[i:i + m] == subseq:
            yield i
  • find_pivot()find_pivot2() 用多个 .index() 调用替代了外循环,使用子序列的第一个项目,同时在内循环中使用切片,最终在最后一个项目上增加了额外的短路判断(第一个匹配是通过构造保证的)。多个 .index() 调用被封装在 index_all() 生成器中(这本身可能有用)。
def index_all(seq, item, start=0, stop=-1):
    try:
        n = len(seq)
        if n > 0:
            start %= n
            stop %= n
            i = start
            while True:
                i = seq.index(item, i)
                if i <= stop:
                    yield i
                    i += 1
                else:
                    return
        else:
            return
    except ValueError:
        pass


def find_pivot(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    for i in index_all(seq, subseq[0], 0, n - m):
        if seq[i:i + m] == subseq:
            yield i
def find_pivot2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    for i in index_all(seq, subseq[0], 0, n - m):
        if seq[i + m - 1] == subseq[m - 1] and seq[i:i + m] == subseq:
            yield i

基于 Knuth–Morris–Pratt (KMP) 算法

  • find_kmp() 是该算法的普通 Python 实现。由于没有简单的循环或可以用 slice() 切片的地方,因此优化的空间不大,除了使用 Cython(Numba 需要再次使用 forceobj=True,这会导致代码变慢)。
def find_kmp(seq, subseq):
    n = len(seq)
    m = len(subseq)
    # : compute offsets
    offsets = [0] * m
    j = 1
    k = 0
    while j < m: 
        if subseq[j] == subseq[k]: 
            k += 1
            offsets[j] = k
            j += 1
        else: 
            if k != 0: 
                k = offsets[k - 1] 
            else: 
                offsets[j] = 0
                j += 1
    # : find matches
    i = j = 0
    while i < n: 
        if seq[i] == subseq[j]: 
            i += 1
            j += 1
        if j == m:
            yield i - j
            j = offsets[j - 1] 
        elif i < n and seq[i] != subseq[j]: 
            if j != 0: 
                j = offsets[j - 1] 
            else: 
                i += 1
  • find_kmp_cy() 是该算法的 Cython 实现,其中索引使用 C int 数据类型,从而使代码运行得更快。
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_kmp_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    # : compute offsets
    offsets = [0] * m
    cdef Py_ssize_t j = 1
    cdef Py_ssize_t k = 0
    while j < m: 
        if subseq[j] == subseq[k]: 
            k += 1
            offsets[j] = k
            j += 1
        else: 
            if k != 0: 
                k = offsets[k - 1] 
            else: 
                offsets[j] = 0
                j += 1
    # : find matches
    cdef Py_ssize_t i = 0
    j = 0
    while i < n: 
        if seq[i] == subseq[j]: 
            i += 1
            j += 1
        if j == m:
            yield i - j
            j = offsets[j - 1] 
        elif i < n and seq[i] != subseq[j]: 
            if j != 0: 
                j = offsets[j - 1] 
            else: 
                i += 1

基于 Rabin-Karp (RK) 算法

  • find_rk() 是纯 Python 实现,依赖于 Python 的 hash() 进行哈希的计算(和比较)。这种哈希通过简单的 sum() 实现滚动。然后通过从之前的哈希中减去刚访问的项目 seq[i - 1]hash() 结果,并加上新考虑的项目 seq[i + m - 1]hash() 结果来计算滚动。
def find_rk(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if seq[:m] == subseq:
        yield 0
    hash_subseq = sum(hash(x) for x in subseq)  # compute hash
    curr_hash = sum(hash(x) for x in seq[:m])  # compute hash
    for i in range(1, n - m + 1):
        curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1])   # update hash
        if hash_subseq == curr_hash and seq[i:i + m] == subseq:
            yield i
  • find_rk_cy() 是该算法的 Cython 实现,其中索引使用适当的 C 数据类型,从而使代码运行得更快。注意 hash() 会根据主机的位宽截断返回值。
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_rk_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    if seq[:m] == subseq:
        yield 0
    cdef Py_ssize_t hash_subseq = sum(hash(x) for x in subseq)  # compute hash
    cdef Py_ssize_t curr_hash = sum(hash(x) for x in seq[:m])  # compute hash
    cdef Py_ssize_t old_item, new_item
    for i in range(1, n - m + 1):
        old_item = hash(seq[i - 1])
        new_item = hash(seq[i + m - 1])
        curr_hash += new_item - old_item  # update hash
        if hash_subseq == curr_hash and seq[i:i + m] == subseq:
            yield i

基准测试

上述函数在两个输入上进行评估:

  • 随机输入
def gen_input(n, k=2):
    return tuple(random.randint(0, k - 1) for _ in range(n))
  • (几乎)最坏情况下的输入,针对简单算法
def gen_input_worst(n, k=-2):
    result = [0] * n
    result[k] = 1
    return tuple(result)

subseq 的大小是固定的(32)。由于有很多替代方案,因此进行了两个单独的分组,并省略了一些变化非常小且几乎相同的时间的解决方案(即 find_mix2()find_pivot2())。对于每个组,两个输入都进行了测试。对于每个基准测试,提供了完整的图表和对最快方法的放大图。

随机情况下的简单算法

bm_full_naive_random bm_zoom_naive_random

最坏情况下的简单算法

bm_full_naive_worst bm_zoom_naive_worst

随机情况下的其他算法

bm_full_other_random bm_zoom_other_random

最坏情况下的其他算法

bm_full_other_worst bm_zoom_other_worst

(完整代码在 这里 可以找到。)


NumPy 数组

基于简单算法

  • find_loop()find_loop_cy()find_loop_nb() 分别是纯 Python、Cython 和使用 Numba JIT 的显式循环实现。前两个的代码与上面相同,因此省略。find_loop_nb() 现在享受快速的 JIT 编译。内循环被写在一个单独的函数中,因为它可以被重用到 find_rk_nb() 中(在 Numba 函数内部调用 Numba 函数不会产生 Python 的函数调用惩罚)。
@nb.jit
def _is_equal_nb(seq, subseq, m, i):
    for j in range(m):
        if seq[i + j] != subseq[j]:
            return False
    return True


@nb.jit
def find_loop_nb(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if _is_equal_nb(seq, subseq, m, i):
            yield i
  • find_all() 与上面相同,而 find_slice()find_mix()find_mix2() 与上面几乎相同,唯一的区别是 seq[i:i + m] == subseq 现在是 np.all() 的参数:np.all(seq[i:i + m] == subseq)

  • find_pivot()find_pivot2() 与上面相同,只是现在使用 np.where() 代替 index_all(),并且需要将数组相等性封装在 np.all() 调用中。

def find_pivot(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    max_i = n - m
    for i in np.where(seq == subseq[0])[0]:
        if i > max_i:
            return
        elif np.all(seq[i:i + m] == subseq):
            yield i


def find_pivot2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    max_i = n - m
    for i in np.where(seq == subseq[0])[0]:
        if i > max_i:
            return
        elif seq[i + m - 1] == subseq[m - 1] \
                and np.all(seq[i:i + m] == subseq):
            yield i
  • find_rolling() 通过滚动窗口表达循环,匹配通过 np.all() 检查。这将所有循环向量化,但代价是创建大型临时对象,同时仍然基本应用简单算法。(该方法来自 @senderle 的回答)。
def rolling_window(arr, size):
    shape = arr.shape[:-1] + (arr.shape[-1] - size + 1, size)
    strides = arr.strides + (arr.strides[-1],)
    return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)


def find_rolling(seq, subseq):
    bool_indices = np.all(rolling_window(seq, len(subseq)) == subseq, axis=1)
    yield from np.mgrid[0:len(bool_indices)][bool_indices]
  • find_rolling2() 是上述方法的稍微更节省内存的变体,其中向量化仅部分进行,并保留一个显式循环(沿着预期的最短维度——subseq 的长度)。(该方法同样来自 @senderle 的回答)。
def find_rolling2(seq, subseq):
    windows = rolling_window(seq, len(subseq))
    hits = np.ones((len(seq) - len(subseq) + 1,), dtype=bool)
    for i, x in enumerate(subseq):
        hits &= np.in1d(windows[:, i], [x])
    yield from hits.nonzero()[0]

基于 Knuth–Morris–Pratt (KMP) 算法

  • find_kmp() 与上面相同,而 find_kmp_nb() 是该算法的直接 JIT 编译版本。
find_kmp_nb = nb.jit(find_kmp)
find_kmp_nb.__name__ = 'find_kmp_nb'

基于 Rabin-Karp (RK) 算法

  • find_rk() 与上面相同,只是 seq[i:i + m] == subseq 再次被封装在 np.all() 调用中。

  • find_rk_nb() 是上述的 Numba 加速版本。使用 _is_equal_nb() 来确定匹配,同时对于哈希,使用 Numba 加速的 sum_hash_nb() 函数,其定义非常简单。

@nb.jit
def sum_hash_nb(arr):
    result = 0
    for x in arr:
        result += hash(x)
    return result


@nb.jit
def find_rk_nb(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if _is_equal_nb(seq, subseq, m, 0):
        yield 0
    hash_subseq = sum_hash_nb(subseq)  # compute hash
    curr_hash = sum_hash_nb(seq[:m])  # compute hash
    for i in range(1, n - m + 1):
        curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1])  # update hash
        if hash_subseq == curr_hash and _is_equal_nb(seq, subseq, m, i):
            yield i
  • find_conv() 使用伪 Rabin-Karp 方法,其中初始候选项通过 np.dot() 乘积进行哈希,并通过 np.where()seqsubseq 之间的卷积中定位。该方法是伪的,因为虽然它仍然使用哈希来识别可能的候选项,但可能不被视为滚动哈希(这取决于 np.correlate() 的实际实现)。此外,它需要创建一个与输入大小相同的临时数组。(该方法来自 @Jaime 的回答)。
def find_conv(seq, subseq):
    target = np.dot(subseq, subseq)
    candidates = np.where(np.correlate(seq, subseq, mode='valid') == target)[0]
    check = candidates[:, np.newaxis] + np.arange(len(subseq))
    mask = np.all((np.take(seq, check) == subseq), axis=-1)
    yield from candidates[mask]

基准测试

  • 随机输入
def gen_input(n, k=2):
    return np.random.randint(0, k, n)
  • (几乎)最坏情况下的输入,针对简单算法
def gen_input_worst(n, k=-2):
    result = np.zeros(n, dtype=int)
    result[k] = 1
    return result

subseq 的大小是固定的(32)。这些图表遵循与之前相同的方案,方便起见总结如下。

由于有很多替代方案,因此进行了两个单独的分组,并省略了一些变化非常小且几乎相同的时间的解决方案(即 find_mix2()find_pivot2())。对于每个组,两个输入都进行了测试。对于每个基准测试,提供了完整的图表和对最快方法的放大图。

随机情况下的简单算法

bm_full_naive_random bm_zoom_naive_random

最坏情况下的简单算法

bm_full_naive_worst bm_zoom_naive_worst

随机情况下的其他算法

bm_full_other_random bm_zoom_other_random

最坏情况下的其他算法

bm_full_other_worst bm_zoom_other_worst

(完整代码在 这里 可以找到。)

25

下面的代码应该可以正常运行:

[x for x in xrange(len(a)) if a[x:x+len(b)] == b]

它会返回模式开始的位置索引。

26

我猜你是在找一个专门针对numpy的解决方案,而不是简单的列表推导式或者for循环。一个直接的方法是使用滑动窗口技术来查找合适大小的窗口。

这个方法简单,效果很好,而且比纯Python的解决方案快得多。对于很多使用场景来说,这个方法已经足够了。不过,它并不是最有效的解决方案,原因有很多。如果你想要一个更复杂但在预期情况下更优的方案,可以看看基于numba滚动哈希实现,具体可以参考norok2的回答

下面是滑动窗口的函数:

>>> def rolling_window(a, size):
...     shape = a.shape[:-1] + (a.shape[-1] - size + 1, size)
...     strides = a.strides + (a. strides[-1],)
...     return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
... 

然后你可以做类似这样的操作:

>>> a = numpy.arange(10)
>>> numpy.random.shuffle(a)
>>> a
array([7, 3, 6, 8, 4, 0, 9, 2, 1, 5])
>>> rolling_window(a, 3) == [8, 4, 0]
array([[False, False, False],
       [False, False, False],
       [False, False, False],
       [ True,  True,  True],
       [False, False, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)

为了让这个更有用,你需要沿着轴1使用all来进行归约:

>>> numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
array([False, False, False,  True, False, False, False, False], dtype=bool)

之后你可以像使用布尔数组那样使用它。获取索引的一个简单方法是:

>>> bool_indices = numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
>>> numpy.mgrid[0:len(bool_indices)][bool_indices]
array([3])

对于列表,你可以改编这些滑动窗口迭代器,采用类似的方法。

对于非常大的数组和子数组,你可以这样节省内存:

>>> windows = rolling_window(a, 3)
>>> sub = [8, 4, 0]
>>> hits = numpy.ones((len(a) - len(sub) + 1,), dtype=bool)
>>> for i, x in enumerate(sub):
...     hits &= numpy.in1d(windows[:,i], [x])
... 
>>> hits
array([False, False, False,  True, False, False, False, False], dtype=bool)
>>> hits.nonzero()
(array([3]),)

不过,这样做可能会稍微慢一些。

撰写回答