NumPy性能:uint8对float和乘法对除法?

2024-04-16 21:10:11 发布

您现在位置:Python中文网/ 问答频道 /正文

我刚刚注意到,我的脚本执行时间几乎减半,只需将乘法改为除法。在

为了调查这一点,我写了一个小例子:

import numpy as np                                                                                                                                                                                
import timeit

# uint8 array
arr1 = np.random.randint(0, high=256, size=(100, 100), dtype=np.uint8)

# float32 array
arr2 = np.random.rand(100, 100).astype(np.float32)
arr2 *= 255.0


def arrmult(a):
    """ 
    mult, read-write iterator
    """
    b = a.copy()
    for item in np.nditer(b, op_flags=["readwrite"]):
        item[...] = (item + 5) * 0.5

def arrmult2(a):
    """ 
    mult, index iterator
    """
    b = a.copy()
    for i, j in np.ndindex(b.shape):
        b[i, j] = (b[i, j] + 5) * 0.5

def arrmult3(a):
    """
    mult, vectorized
    """
    b = a.copy()
    b = (b + 5) * 0.5

def arrdiv(a):
    """ 
    div, read-write iterator 
    """
    b = a.copy()
    for item in np.nditer(b, op_flags=["readwrite"]):
        item[...] = (item + 5) / 2

def arrdiv2(a):
    """ 
    div, index iterator
    """
    b = a.copy()
    for i, j in np.ndindex(b.shape):
           b[i, j] = (b[i, j] + 5)  / 2                                                                                 

def arrdiv3(a):                                                                                                     
    """                                                                                                             
    div, vectorized                                                                                                 
    """                                                                                                             
    b = a.copy()                                                                                                    
    b = (b + 5) / 2                                                                                               




def print_time(name, t):                                                                                            
    print("{: <10}: {: >6.4f}s".format(name, t))                                                                    

timeit_iterations = 100                                                                                             

print("uint8 arrays")                                                                                               
print_time("arrmult", timeit.timeit("arrmult(arr1)", "from __main__ import arrmult, arr1", number=timeit_iterations))
print_time("arrmult2", timeit.timeit("arrmult2(arr1)", "from __main__ import arrmult2, arr1", number=timeit_iterations))
print_time("arrmult3", timeit.timeit("arrmult3(arr1)", "from __main__ import arrmult3, arr1", number=timeit_iterations))
print_time("arrdiv", timeit.timeit("arrdiv(arr1)", "from __main__ import arrdiv, arr1", number=timeit_iterations))  
print_time("arrdiv2", timeit.timeit("arrdiv2(arr1)", "from __main__ import arrdiv2, arr1", number=timeit_iterations))
print_time("arrdiv3", timeit.timeit("arrdiv3(arr1)", "from __main__ import arrdiv3, arr1", number=timeit_iterations))

print("\nfloat32 arrays")                                                                                           
print_time("arrmult", timeit.timeit("arrmult(arr2)", "from __main__ import arrmult, arr2", number=timeit_iterations))
print_time("arrmult2", timeit.timeit("arrmult2(arr2)", "from __main__ import arrmult2, arr2", number=timeit_iterations))
print_time("arrmult3", timeit.timeit("arrmult3(arr2)", "from __main__ import arrmult3, arr2", number=timeit_iterations))
print_time("arrdiv", timeit.timeit("arrdiv(arr2)", "from __main__ import arrdiv, arr2", number=timeit_iterations))  
print_time("arrdiv2", timeit.timeit("arrdiv2(arr2)", "from __main__ import arrdiv2, arr2", number=timeit_iterations))
print_time("arrdiv3", timeit.timeit("arrdiv3(arr2)", "from __main__ import arrdiv3, arr2", number=timeit_iterations))

这将打印以下计时:

^{pr2}$

我一直认为乘法在计算上比除法便宜。然而,对于uint8来说,除法的效果几乎是前者的两倍。这是否与* 0.5必须计算浮点乘法,然后将结果转换回整数?在

至少对于float,乘法似乎比除法快。这通常是真的吗?在

为什么uint8中的乘法比float32中的乘法更膨胀?我以为8位无符号整数应该比32位浮点运算快得多?!在

有人能“解开”这个谜团吗?在

编辑:为了获得更多的数据,我还包括了向量化函数(如建议的),并添加了索引迭代器。矢量化的函数要快得多,因此不具有真正的可比性。然而,如果向量化函数的timeit_iterations设置得更高,则uint8和{}的乘法速度更快。我想这更让人困惑?!在

也许乘法实际上总是比除法快,但是for循环的主要性能漏洞不是算术运算,而是循环本身。尽管这并不能解释为什么循环在不同的操作中表现不同。在

