Python中具有周期边界的数组中的连续值

2024-04-20 10:46:27 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一些填充了01的2D数组:

import numpy as np

a = np.random.randint(2, size=(20, 20))
b = np.random.randint(2, size=(20, 20))
c = np.random.randint(2, size=(20, 20))
d = np.random.randint(2, size=(20, 20)) 

我想计算周期性边界的连续出现次数。 这意味着(在1D中为清晰起见):

[1 1 0 0 1 1 0 1 1 1]

应该给我5(最后三个元素+前两个元素)。
2D数组应该在第三个轴(如果从0开始的话是第二个轴)进行比较/计数,比如首先将数组堆叠在axis=2中,然后应用与1D相同的算法。但是我不确定这是否是最简单的方法。你知道吗


Tags: importnumpy算法元素sizeasnprandom
3条回答

对于2D的ndarraysa和更高的dim数组,这是一种提高性能效率的方法-

def count_periodic_boundary(a):
    a = a.reshape(-1,a.shape[-1])
    m = a==1    
    c0 = np.flip(m,axis=-1).argmin(axis=-1)+m.argmin(axis=-1)
    z = np.zeros(a.shape[:-1]+(1,),dtype=bool)
    p = np.hstack((z,m,z))
    c = (p[:,:-1]<p[:,1:]).sum(1)
    s = np.r_[0,c[:-1].cumsum()]
    l = np.diff(np.flatnonzero(np.diff(p.ravel())))[::2]
    d = np.maximum(c0,np.maximum.reduceat(l,s))    
    return np.where(m.all(-1),a.shape[-1],d)

示例运行-

In [75]: np.random.seed(0)
    ...: a = np.random.randint(2, size=(5, 20))

In [76]: a
Out[76]: 
array([[0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1],
       [0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0],
       [0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1],
       [1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0],
       [0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0]])

In [77]: count_periodic_boundary(a)
Out[77]: array([7, 4, 5, 2, 6])


In [72]: np.random.seed(0)
    ...: a = np.random.randint(2, size=(2, 5, 20))

In [73]: a
Out[73]: 
array([[[0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1],
        [0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0],
        [0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1],
        [1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0],
        [0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0]],

       [[1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0],
        [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0],
        [1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1],
        [0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0],
        [1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0]]])

In [74]: count_periodic_boundary(a)
Out[74]: array([7, 4, 5, 2, 6, 2, 5, 4, 2, 1])

这是一条两行线,诚然有一条很长的线:

*m,n = a.shape
return np.minimum(n,(np.arange(1,2*n+1)-np.maximum.accumulate(np.where(a[...,None,:],0,np.arange(1,2*n+1).reshape(2,n)).reshape(*m,2*n),-1)).max(-1))

工作原理:

让我们首先忽略环绕,考虑一个简单的示例:a=[1 0 0 1 0 1 1 1 0] 我们想把它转换成b=[1 0 0 1 2 0 1 2 3 0],所以我们可以简单地取最大值。生成b的一种方法是取arange r=[1 2 3 4 5 6 7 8 9 10],减去aux=[0 2 3 3 6 6 10]。我们用r乘以(1-a)得到[0 2 3 0 0 6 0 0 0 0 10]并取累积最大值。你知道吗

为了处理环绕,我们简单地把两个副本一个接一个,然后使用上面的。你知道吗

下面是代码再次分解为更小的位并注释:

*m,n = a.shape
# r has length 2*n because of how we deal with the wrap around
r = np.arange(1,2*n+1)
# create r x (1-a) using essentially np.where(a,0,r)
# it's a bit more involved because we are cloning a in the same step
# a will be doubled along a new axis we insert before the last one
# this will happen by means of broadcasting against r which we distribute
# over two rows along the new axis
# in the very end we merge the new and the last axis
r1_a = np.where(a[...,None,:],0,r.reshape(2,n)).reshape(*m,2*n)
# take cumulative max
aux = np.maximum.accumulate(r1_a,-1)
# finally, take the row wise maximum and deal with all-one rows
return np.minimum(n,(r-aux).max(-1))

您可以从itertools使用groupby

from itertools import groupby

a = [1, 1, 0, 0, 1, 1, 0, 1, 1, 1]

def get_longest_seq(a):
    if all(a):
        return len(a)

    a_lens = [len(list(it)) for k, it in groupby(a) if k != 0]

    if a[0] == 1 and a[-1] == 1:
        m = max(max(a_lens), a_lens[0] + a_lens[-1])
    else:
        m = max(a_lens)
    return m

print(get_longest_seq(a))

相关问题 更多 >