基于二维数组中一列数据分桶,并使用cython估计每个桶的均值

1 投票
1 回答
1697 浏览
提问于 2025-04-18 15:15

为了提高我代码的速度,这对我的MCMC(马尔科夫链蒙特卡洛)非常重要,我想用cython来替换我Python代码中的一些瓶颈。因为我正在处理一个很大的二维数组,我需要根据这个二维数组的一列来对数据进行分组,然后在其他列中根据第一列的分组计算每个组的平均值。我之前使用的Python代码是:

   import numpy as np
   d = np.random.random((10**5, 3))
  #binning data again based on first column 
   bins = np.linspace(ndata[0,0], ndata[-1,0], 10)
   #compute the mean in each bin for different input parameters
   digitized = np.digitize(ndata[:,0], bins)
   r= np.array([ndata[digitized == i,0].mean() for i in range(1, len(bins))])
   p= np.array([ndata[digitized == i,1].mean() for i in range(1, len(bins))])
   q= np.array([ndata[digitized == i,2].mean() for i in range(1, len(bins))])

我该如何用cython的代码来加速,至少提高两个数量级,替代numpy.digitize这个函数呢?

1 个回答

5

我觉得你不需要用cython,实际上你可以用numpy.bincount来解决这个问题。下面是一个例子:

import numpy as np
d = np.random.random(10**5)
numbins = 10

bins = np.linspace(d.min(), d.max(), numbins+1)
# This line is not necessary, but without it the smallest bin only has 1 value.
bins = bins[1:]
digitized = bins.searchsorted(d)

bin_means = (np.bincount(digitized, weights=d, minlength=numbins) /
             np.bincount(digitized, minlength=numbins))

更新

我们先来聊聊为什么上面的代码比你问题中的代码快,以及为什么在这种情况下cython可能帮助不大。在你的代码中,当你执行[digitized == i] for i in range(numbins)]时,其实是对digitized数组进行了numbins次遍历。如果你了解大O符号,你会发现这就是O(n * m)的复杂度。另一方面,bincount的做法有点不同。bincount的工作原理大致相当于:

def bincount(digitized, Weights):
   out = zeros(digitized.max() + 1)
   for i, w = zip(digitized, Weights):
       out[i] += w
   return out

它只需要对digitized进行1次遍历(如果算上最大值的话是2次),所以它的复杂度是O(n)。而且bincount已经用C语言编写并编译过,所以它的运行开销非常小,速度也很快。Cython最有用的地方在于当你的代码有很多解释器和类型检查的开销时,通过声明类型和编译代码可以减少这些开销。希望这些信息对你有帮助。

撰写回答