如何快速统计numpy.array中相等的元素数量?

2 投票
2 回答
1951 浏览
提问于 2025-04-18 16:09

我有一个Python的矩阵

leafs = np.array([[1,2,3],[1,2,4],[2,3,4],[4,2,1]])

我想计算每一对行中有多少次相同的元素。

在这种情况下,我会得到一个4x4的接近度矩阵

proximity = array([[3, 2, 0, 1],
                   [2, 3, 1, 1],
                   [0, 1, 3, 0],
                   [1, 1, 0, 3]])

这是我现在正在使用的代码。

proximity = []

for i in range(n):
 print(i)
 proximity.append(np.apply_along_axis(lambda x: sum(x==leafs[i, :]), axis=1,
                                      arr=leafs))

我需要一个更快的解决方案

编辑: 被接受的解决方案在这个例子中不起作用

    >>> type(f.leafs)
<class 'numpy.ndarray'>
>>> f.leafs.shape
(7210, 1000)
>>> f.leafs.dtype
dtype('int64')

>>> f.leafs.reshape(7210, 1, 1000) == f.leafs.reshape(1, 7210, 1000)
False
>>> f.leafs
array([[ 19,  32,  16, ..., 143, 194, 157],
       [ 19,  32,  16, ..., 143, 194, 157],
       [ 19,  32,  16, ..., 143, 194, 157],
       ..., 
       [139,  32,  16, ...,   5, 194, 157],
       [170,  32,  16, ...,   5, 194, 157],
       [170,  32,  16, ...,   5, 194, 157]])
>>> 

2 个回答

5

这里有一种方法,使用了广播的技巧。不过要注意:这个临时数组 eq 的形状是 (nrows, nrows, ncols),所以如果 nrows 是 4000,ncols 是 1000,那么 eq 就需要 16GB 的内存。

In [38]: leafs
Out[38]: 
array([[1, 2, 3],
       [1, 2, 4],
       [2, 3, 4],
       [4, 2, 1]])

In [39]: nrows, ncols = leafs.shape

In [40]: eq = leafs.reshape(nrows,1,ncols) == leafs.reshape(1,nrows,ncols)

In [41]: proximity = eq.sum(axis=-1)

In [42]: proximity
Out[42]: 
array([[3, 2, 0, 1],
       [2, 3, 1, 1],
       [0, 1, 3, 0],
       [1, 1, 0, 3]])

另外要提到的是,这个解决方案效率不高:proximity 是对称的,而且对角线上的值总是等于 ncols,但这个方法计算了整个数组,所以它做的工作比实际需要的多了两倍以上。

1

Warren Weckesser 提出了一个很棒的解决方案,利用了广播的特性。不过,即使是用简单的循环方法,性能也差不多。你最开始的方案中使用的 np.apply_along_axis 速度比较慢,因为它没有利用向量化的优势。下面的代码就解决了这个问题:

def proximity_1(leafs):
    n = len(leafs)
    proximity = np.zeros((n,n))
    for i in range(n):
        proximity[i] = (leafs == leafs[i]).sum(1)  
    return proximity

你也可以用列表推导式来让上面的代码更简洁。不同之处在于,np.apply_along_axis 会以一种不优化的方式遍历所有行,而 leafs == leafs[i] 则能利用 numpy 的速度。

Warren Weckesser 的解决方案确实展现了 numpy 的美妙之处。不过,它需要创建一个大小为 nrows*nrows*ncols 的中间三维数组,这会带来额外的开销。所以如果你的数据量很大,简单的循环可能会更高效。

下面是一个例子。以下是 Warren Weckesser 提供的代码,封装在一个函数里。(我不太清楚这里的代码版权规则,所以我假设这个引用就足够了 :)

def proximity_2(leafs):
    nrows, ncols = leafs.shape    
    eq = leafs.reshape(nrows,1,ncols) == leafs.reshape(1,nrows,ncols)
    proximity = eq.sum(axis=-1)  
    return proximity

现在我们来评估一下在一个大小为 10000 x 100 的随机整数数组上的性能。

leafs = np.random.randint(1,100,(10000,100))
time proximity_1(leafs)
>> 28.6 s
time proximity_2(leafs) 
>> 35.4 s 

我在同一台机器上的 IPython 环境中运行了这两个例子。

撰写回答