使用Python merge/a编写的Python sort/a:
import numba as nb
import numpy as np
@nb.jit( nopython=True )
def merge( x ):
n = x.shape[0]
width=1
r = x.copy()
tgt = np.empty_like( r )
while width<n:
i=0
while i<n:
istart = i
imid = i+width
iend = imid+width
# i has become i+2*width
i = iend
if imid>n:
imid = n
if iend>n:
iend=n
_merge( r, tgt, istart, imid, iend)
# Swap them round, so that the partially sorted tgt becomes the result,
# and the result becomes a new target buffer
r, tgt = tgt, r
width*=2
return r
@nb.jit( nopython=True )
def _merge( src_arr, tgt_arr, istart, imid, iend ):
""" The merge part of the merge sort """
i0 = istart
i1 = imid
for ipos in range( istart, iend ):
if ( i0<imid ) and ( ( i1==iend ) or ( src_arr[ i0 ] < src_arr[ i1 ] ) ):
tgt_arr[ ipos ] = src_arr[ i0 ]
i0+=1
else:
tgt_arr[ ipos ] = src_arr[ i1 ]
i1+=1
我为此写了一个测试:
^{pr2}$我用了这个计时器类:
^{3}$试验结果为:
nb/np performance 9307.846153856719
nb/np performance 1.1428571428616743
nb/np performance 0.7142857142925115
nb/np performance 0.8333333333302494
nb/np performance 0.9999999999814962
nb/np performance 0.9999999999777955
nb/np performance 0.8333333333456692
nb/np performance 0.8333333333302494
nb/np performance 1.0
nb/np performance 0.8333333333456692
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 0.8333333333456692
nb/np performance 0.9999999999814962
nb/np performance 1.0
nb/np performance 0.9999999999814962
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 1.0000000000185036
nb/np performance 1.2000000000044408
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 1.0000000000185036
nb/np performance 1.2000000000088817
nb/np performance 1.0
nb/np performance 1.1666666666512469
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 0.9999999999814962
nb/np performance 1.1666666666728345
nb/np performance 1.1666666666512469
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 1.1666666666512469
nb/np performance 1.1666666666512469
nb/np performance 1.1666666666728345
nb/np performance 1.1666666666728345
nb/np performance 1.1666666666728345
nb/np performance 1.1666666666728345
nb/np performance 1.1666666666512469
nb/np performance 1.1666666666512469
nb/np performance 1.0
nb/np performance 1.1666666666728345
nb/np performance 1.3333333333456692
nb/np performance 1.3333333333024937
nb/np performance 1.3333333333456692
nb/np performance 1.1428571428435483
nb/np performance 1.3333333333209976
nb/np performance 1.1666666666728345
nb/np performance 1.3333333333456692
nb/np performance 1.3333333333209976
nb/np performance 1.000000000012336
nb/np performance 1.1428571428616743
nb/np performance 1.3333333333456692
nb/np performance 1.3333333333209976
nb/np performance 1.1428571428616743
nb/np performance 1.1428571428616743
nb/np performance 1.3333333333456692
nb/np performance 1.499999999990748
nb/np performance 1.2857142857074884
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857029569
nb/np performance 1.1428571428616743
nb/np performance 1.1428571428435483
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857029569
nb/np performance 1.1249999999895917
nb/np performance 1.2857142857029569
nb/np performance 1.2857142857233488
nb/np performance 1.4285714285623656
nb/np performance 1.249999999993061
nb/np performance 1.1250000000034694
nb/np performance 1.2857142857029569
图形化结果(来自不同的运行):
长期运行的图形化结果:
请注意,对于n<;=20,numpy在mergesort被调用时使用插入排序:https://github.com/numpy/numpy/blob/master/numpy/core/src/npysort/mergesort.c.src
所以您可以看到,对于n的小值,mergesort的numba版本胜过numpy版本。在
然而,当n变大时,numpy的表现总是比numba高出2倍。在
为什么会这样?我该如何优化numba版本来击败numpy版本呢?在
如果你的人生目标是打败numpy的实现,那么你不妨更紧密地复制那里正在做的事情。在算法上与您所实现的有两个主要区别:
NumPy通过实际的递归实现自顶向下的递归。您使用的是自底向上的方法,这为您节省了递归堆栈,但通常会产生不平衡的合并,这会降低效率。
虽然您的乒乓缓冲区方法是一个整洁的方法,但您移动的数据比严格需要的要多。像NumPy那样进行适当的排序,可以将需要访问的总内存的大小至少减少到实现的75%,这可能也有助于提高缓存性能。
抛开Numba魔法不谈,这与NumPy的mergesort的内部工作原理非常接近:
相关问题 更多 >
编程相关推荐