如何有效地比较numpy数组中的条目?

2024-03-28 12:14:38 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个numpy数组embed_vec,长度为tot_vec,其中每个条目都是一个3d向量:

[[ 0.52483319  0.78015841  0.71117216]
 [ 0.53041481  0.79462171  0.67234534]
 [ 0.53645428  0.80896727  0.63119403]
 ..., 
 [ 0.72283509  0.40070804  0.15220522]
 [ 0.71277758  0.38498613  0.16141834]
 [ 0.70221445  0.36918032  0.17370776]]

对于这个数组中的每个元素,我想找出“接近”该条目的其他条目的数量。我的意思是两个向量之间的距离小于指定值R。为此,我必须将这个数组中所有可能的对进行比较,然后找出数组中每个向量的接近向量数。所以我要这样做:

^{pr2}$

但是,这是非常低效的,因为我有两个嵌套的python循环,对于更大的数组大小,这需要花费很长时间。如果这是在C++或FORTRAN中,那就不是一个大问题。我的问题是,使用numpy可以有效地使用向量化方法实现同样的事情吗?顺便说一句,我不介意用熊猫来解决这个问题。在


Tags: 方法目的numpy元素距离数量条目embed
3条回答

方法1:矢量化方法-

def vectorized_app(embed_vec, R):  
    tot_vec = embed_vec.shape[0]          
    r,c = np.triu_indices(tot_vec,1)
    subs = embed_vec[r] - embed_vec[c]
    dists = np.einsum('ij,ij->i',subs,subs)
    return np.bincount(r,dists<R**2,minlength=tot_vec)

方法2:循环复杂度较低(对于非常大的阵列)-

^{pr2}$

标杆管理

原始方法-

def loopy_app(embed_vec, R):
    tot_vec = embed_vec.shape[0]
    p = np.zeros(tot_vec) # This contains the number of close vectors
    for i in range(tot_vec-1):
        for j in range(i+1, tot_vec):
            if np.linalg.norm(embed_vec[i]-embed_vec[j]) < R:
                p[i] += 1
    return p                

时间安排-

In [76]: # Sample random array
    ...: embed_vec = np.random.rand(3000,3)
    ...: R = 0.5
    ...: 

In [77]: %timeit loopy_app(embed_vec, R)
1 loops, best of 3: 50.5 s per loop

In [78]: %timeit loopy_less_app(embed_vec, R)
10 loops, best of 3: 143 ms per loop

350x+加速!在

使用更大的数组和建议的loopy_less_app-

In [81]: # Sample random array
    ...: embed_vec = np.random.rand(20000,3)
    ...: R = 0.5
    ...: 

In [82]: %timeit loopy_less_app(embed_vec, R)
1 loops, best of 3: 4.47 s per loop

首先广播差异:

disp_vecs=tot_vec[:,None,:]-tot_vec[None,:,:]

现在,根据你的数据集有多大,你可能想做一个没有所有数学运算的第一次传递。如果距离小于r,则所有组件都应小于r

first_mask=np.max(disp_vec, axis=-1)<r

然后进行实际计算

disps=np.linlg.norm(disp_vec[first_mask],axis=-1)
second_mask=disps<r

现在重新分配

^{pr2}$

disps现在是好值,first_mask是一个布尔掩码。你可以从那里处理。在

我对这个问题很感兴趣,并试图用scipy的cKDTree有效地解决它。但是,这种方法可能会耗尽内存,因为内部会维护距离为<;=R的所有对的列表。如果R和tot_vec足够小,则可以:

import numpy as np
from scipy.spatial import cKDTree as KDTree

tot_vec = 60000
embed_vec = np.random.randn(tot_vec, 3)
R = 0.1

tree = KDTree(embed_vec, leafsize=100)
p = np.zeros(tot_vec)
for pair in tree.query_pairs(R):
    p[pair[0]] += 1
    p[pair[1]] += 1

如果内存是个问题,那么只要付出一定的努力,就可以在Python中将query_pairs重写为一个生成器函数,但代价是C性能。在

相关问题 更多 >