在Python中对500,000个地理空间点进行聚类

8 投票
2 回答
9929 浏览
提问于 2025-04-18 08:28

我现在遇到了一个问题,想在Python中对大约50万个经纬度坐标进行聚类。到目前为止,我尝试用numpy计算一个距离矩阵,然后把它传给scikit-learn的DBSCAN算法,但因为输入数据太大,系统很快就报了内存错误。

这些点是以元组的形式存储的,每个元组包含了经度、纬度和该点的数据值。

简单来说,我想知道在Python中,有什么高效的方法可以对大量的经纬度坐标进行空间聚类?为了提高速度,我愿意在一定程度上牺牲一些准确性。

补充说明:算法需要找到的聚类数量是事先不知道的。

2 个回答

6

在早期版本的scikit-learn中,DBSCAN算法会计算一个完整的距离矩阵。

可惜的是,计算这个距离矩阵需要占用O(n^2)的内存,这可能就是你内存不够用的原因。

新版本的scikit-learn(你用的是哪个版本呢?)应该可以在不计算距离矩阵的情况下工作;至少在使用索引的时候是这样。对于50万个对象,你确实想使用索引加速,因为这可以把运行时间从O(n^2)减少到O(n log n)

不过,我不太清楚scikit-learn在其索引中对地理距离的支持情况。ELKI是我知道的唯一一个可以使用R*-树索引来加速地理距离计算的工具;这使得它在这个任务上非常快速(尤其是在批量加载索引的时候)。你可以试试看。

可以看看Scikit-learn的索引文档,试着设置algorithm='ball_tree'

4

我没有你的数据,所以我随便生成了50万个随机数字,分成三列。

import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.vq import kmeans2, whiten

arr = np.random.randn(500000*3).reshape((500000, 3))
x, y = kmeans2(whiten(arr), 7, iter = 20)  #<--- I randomly picked 7 clusters
plt.scatter(arr[:,0], arr[:,1], c=y, alpha=0.33333);

out[1]:

在这里输入图片描述

我测了一下,这个Kmeans2运行了1.96秒,所以我觉得问题不在于你数据的大小。你可以把你的数据放进一个500000 x 3的numpy数组里,然后试试kmeans2。

撰写回答