棘手的中位数问题

11 投票
2 回答
749 浏览
提问于 2025-04-16 23:35

给定n个点,从这些点中选择一个,使得到这个点的距离总和最小,相比于其他点。

距离的计算方式是这样的:

对于一个点(x,y),它周围8个相邻的点距离都是1。

(x+1,y)(x+1,y+1),(x+1,y-1),(x,y+1),(x,y-1),(x-1,y)(x-1,y+1),(x-1,y-1)

编辑

更清楚的解释。

定义一个函数foo

foo(point_a,point_b) = max(abs(point_a.x - point_b.x),abs(point_a.y - point_b.y))

找到一个点x,使得sum([foo(x,y) for y in list_of_points])的值最小。

示例

输入:

12 -14
-3 3
-14 7
-14 -3
2 -12
-1 -6

输出

-1 -6

例如:

(4,5)和(6,7)之间的距离是2。

这个问题可以在O(n^2)的时间内解决,通过检查每一对点的距离总和。

有没有更好的算法来解决这个问题呢?

2 个回答

1

我能想到一种比O(n^2)更好的方案,至少在常见情况下是这样的。

首先,利用你的输入点构建一个四叉树。对于树中的每一个节点,计算该节点内点的数量和平均位置。然后,对于每一个点,你可以利用四叉树来计算它与其他所有点的距离,这样的时间复杂度会小于O(n)。如果你要计算一个点p到一个远处的四叉树节点v的距离,而v与p的45度对角线没有重叠,那么从p到v中所有点的总距离就很容易计算了(如果v在水平方向上与p的距离更远,只需计算v.num_points * |p.x - v.average.x|,如果v在垂直方向上更远,则使用y坐标进行类似的计算)。如果v与某条45度对角线重叠,就对它的组成部分进行递归计算。

这样做应该能比O(n^2)更快,至少当你能找到一个平衡的四叉树来表示你的点时。

5

更新: 有时候它找不到最优解,我会把这个留着,直到我找到问题所在。

这个是 O(n): 第n个是O(n)(这是预期的,不是最坏情况),遍历列表是O(n)。如果你需要严格的O(),那么可以选择中间的元素并进行排序,但那样就会变成O(n*log(n))。

注意:很容易修改它以返回所有的最优点。

import sys

def nth(sample, n):
    pivot = sample[0]
    below = [s for s in sample if s < pivot]
    above = [s for s in sample if s > pivot]
    i, j = len(below), len(sample)-len(above)
    if n < i:      return nth(below, n)
    elif n >= j:   return nth(above, n-j)
    else:          return pivot

def getbest(li):
    ''' li is a list of tuples (x,y) '''
    l = len(li)
    lix = [x[0] for x in li]
    liy = [x[1] for x in li]

    mid_x1 = nth(lix, l/2) if l%2==1 else nth(lix, l/2-1)
    mid_x2 = nth(lix, l/2)
    mid_y1 = nth(liy, l/2) if l%2==1 else nth(liy, l/2-1)
    mid_y2 = nth(liy, l/2)

    mindist = sys.maxint
    minp = None
    for p in li:
        dist = 0 if mid_x1 <= p[0] <= mid_x2 else min(abs(p[0]-mid_x1), abs(p[0]-mid_x2))
        dist += 0 if mid_y1 <= p[1] <= mid_y2 else min(abs(p[1]-mid_y1), abs(p[1]-mid_y2))
        if dist < mindist:
            minp, mindist = p, dist
    return minp

这个方法是基于一维问题的解决方案——在一串数字中找到一个数字,使得所有数字到这个数字的距离之和最小。

这个问题的解决方案是(排序后的)列表中的中间元素,或者如果列表中的元素数量是偶数,可以选择两个中间元素之间的任何数字(包括这两个元素)。

更新:我的 nth 算法似乎很慢,可能有更好的方法来重写它,sort 在处理少于100000个元素时表现更好,所以如果你要比较速度,只需加上 sort(lix); sort(liy); 然后

def nth(sample, n):
    return sample[n]

对于想要测试自己解决方案的朋友们,这里是我使用的方法。只需运行一个循环,生成输入,然后将你的解决方案与暴力破解的输出进行比较。

import random
def example(length):
    l = []
    for x in range(length):
        l.append((random.randint(-100, 100), random.randint(-100,100)))
    return l

def bruteforce(li):
    bestsum = sys.maxint
    bestp = None
    for p in li:
        sum = 0
        for p1 in li:
            sum += max(abs(p[0]-p1[0]), abs(p[1]-p1[1]))
        if sum < bestsum:
            bestp, bestsum = p, sum
    return bestp

撰写回答