在Python中优化均值

1 投票
9 回答
875 浏览
提问于 2025-04-16 04:38

我有一个函数,用来在K均值算法中更新中心点(平均值)。

我运行了一个性能分析工具,发现这个函数消耗了很多计算时间。

它的样子是这样的:

def updateCentroid(self, label):
    X=[]; Y=[]
    for point in self.clusters[label].points:
        X.append(point.x)
        Y.append(point.y)
    self.clusters[label].centroid.x = numpy.mean(X)
    self.clusters[label].centroid.y = numpy.mean(Y)

所以我在想,是否有更高效的方法来计算这些点的平均值?如果没有,是否有更优雅的写法呢?;)

编辑:

感谢大家的精彩回复!我在想,也许我可以用一种累积的方式来计算平均值,像这样:

alt text

其中x_bar(t)是新的平均值,x_bar(t-1)是旧的平均值。

这样会得到一个类似于这个的函数:

def updateCentroid(self, label):
    cluster = self.clusters[label]
    n = len(cluster.points)
    cluster.centroid.x *= (n-1) / n
    cluster.centroid.x += cluster.points[n-1].x / n
    cluster.centroid.y *= (n-1) / n
    cluster.centroid.y += cluster.points[n-1].y / n

虽然现在并不太管用,但你觉得经过一些调整,这样的方法会有效吗?

9 个回答

3

为什么不避免构建额外的数组呢?

def updateCentroid(self, label):
  sumX=0; sumY=0
  N = len( self.clusters[label].points)
  for point in self.clusters[label].points:
    sumX += point.x
    sumY += point.y
  self.clusters[label].centroid.x = sumX/N
  self.clusters[label].centroid.y = sumY/N
5

K-means算法已经在 scipy.cluster.vq 这个库里实现了。如果你想对这个实现做些什么改变,建议你先去看看那里的代码:

In [62]: import scipy.cluster.vq as scv
In [64]: scv.__file__
Out[64]: '/usr/lib/python2.6/dist-packages/scipy/cluster/vq.pyc'

另外,你提到的算法是把数据放在一个字典里(self.clusters)和通过属性查找(.points),这样你就不得不使用比较慢的Python循环来获取数据。其实,如果使用numpy数组的话,速度会快很多。可以参考一下scipy中K-means聚类的实现,看看有没有更好的数据结构。

1

好的,我找到了一个快速的移动平均解决方案,而且没有改变数据结构:

def updateCentroid(self, label):
    cluster = self.clusters[label]
    n = len(cluster.points)
    cluster.centroid.x = ((n-1)*cluster.centroid.x + cluster.points[n-1].x)/n
    cluster.centroid.y = ((n-1)*cluster.centroid.y + cluster.points[n-1].y)/n

这样一来,整个k均值算法的计算时间降低到了原来的13%。=)

感谢大家提供的精彩见解!

撰写回答