快速选择算法
我现在正在做计算机科学课的作业,但我搞不清楚我的代码在执行快速选择时出了什么问题。
def partition(a_list, first, last):
pivot = a_list[last]
i = first-1
for j in range(first, last):
if a_list[j] <= pivot:
i += 1
a_list[i], a_list[j] = a_list[j], a_list[i]
a_list[i+1], a_list[last] = a_list[last], a_list[i+1]
print(a_list)
return i+1
def selection(a_list, first, last, k):
pivot = a_list[last]
pivotIndex = partition(a_list, first, last)
if first == last:
return a_list[k]
elif k <= pivotIndex:
return selection(a_list, first, pivotIndex-1, k)
else:
return selection(a_list, pivotIndex+1, last, k - pivotIndex)
print(selection([5,4,1,10,8,3,2], 0, 6, 1))
print(selection([5,4,1,10,8,3,2], 0, 6, 3))
print(selection([5,4,1,10,8,3,2], 0, 6, 6))
print(selection([5,4,1,10,8,3,2], 0, 6, 7))
print(selection([46, 50, 16, 88, 79, 77, 17, 2, 43, 13, 86, 12, 68, 33, 81, \
74, 19, 52, 98, 70, 61, 71, 93, 5, 55], 0, 24, 19))
在第三个打印语句之后,我的代码就陷入了一个循环,最后因为达到最大递归次数而崩溃。第一个输出应该是1,我知道为什么会这样。但我就是找不到解决办法来修复它。
这是我的输出,直到最后出现最大递归深度达到的错误。(可以忽略打印的列表,它只是让我看看正在进行的分区情况)
[1, 2, 5, 10, 8, 3, 4]
2
[1, 2, 5, 10, 8, 3, 4]
[1, 2, 3, 4, 8, 5, 10]
3
[1, 2, 5, 10, 8, 3, 4]
[1, 2, 3, 4, 8, 5, 10]
[1, 2, 3, 4, 8, 5, 10]
[1, 2, 3, 4, 5, 8, 10]
[1, 2, 3, 5, 4, 8, 10]
[1, 2, 3, 4, 5, 8, 10]
[1, 2, 3, 5, 4, 8, 10]
[1, 2, 3, 4, 5, 8, 10]
1 个回答
partition
函数看起来没什么问题。主要的问题出在selection
函数上。具体有以下几点:
- 混用了0索引和1索引
- 检查
selection
函数的边界 - 处理递归中的
k
值
第一点:混用0索引和1索引
这个例子说明了问题:
print(selection([5,4,1,10,8,3,2], 0, 6, 1))
你在问题中提到期望的输出是1
。列表[5,4,1,10,8,3,2]
排序后是[1,2,3,4,5,8,10]
。在调用selection
函数时,你提供了0
和6
作为first
和last
。这两个变量使用的是0索引。而对于k
,你提供了1
,并期望selection
函数的输出是1
。这就用了1索引。
这样做没有错,但会很容易让人混淆。我们应该统一一下。我选择对k
使用0索引。
第二点:检查selection
函数的边界
特别是这条语句:
if first == last:
应该改成:
if first >= last:
因为下面的语句:
elif k <= pivotIndex:
return selection(a_list, first, pivotIndex-1, k)
else:
return selection(a_list, pivotIndex+1, last, k - pivotIndex)
在这两个递归调用selection
时,有可能出现first >= pivotIndex - 1
和pivotIndex + 1 >= last
的情况。在这种情况下,我们知道子列表中只剩下1个元素,所以我们应该直接返回这个元素。
第三点:处理递归中的k
值
在这条语句中:
return selection(a_list, pivotIndex+1, last, k - pivotIndex)
其实不需要从k
中减去pivotIndex
。即使下一个selection
调用只考虑从pivotIndex+1
到last
(包括last
)的子列表,我们并没有创建一个只包含a_list[pivotIndex+1]
到a_list[last]
的新的数组,因此我们关心的元素仍然会在位置k
。
做了这些改动后
我们可以保持partition
函数不变。这里是更新后的selection
函数:
def selection(a_list, first, last, k):
# Handle possibility that first >= last, so we only have
# one element remaining in the sublist
if first >= last:
return a_list[k]
pivot = a_list[last]
pivotIndex = partition(a_list, first, last)
if k < pivotIndex:
return selection(a_list, first, pivotIndex-1, k)
else:
# k is left as it is
return selection(a_list, pivotIndex+1, last, k)
你应该把对selection
的调用改成对k
使用0索引。
希望这些能帮到你!