棘手的中位数问题
给定n个点,从这些点中选择一个,使得到这个点的距离总和最小,相比于其他点。
距离的计算方式是这样的:
对于一个点(x,y),它周围8个相邻的点距离都是1。
(x+1,y)(x+1,y+1),(x+1,y-1),(x,y+1),(x,y-1),(x-1,y)(x-1,y+1),(x-1,y-1)
编辑
更清楚的解释。
定义一个函数foo
foo(point_a,point_b) = max(abs(point_a.x - point_b.x),abs(point_a.y - point_b.y))
找到一个点x,使得sum([foo(x,y) for y in list_of_points])的值最小。
示例
输入:
12 -14
-3 3
-14 7
-14 -3
2 -12
-1 -6
输出
-1 -6
例如:
(4,5)和(6,7)之间的距离是2。
这个问题可以在O(n^2)的时间内解决,通过检查每一对点的距离总和。
有没有更好的算法来解决这个问题呢?
2 个回答
我能想到一种比O(n^2)更好的方案,至少在常见情况下是这样的。
首先,利用你的输入点构建一个四叉树。对于树中的每一个节点,计算该节点内点的数量和平均位置。然后,对于每一个点,你可以利用四叉树来计算它与其他所有点的距离,这样的时间复杂度会小于O(n)。如果你要计算一个点p到一个远处的四叉树节点v的距离,而v与p的45度对角线没有重叠,那么从p到v中所有点的总距离就很容易计算了(如果v在水平方向上与p的距离更远,只需计算v.num_points * |p.x - v.average.x|
,如果v在垂直方向上更远,则使用y坐标进行类似的计算)。如果v与某条45度对角线重叠,就对它的组成部分进行递归计算。
这样做应该能比O(n^2)更快,至少当你能找到一个平衡的四叉树来表示你的点时。
更新: 有时候它找不到最优解,我会把这个留着,直到我找到问题所在。
这个是 O(n)
: 第n个是O(n)(这是预期的,不是最坏情况),遍历列表是O(n)。如果你需要严格的O(),那么可以选择中间的元素并进行排序,但那样就会变成O(n*log(n))。
注意:很容易修改它以返回所有的最优点。
import sys
def nth(sample, n):
pivot = sample[0]
below = [s for s in sample if s < pivot]
above = [s for s in sample if s > pivot]
i, j = len(below), len(sample)-len(above)
if n < i: return nth(below, n)
elif n >= j: return nth(above, n-j)
else: return pivot
def getbest(li):
''' li is a list of tuples (x,y) '''
l = len(li)
lix = [x[0] for x in li]
liy = [x[1] for x in li]
mid_x1 = nth(lix, l/2) if l%2==1 else nth(lix, l/2-1)
mid_x2 = nth(lix, l/2)
mid_y1 = nth(liy, l/2) if l%2==1 else nth(liy, l/2-1)
mid_y2 = nth(liy, l/2)
mindist = sys.maxint
minp = None
for p in li:
dist = 0 if mid_x1 <= p[0] <= mid_x2 else min(abs(p[0]-mid_x1), abs(p[0]-mid_x2))
dist += 0 if mid_y1 <= p[1] <= mid_y2 else min(abs(p[1]-mid_y1), abs(p[1]-mid_y2))
if dist < mindist:
minp, mindist = p, dist
return minp
这个方法是基于一维问题的解决方案——在一串数字中找到一个数字,使得所有数字到这个数字的距离之和最小。
这个问题的解决方案是(排序后的)列表中的中间元素,或者如果列表中的元素数量是偶数,可以选择两个中间元素之间的任何数字(包括这两个元素)。
更新:我的 nth
算法似乎很慢,可能有更好的方法来重写它,sort
在处理少于100000个元素时表现更好,所以如果你要比较速度,只需加上 sort(lix); sort(liy);
然后
def nth(sample, n):
return sample[n]
对于想要测试自己解决方案的朋友们,这里是我使用的方法。只需运行一个循环,生成输入,然后将你的解决方案与暴力破解的输出进行比较。
import random
def example(length):
l = []
for x in range(length):
l.append((random.randint(-100, 100), random.randint(-100,100)))
return l
def bruteforce(li):
bestsum = sys.maxint
bestp = None
for p in li:
sum = 0
for p1 in li:
sum += max(abs(p[0]-p1[0]), abs(p[1]-p1[1]))
if sum < bestsum:
bestp, bestsum = p, sum
return bestp