在Python中优化均值
我有一个函数,用来在K均值算法中更新中心点(平均值)。
我运行了一个性能分析工具,发现这个函数消耗了很多计算时间。
它的样子是这样的:
def updateCentroid(self, label):
X=[]; Y=[]
for point in self.clusters[label].points:
X.append(point.x)
Y.append(point.y)
self.clusters[label].centroid.x = numpy.mean(X)
self.clusters[label].centroid.y = numpy.mean(Y)
所以我在想,是否有更高效的方法来计算这些点的平均值?如果没有,是否有更优雅的写法呢?;)
编辑:
感谢大家的精彩回复!我在想,也许我可以用一种累积的方式来计算平均值,像这样:

其中x_bar(t)是新的平均值,x_bar(t-1)是旧的平均值。
这样会得到一个类似于这个的函数:
def updateCentroid(self, label):
cluster = self.clusters[label]
n = len(cluster.points)
cluster.centroid.x *= (n-1) / n
cluster.centroid.x += cluster.points[n-1].x / n
cluster.centroid.y *= (n-1) / n
cluster.centroid.y += cluster.points[n-1].y / n
虽然现在并不太管用,但你觉得经过一些调整,这样的方法会有效吗?
9 个回答
3
为什么不避免构建额外的数组呢?
def updateCentroid(self, label):
sumX=0; sumY=0
N = len( self.clusters[label].points)
for point in self.clusters[label].points:
sumX += point.x
sumY += point.y
self.clusters[label].centroid.x = sumX/N
self.clusters[label].centroid.y = sumY/N
5
K-means算法已经在 scipy.cluster.vq 这个库里实现了。如果你想对这个实现做些什么改变,建议你先去看看那里的代码:
In [62]: import scipy.cluster.vq as scv
In [64]: scv.__file__
Out[64]: '/usr/lib/python2.6/dist-packages/scipy/cluster/vq.pyc'
另外,你提到的算法是把数据放在一个字典里(self.clusters
)和通过属性查找(.points
),这样你就不得不使用比较慢的Python循环来获取数据。其实,如果使用numpy数组的话,速度会快很多。可以参考一下scipy中K-means聚类的实现,看看有没有更好的数据结构。
1
好的,我找到了一个快速的移动平均解决方案,而且没有改变数据结构:
def updateCentroid(self, label):
cluster = self.clusters[label]
n = len(cluster.points)
cluster.centroid.x = ((n-1)*cluster.centroid.x + cluster.points[n-1].x)/n
cluster.centroid.y = ((n-1)*cluster.centroid.y + cluster.points[n-1].y)/n
这样一来,整个k均值算法的计算时间降低到了原来的13%。=)
感谢大家提供的精彩见解!