在numpy数组中查找阶梯状结构

1 投票
1 回答
55 浏览
提问于 2025-04-13 13:16

我正在尝试在一个numpy数组中找到一个像楼梯一样的结构(从右到左)。我有一段代码是用for循环来解决这个问题,但我觉得应该有更聪明的方法可以使用numpy的ufunctions(可能是这个?)

def pattern_found(sts):
    z = 0
    for i in range(sts.shape[0]): 
        x = np.argmax(sts[i, z:])
        z += x
        if not x: 
            return False
    return True


states = np.array([[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
               [1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
               [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]])
print(pattern_found(states)) # stair found going from right to left

states = np.array([[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
                   [0, 1, 0, 0, 0, 0, 0, 1, 0, 1],
                   [0, 0, 0, 1, 1, 0, 0, 0, 0, 0]])

print(pattern_found(states)) # no stair found 

1 个回答

2

我觉得这个任务很难高效地用向量化的方法来处理,因为它本身就是一个需要反复执行的过程。

不过,你可以简化逻辑,只在水平方向上移动,最多测试N个单元格,适用于一个(M, N)的数组。

另外,你可以使用numba来加速计算:

from numba import jit

@jit(nopython=True)     # optional
def pattern_found(sts):
    row = 0
    for col in range(sts.shape[1]):
        if sts[row, col]:  # or if sts[row, col] == 1:
            row += 1
            if row >= sts.shape[0]:
                return True
    return False
        
pattern_found(example1)
# True

pattern_found(example2)
# False

这个是怎么工作的呢?

我们会遍历每一列的索引,如果找到一个1,就增加行的索引。如果在某个时刻行的索引等于数组的行数,那就完成了。否则就失败了。

这里是两个例子的路径:

# example1
np.array([[-, -, -, -, +, 0, 0, 0, 0, 0, 0],
          [1, 0, 1, 0, 0, -, +, 0, 0, 0, 1],
          [0, 0, 0, 0, 0, 0, 0, -, +, 0, 0]])
                                   ->True
# example2
np.array([[-, -, -, -, +, 1, 0, 0, 0, 0],
          [0, 1, 0, 0, 0, -, -, +, 0, 1],
          [0, 0, 0, 1, 1, 0, 0, 0, -, -]]) -> False

可重复的输入:

example1 = np.array([[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                     [1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
                     [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]])

example2 = np.array([[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
                     [0, 1, 0, 0, 0, 0, 0, 1, 0, 1],
                     [0, 0, 0, 1, 1, 0, 0, 0, 0, 0]])

撰写回答