给定三维布尔数据:
np.random.seed(13)
bool_data = np.random.randint(2, size=(2,3,6))
>> bool_data
array([[[0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1]],
[[1, 0, 1, 1, 0, 0],
[0, 1, 1, 1, 1, 0],
[1, 1, 1, 0, 0, 0]]])
我希望计算每行中由两个0限定的连续1的数目(沿轴=1),并返回一个带有计数的数组。对于bool_data
,这将给出array([1, 1, 2, 4])
。你知道吗
由于bool_data
的3D结构和每行的变量计数,我不得不笨拙地将计数转换为嵌套列表,使用itertools.chain
将其展平,然后将列表反向转换为数组:
# count consecutive 1's bounded by two 0's
def count_consect_ones(input):
return np.diff(np.where(input==0)[0])-1
# run tallies across all rows in bool_data
consect_ones = []
for i in range(len(bool_data)):
for j in range(len(bool_data[i])):
res = count_consect_ones(bool_data[i, j])
consect_ones.append(list(res[res!=0]))
>> consect_ones
[[], [1, 1], [], [2], [4], []]
# combines nested lists
from itertools import chain
consect_ones_output = np.array(list(chain.from_iterable(consect_ones)))
>> consect_ones_output
array([1, 1, 2, 4])
有没有更有效或更聪明的方法?你知道吗
如果改为使用.extend,则直接附加序列的内容。这将保存以后合并嵌套列表的步骤:
此外,您可以跳过索引,直接遍历维度:
我们可以使用一种技巧,用零填充列,然后在平坦的版本上查找渐变和渐变索引,最后过滤出与边界索引对应的索引,从而为我们自己提供一个矢量化的解决方案,就像这样-
解释-
将每行两端的零填充为“sentinent”的想法是,当我们获得一次性切片数组版本并进行比较时,我们可以分别用
b[...,1:]>b[...,:-1]
和b[...,1:]<b[...,:-1]
来检测斜坡上升和斜坡下降位置。因此,我们得到s0
和s1
作为1s
的每个岛的开始和结束索引。现在,我们不需要边界索引,所以我们需要将它们的列索引追溯到原始的未填充输入数组,从而得到该位:s0%(n-1)
和s1%(n-1)
。我们需要删除1s
的每个岛的开始位于左侧边界,而1s
的每个岛的结束位于右侧边界的所有情况。开始和结束是s0
和s1
。所以,我们用这些来检查s0
是0
还是s1
是a.shape[2]
。这些给了我们有效的。孤岛长度是通过s1-s0
获得的,因此使用有效的掩码对其进行掩码以获得所需的输出。你知道吗样本输入,输出-
相关问题 更多 >
编程相关推荐