快速选择算法

2 投票
1 回答
1859 浏览
提问于 2025-04-17 21:34

我现在正在做计算机科学课的作业,但我搞不清楚我的代码在执行快速选择时出了什么问题。

def partition(a_list, first, last):
    pivot = a_list[last]
    i = first-1
    for j in range(first, last):
        if a_list[j] <= pivot:
            i += 1
            a_list[i], a_list[j] = a_list[j], a_list[i]
    a_list[i+1], a_list[last] = a_list[last], a_list[i+1]
    print(a_list)
    return i+1

def selection(a_list, first, last, k):
    pivot = a_list[last]
    pivotIndex = partition(a_list, first, last)
    if first == last:
        return a_list[k]
    elif k <= pivotIndex:
        return selection(a_list, first, pivotIndex-1, k)
    else:
        return selection(a_list, pivotIndex+1, last, k - pivotIndex)

print(selection([5,4,1,10,8,3,2], 0, 6, 1))
print(selection([5,4,1,10,8,3,2], 0, 6, 3))
print(selection([5,4,1,10,8,3,2], 0, 6, 6))
print(selection([5,4,1,10,8,3,2], 0, 6, 7))
print(selection([46, 50, 16, 88, 79, 77, 17, 2, 43, 13, 86, 12, 68, 33, 81, \
74, 19, 52, 98, 70, 61, 71, 93, 5, 55], 0, 24, 19))

在第三个打印语句之后,我的代码就陷入了一个循环,最后因为达到最大递归次数而崩溃。第一个输出应该是1,我知道为什么会这样。但我就是找不到解决办法来修复它。

这是我的输出,直到最后出现最大递归深度达到的错误。(可以忽略打印的列表,它只是让我看看正在进行的分区情况)

[1, 2, 5, 10, 8, 3, 4]
2
[1, 2, 5, 10, 8, 3, 4]
[1, 2, 3, 4, 8, 5, 10]
3
[1, 2, 5, 10, 8, 3, 4]
[1, 2, 3, 4, 8, 5, 10]
[1, 2, 3, 4, 8, 5, 10]
[1, 2, 3, 4, 5, 8, 10]
[1, 2, 3, 5, 4, 8, 10]
[1, 2, 3, 4, 5, 8, 10]
[1, 2, 3, 5, 4, 8, 10]
[1, 2, 3, 4, 5, 8, 10]

1 个回答

1

partition函数看起来没什么问题。主要的问题出在selection函数上。具体有以下几点:

  1. 混用了0索引和1索引
  2. 检查selection函数的边界
  3. 处理递归中的k

第一点:混用0索引和1索引

这个例子说明了问题:

print(selection([5,4,1,10,8,3,2], 0, 6, 1))

你在问题中提到期望的输出是1。列表[5,4,1,10,8,3,2]排序后是[1,2,3,4,5,8,10]。在调用selection函数时,你提供了06作为firstlast。这两个变量使用的是0索引。而对于k,你提供了1,并期望selection函数的输出是1。这就用了1索引。

这样做没有错,但会很容易让人混淆。我们应该统一一下。我选择对k使用0索引。

第二点:检查selection函数的边界

特别是这条语句:

if first == last:

应该改成:

if first >= last:

因为下面的语句:

elif k <= pivotIndex:
    return selection(a_list, first, pivotIndex-1, k)
else:
    return selection(a_list, pivotIndex+1, last, k - pivotIndex)

在这两个递归调用selection时,有可能出现first >= pivotIndex - 1pivotIndex + 1 >= last的情况。在这种情况下,我们知道子列表中只剩下1个元素,所以我们应该直接返回这个元素。

第三点:处理递归中的k

在这条语句中:

return selection(a_list, pivotIndex+1, last, k - pivotIndex)

其实不需要从k中减去pivotIndex。即使下一个selection调用只考虑从pivotIndex+1last(包括last)的子列表,我们并没有创建一个只包含a_list[pivotIndex+1]a_list[last]的新的数组,因此我们关心的元素仍然会在位置k

做了这些改动后

我们可以保持partition函数不变。这里是更新后的selection函数:

def selection(a_list, first, last, k):
    # Handle possibility that first >= last, so we only have
    # one element remaining in the sublist
    if first >= last:
        return a_list[k]
    pivot = a_list[last]
    pivotIndex = partition(a_list, first, last)
    if k < pivotIndex:
        return selection(a_list, first, pivotIndex-1, k)
    else:
        # k is left as it is
        return selection(a_list, pivotIndex+1, last, k)

你应该把对selection的调用改成对k使用0索引。

希望这些能帮到你!

撰写回答