为什么我的基数排序实现比快速排序慢?
我把维基百科上的基数排序算法用Python重写了一遍,使用了SciPy的数组来提高性能和减少代码长度,结果我成功了。接着,我又拿了经典的快速排序算法进行比较,快速排序是基于内存和枢轴的,来源于《文艺编程》。我想看看基数排序在某个阈值之后会不会比快速排序更快,但结果并没有达到我的预期。
我还发现了Erik Gorset的博客上提到的问题:“基数排序在整数数组中真的比快速排序快吗?”。那里的答案是:
.. 基准测试显示,对于大数组,MSB就地基数排序的速度始终是快速排序的三倍以上。
可惜的是,我没能复现这个结果;不同之处在于(a)Erik用的是Java而不是Python,(b)他使用的是MSB就地基数排序,而我只是用Python字典填充了桶。
根据理论,基数排序应该比快速排序快(是线性的);但显然这很大程度上取决于具体的实现。那么我的错误在哪里呢?
下面是比较这两种算法的代码:
from sys import argv
from time import clock
from pylab import array, vectorize
from pylab import absolute, log10, randint
from pylab import semilogy, grid, legend, title, show
###############################################################################
# radix sort
###############################################################################
def splitmerge0 (ls, digit): ## python (pure!)
seq = map (lambda n: ((n // 10 ** digit) % 10, n), ls)
buf = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}
return reduce (lambda acc, key: acc.extend(buf[key]) or acc,
reduce (lambda _, (d,n): buf[d].append (n) or buf, seq, buf), [])
def splitmergeX (ls, digit): ## python & numpy
seq = array (vectorize (lambda n: ((n // 10 ** digit) % 10, n)) (ls)).T
buf = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}
return array (reduce (lambda acc, key: acc.extend(buf[key]) or acc,
reduce (lambda _, (d,n): buf[d].append (n) or buf, seq, buf), []))
def radixsort (ls, fn = splitmergeX):
return reduce (fn, xrange (int (log10 (absolute (ls).max ()) + 1)), ls)
###############################################################################
# quick sort
###############################################################################
def partition (ls, start, end, pivot_index):
lower = start
upper = end - 1
pivot = ls[pivot_index]
ls[pivot_index] = ls[end]
while True:
while lower <= upper and ls[lower] < pivot: lower += 1
while lower <= upper and ls[upper] >= pivot: upper -= 1
if lower > upper: break
ls[lower], ls[upper] = ls[upper], ls[lower]
ls[end] = ls[lower]
ls[lower] = pivot
return lower
def qsort_range (ls, start, end):
if end - start + 1 < 32:
insertion_sort(ls, start, end)
else:
pivot_index = partition (ls, start, end, randint (start, end))
qsort_range (ls, start, pivot_index - 1)
qsort_range (ls, pivot_index + 1, end)
return ls
def insertion_sort (ls, start, end):
for idx in xrange (start, end + 1):
el = ls[idx]
for jdx in reversed (xrange(0, idx)):
if ls[jdx] <= el:
ls[jdx + 1] = el
break
ls[jdx + 1] = ls[jdx]
else:
ls[0] = el
return ls
def quicksort (ls):
return qsort_range (ls, 0, len (ls) - 1)
###############################################################################
if __name__ == "__main__":
###############################################################################
lower = int (argv [1]) ## requires: >= 2
upper = int (argv [2]) ## requires: >= 2
color = dict (enumerate (3*['r','g','b','c','m','k']))
rslbl = "radix sort"
qslbl = "quick sort"
for value in xrange (lower, upper):
#######################################################################
ls = randint (1, value, size=value)
t0 = clock ()
rs = radixsort (ls)
dt = clock () - t0
print "%06d -- t0:%0.6e, dt:%0.6e" % (value, t0, dt)
semilogy (value, dt, '%s.' % color[int (log10 (value))], label=rslbl)
#######################################################################
ls = randint (1, value, size=value)
t0 = clock ()
rs = quicksort (ls)
dt = clock () - t0
print "%06d -- t0:%0.6e, dt:%0.6e" % (value, t0, dt)
semilogy (value, dt, '%sx' % color[int (log10 (value))], label=qslbl)
grid ()
legend ((rslbl,qslbl), numpoints=3, shadow=True, prop={'size':'small'})
title ('radix & quick sort: #(integer) vs duration [s]')
show ()
###############################################################################
###############################################################################
这里是比较整数数组排序耗时的结果(纵轴是对数刻度),数组大小范围从2到1250(横轴);快速排序的曲线在下面:
快速排序在幂次变化时表现得很平滑(例如在10、100或1000时),而基数排序则稍微跳动一下,但整体上跟快速排序的路径是差不多的,只是速度慢了很多!
2 个回答
你的数据表示方式成本很高。你为什么要用 哈希表 来做桶呢?为什么要用十进制表示法,这样你还得计算对数(= 计算起来很麻烦)?
尽量避免使用 lambda 表达式之类的,我觉得 Python 目前还不能很好地优化这些。
也许可以先用 10 字节的字符串来进行基准测试。还有:不要使用哈希表和类似的高成本数据结构。
你这里有几个问题。
首先,正如评论中提到的,你的数据集太小了,理论上的复杂度无法克服代码中的开销。
接下来,你的实现方式有很多不必要的函数调用和列表复制,这样效率很低。用简单的步骤式写法来写代码,通常会比函数式的解决方案快(对于Python来说,其他语言可能会有所不同)。你已经有了一个快速排序的步骤式实现,如果你用同样的风格来写基数排序,可能即使是小列表也会更快。
最后,当你尝试处理大列表时,内存管理的开销可能会变得很重要。这意味着在小列表时,代码的效率是主要因素,而在大列表时,内存管理的开销才是主要因素,你的效率窗口是有限的。
这里有一些代码,使用了你的快速排序,但基数排序是用简单的步骤式写法,尽量避免过多的数据复制。你会发现即使是短列表,它的表现也超过了快速排序,更有趣的是,随着数据量的增加,快速排序和基数排序的效率比率也在上升,但当内存管理开始占主导时,这个比率又会下降(像释放一个包含1,000,000个项目的列表这样的简单操作会花费相当多的时间):
from random import randint
from math import log10
from time import clock
from itertools import chain
def splitmerge0 (ls, digit): ## python (pure!)
seq = map (lambda n: ((n // 10 ** digit) % 10, n), ls)
buf = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}
return reduce (lambda acc, key: acc.extend(buf[key]) or acc,
reduce (lambda _, (d,n): buf[d].append (n) or buf, seq, buf), [])
def splitmerge1 (ls, digit): ## python (readable!)
buf = [[] for i in range(10)]
divisor = 10 ** digit
for n in ls:
buf[(n//divisor)%10].append(n)
return chain(*buf)
def radixsort (ls, fn = splitmerge1):
return list(reduce (fn, xrange (int (log10 (max(abs(val) for val in ls)) + 1)), ls))
###############################################################################
# quick sort
###############################################################################
def partition (ls, start, end, pivot_index):
lower = start
upper = end - 1
pivot = ls[pivot_index]
ls[pivot_index] = ls[end]
while True:
while lower <= upper and ls[lower] < pivot: lower += 1
while lower <= upper and ls[upper] >= pivot: upper -= 1
if lower > upper: break
ls[lower], ls[upper] = ls[upper], ls[lower]
ls[end] = ls[lower]
ls[lower] = pivot
return lower
def qsort_range (ls, start, end):
if end - start + 1 < 32:
insertion_sort(ls, start, end)
else:
pivot_index = partition (ls, start, end, randint (start, end))
qsort_range (ls, start, pivot_index - 1)
qsort_range (ls, pivot_index + 1, end)
return ls
def insertion_sort (ls, start, end):
for idx in xrange (start, end + 1):
el = ls[idx]
for jdx in reversed (xrange(0, idx)):
if ls[jdx] <= el:
ls[jdx + 1] = el
break
ls[jdx + 1] = ls[jdx]
else:
ls[0] = el
return ls
def quicksort (ls):
return qsort_range (ls, 0, len (ls) - 1)
if __name__=='__main__':
for value in 1000, 10000, 100000, 1000000, 10000000:
ls = [randint (1, value) for _ in range(value)]
ls2 = list(ls)
last = -1
start = clock()
ls = radixsort(ls)
end = clock()
for i in ls:
assert last <= i
last = i
print("rs %d: %0.2fs" % (value, end-start))
tdiff = end-start
start = clock()
ls2 = quicksort(ls2)
end = clock()
last = -1
for i in ls2:
assert last <= i
last = i
print("qs %d: %0.2fs %0.2f%%" % (value, end-start, ((end-start)/tdiff*100)))
我运行这个代码时的输出是:
C:\temp>c:\python27\python radixsort.py
rs 1000: 0.00s
qs 1000: 0.00s 212.98%
rs 10000: 0.02s
qs 10000: 0.05s 291.28%
rs 100000: 0.19s
qs 100000: 0.58s 311.98%
rs 1000000: 2.47s
qs 1000000: 7.07s 286.33%
rs 10000000: 31.74s
qs 10000000: 86.04s 271.08%
编辑:
为了澄清一下,这里的快速排序实现非常节省内存,它是在原地排序,所以无论列表多大,只是在移动数据而不是复制它。原来的基数排序实际上是对每个数字复制列表两次:一次是放入小列表,然后在连接这些列表时又复制一次。使用 itertools.chain
可以避免第二次复制,但仍然会有很多内存分配和释放的操作。(另外,“两次”是个大概,因为列表追加确实涉及额外的复制,尽管它是摊销的O(1),所以我可能应该说“与两次成比例”。)