对点/元组列表进行分组
我有一组元组(点),我想知道怎么把每个元组分组,条件是它们之间的距离在一定范围内。这个问题有点难以解释,但我写的简短代码应该能让你明白我的意思……我就是找不到解决办法,也不知道怎么更好地解释这个问题。
举个例子:
TPL = [(1, 1), (2, 1), (3, 2), (7, 5), (2, 7), (6, 4), (2, 3), (2, 6), (3, 1)]
Print GroupTPL(TPL, distance=1)
> [
> [(2, 7), (2, 6)],
> [(6, 4), (7, 5)],
> [(3, 2), (3, 1), (2, 3), (1, 1), (2, 1)]
> ]
我尝试过的所有方法和想法都不太好,所以我觉得没必要分享这些。希望你们能给我一些建议和技巧。
3 个回答
0
这里提供一个替代方案,虽然它默认情况下不一定比musically-ut
给出的并查集代码更快,但使用Cython
后可以实现3倍的速度提升。而且在某些情况下,它的速度反而会更快。这不是我的作品,而是在这里找到的:https://github.com/MerlijnWajer/Simba/blob/master/Units/MMLCore/tpa.pas
Cython代码:(去掉cdef int...,以及int w, int h,就可以用Python来使用)
def group_pts(pts, int w, int h):
cdef int t1, t2, c, ec, tc, l
l = len(pts)-1
if (l < 0): return False
result = [list() for i in range(l+1)]
c = 0
ec = 0
while ((l - ec) >= 0):
result[c].append(pts[0])
pts[0] = pts[l - ec]
ec += 1
tc = 1
t1 = 0
while (t1 < tc):
t2 = 0
while (t2 <= (l - ec)):
if (abs(result[c][t1][0] - pts[t2][0]) <= w) and \
(abs(result[c][t1][1] - pts[t2][1]) <= h):
result[c].append(pts[t2])
pts[t2] = pts[l - ec]
ec += 1
tc += 1
t2 -= 1
t2 += 1
t1 += 1
c += 1
return result[0:c]
这个代码可能还有一些优化的空间,但我没有花时间去做。这个方法也允许重复的元素,而并查集结构对此并不太友好。
使用SciPy的kd树来处理这个问题可能会很有趣,这样在处理更大的数据集时,速度肯定会提升。
1
我的回答来得有点晚,不过这个方法简单有效!!
from itertools import combinations
def groupTPL(inputlist):
ptdiff = lambda (p1,p2):(p1,p2,abs(p1[0]-p2[0])+ abs(p1[1]-p2[1]),sqrt((p2[1] - p1[1])**2 + (p2[0] - p1[0])**2 ))
diffs=[ x for x in map(ptdiff, combinations(inputlist,2)) if x[2]==1 or x[3]==sqrt(2)]
nk1=[]
for x in diffs:
if len(nk1)>0:
for y in nk1:
if x[0] in y or x[1] in y:
y.add(x[0])
y.add(x[1])
else:
if set(x[0:2]) not in nk1:
nk1.append(set(x[0:2]))
else:
nk1.append(set(x[0:2]))
return [list(x) for x in nk1]
print groupTPL([(1, 1), (2, 1), (3, 2), (7, 5), (2, 7), (6, 4), (2, 3), (2, 6), (3, 1)])
这个会输出如下内容::::
[[(3, 2), (3, 1), (2, 3), (1, 1), (2, 1)], [(6, 4), (7, 5)], [(2, 7), (2, 6)]]
3
我猜你想要用的是切比雪夫距离来把这些点聚在一起。
在这种情况下,最简单的方法就是使用并查集数据结构。
这是我用过的一个实现:
class UnionFind:
"""Union-find data structure. Items must be hashable."""
def __init__(self):
"""Create a new empty union-find structure."""
self.weights = {}
self.parents = {}
def __getitem__(self, obj):
"""X[item] will return the token object of the set which contains `item`"""
# check for previously unknown object
if obj not in self.parents:
self.parents[obj] = obj
self.weights[obj] = 1
return obj
# find path of objects leading to the root
path = [obj]
root = self.parents[obj]
while root != path[-1]:
path.append(root)
root = self.parents[root]
# compress the path and return
for ancestor in path:
self.parents[ancestor] = root
return root
def union(self, obj1, obj2):
"""Merges sets containing obj1 and obj2."""
roots = [self[obj1], self[obj2]]
heavier = max([(self.weights[r],r) for r in roots])[1]
for r in roots:
if r != heavier:
self.weights[heavier] += self.weights[r]
self.parents[r] = heavier
然后写一个叫做groupTPL
的函数就很简单了:
def groupTPL(TPL, distance=1):
U = UnionFind()
for (i, x) in enumerate(TPL):
for j in range(i + 1, len(TPL)):
y = TPL[j]
if max(abs(x[0] - y[0]), abs(x[1] - y[1])) <= distance:
U.union(x, y)
disjSets = {}
for x in TPL:
s = disjSets.get(U[x], set())
s.add(x)
disjSets[U[x]] = s
return [list(x) for x in disjSets.values()]
在你的数据集上运行它会得到:
>>> groupTPL([(1, 1), (2, 1), (3, 2), (7, 5), (2, 7), (6, 4), (2, 3), (2, 6), (3, 1)])
[
[(2, 7), (2, 6)],
[(6, 4), (7, 5)],
[(3, 2), (3, 1), (2, 3), (1, 1), (2, 1)]
]
不过,这个实现虽然简单,但时间复杂度是O(n^2)
。如果点的数量变得非常多,效率更高的实现方法会使用k-d树。