EDIT2:如@jotasi所述,我们正在寻找division与{}和{}(或{})与{}(或{})的完整解释。此外,解释矢量化方法和迭代器的不同趋势也很有趣,因为在矢量化的情况下,除法似乎比较慢,而在迭代器的情况下则更快。在


Tags: fromimportnumbertimemaindefnpprint
3条回答

这个答案只考虑矢量化操作,因为其他操作速度慢的原因已经由ead回答。在

很多“优化”都是基于旧硬件。在旧硬件上的优化是正确的,而在新的硬件上则不是旧的。在

管道和分区

除法慢。除法运算由几个单元组成,每个单元都必须一个接一个地执行一个计算。这就是分裂缓慢的原因。在

然而,在浮点处理单元(FPU)[在大多数现代cpu上常见]中,有专门的单元被安排在一个“流水线”中,用于除法指令。一旦一个单元完成,剩下的操作就不再需要这个单元了。如果你有几次除法行动,你可以让这些部队在下一次除法行动中无所事事。因此,尽管每个操作都很慢,但FPU实际上可以实现高吞吐量的除法运算。管道化与矢量化不同,但结果基本相同——当有许多相同的操作要做时,吞吐量会更高。在

想想管道交通。比较三条以30英里/小时的速度行驶的车辆与以90英里/小时行驶的一条车道的车辆。较慢的交通量肯定单独较慢,但三车道道路的吞吐量仍然相同。在

问题是你的假设,你测量除法或乘法所需的时间,这是不正确的。您正在测量除法或乘法所需的开销。在

我们真的要看确切的代码来解释每种效果,因为版本不同而不同。这个答案只能给出一个想法,一个必须考虑的问题。在

问题是一个简单的int在python中一点都不简单:它是一个真正的对象,必须在垃圾收集器中注册,它的大小随着它的值而增长-对于所有你必须付出的代价:例如,对于一个8位整数,需要24字节的内存!python浮动也是如此。在

另一方面,一个numpy数组由简单的c风格的整数/浮点数组成,没有开销,您可以节省大量内存,但在访问numpy数组的一个元素时需要为此付出代价。a[i]意味着:必须构造一个python整数,并将其注册到垃圾回收器中,并且只有在它可以使用时才可以使用它,这会产生大量的开销。在

考虑以下代码:

li1=[x%256 for x in xrange(10**4)]
arr1=np.array(li1, np.uint8)

def arrmult(a):    
    for i in xrange(len(a)):
        a[i]*=5;

arrmult(li1)arrmult(arr1)快25,因为列表中的整数已经是python int,不必创建!大部分的计算时间是用来创建对象的——其他的几乎可以忽略不计。在


让我们看看你的代码,首先是乘法:

^{pr2}$

对于uint8,必须发生以下情况(为了简单起见,我忽略了+5):

  1. 必须创建python int
  2. 必须将其强制转换为float(python float creation),以便能够进行浮点乘法
  3. 并转换回python int或/和uint8

对于float32,要做的工作更少(乘法不需要太多): 1创建了一个python float 2铸造后浮动32。在

应该是更快的版本。在


现在让我们来看看这个部门:

def arrdiv2(a):
    ...
    b[i, j] = (b[i, j] + 5)  / 2 

这里的陷阱是:所有操作都是整数运算。因此,与乘法相比,不需要强制转换为python float,因此与乘法相比,我们的开销更小。对于unint8,除法比乘法快。在

然而,对于float32,除法和乘法同样快/慢,因为在这种情况下几乎没有什么变化-我们仍然需要创建一个python float。在


现在的矢量化版本:它们可以使用c风格的“原始”float32s/uint8而无需转换(及其成本!)到引擎盖下相应的python对象。为了得到有意义的结果,你应该增加迭代的次数(现在运行时间太少,不能肯定地说些什么)。在

  1. float32的除法和乘法可以有相同的运行时间,因为我希望numpy通过0.5的乘法来替换除法2(但是为了确保人们必须查看代码)。

  2. uint8的乘法应该慢一些,因为每个uint8整数在与0.5相乘之前必须强制转换为浮点,然后再转换回uint8。

  3. 对于uint8的情况,numpy不能用0.5代替2的除法,因为它是整数除法。对于许多架构,整数除法比浮点乘法慢-这是最慢的矢量化操作。


附言:我不想过多地讨论成本乘法和除法,还有太多的其他事情会对性能产生更大的影响。例如,创建不必要的临时对象,或者如果numpy数组很大并且不能放入缓存中,那么内存访问将成为瓶颈—您将看不到乘法和除法之间的任何区别。在

这是因为你把一个int乘以一个float,然后把结果存储为一个int。 尝试使用不同的整数或浮点值来进行乘法/除法的arr_mult和arr_div测试。尤其是,比较乘“2”和乘“2”

相关问题 更多 >