对不同长度数组进行子集和组合的有效方法

2024-04-25 22:55:25 发布

您现在位置:Python中文网/ 问答频道 /正文

给定三维布尔数据:

np.random.seed(13)
bool_data = np.random.randint(2, size=(2,3,6))

>> bool_data 
array([[[0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1]],

       [[1, 0, 1, 1, 0, 0],
        [0, 1, 1, 1, 1, 0],
        [1, 1, 1, 0, 0, 0]]])

我希望计算每行中由两个0限定的连续1的数目(沿轴=1),并返回一个带有计数的数组。对于bool_data,这将给出array([1, 1, 2, 4])。你知道吗

由于bool_data的3D结构和每行的变量计数,我不得不笨拙地将计数转换为嵌套列表,使用itertools.chain将其展平,然后将列表反向转换为数组:

# count consecutive 1's bounded by two 0's
def count_consect_ones(input):
    return np.diff(np.where(input==0)[0])-1

# run tallies across all rows in bool_data
consect_ones = []
for i in range(len(bool_data)):
    for j in range(len(bool_data[i])):
        res = count_consect_ones(bool_data[i, j])
        consect_ones.append(list(res[res!=0]))

>> consect_ones
[[], [1, 1], [], [2], [4], []]

# combines nested lists
from itertools import chain
consect_ones_output = np.array(list(chain.from_iterable(consect_ones)))

>> consect_ones_output
array([1, 1, 2, 4])

有没有更有效或更聪明的方法?你知道吗


Tags: inchain列表datacountnponesres
2条回答

consect_ones.append(list(res[res!=0]))

如果改为使用.extend,则直接附加序列的内容。这将保存以后合并嵌套列表的步骤:

consect_ones.extend(res[res!=0])

此外,您可以跳过索引,直接遍历维度:

consect_ones = []
for i in bool_data:
    for j in i:
        res = count_consect_ones(j)
        consect_ones.extend(res[res!=0])

我们可以使用一种技巧,用零填充列,然后在平坦的版本上查找渐变和渐变索引,最后过滤出与边界索引对应的索引,从而为我们自己提供一个矢量化的解决方案,就像这样-

# Input 3D array : a
b = np.pad(a, ((0,0),(0,0),(1,1)), 'constant', constant_values=(0,0))

# Get ramp-up and ramp-down indices/ start-end indices of 1s islands
s0 = np.flatnonzero(b[...,1:]>b[...,:-1])
s1 = np.flatnonzero(b[...,1:]<b[...,:-1])

# Filter only valid ones that are not at borders
n = b.shape[2]
valid_mask = (s0%(n-1)!=0) & (s1%(n-1)!=a.shape[2])
out = (s1-s0)[valid_mask]

解释-

将每行两端的零填充为“sentinent”的想法是,当我们获得一次性切片数组版本并进行比较时,我们可以分别用b[...,1:]>b[...,:-1]b[...,1:]<b[...,:-1]来检测斜坡上升和斜坡下降位置。因此,我们得到s0s1作为1s的每个岛的开始和结束索引。现在,我们不需要边界索引,所以我们需要将它们的列索引追溯到原始的未填充输入数组,从而得到该位:s0%(n-1)s1%(n-1)。我们需要删除1s的每个岛的开始位于左侧边界,而1s的每个岛的结束位于右侧边界的所有情况。开始和结束是s0s1。所以,我们用这些来检查s00还是s1a.shape[2]。这些给了我们有效的。孤岛长度是通过s1-s0获得的,因此使用有效的掩码对其进行掩码以获得所需的输出。你知道吗

样本输入,输出-

In [151]: a
Out[151]: 
array([[[0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1]],

       [[1, 0, 1, 1, 0, 0],
        [0, 1, 1, 1, 1, 0],
        [1, 1, 1, 0, 0, 0]]])

In [152]: out
Out[152]: array([1, 1, 2, 4])

相关问题 更多 >

    热门问题