需要帮助来加速这段代码Python和numpy吗

2024-04-25 14:27:32 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个函数,它处理一个维数为(h,w,200)的输入数组(数字200可以变化),并返回一个维数为(h,w,50,3)的数组。对于大小为51251220的输入数组,该函数需要约0.8秒。你知道吗

def myfunc(arr, n = 50):
    #shape of arr is (h,w,200)
    #output shape is (h,w,50,3)

    #a1 is an array of length 50, I get them from a different 
    #function, which doesn't take much time. For simplicity, I fix it 
    #as np.arange(0,50)

    a1 = np.arange(0,50)


    output = np.stack((arr[:,:,a1],)*3, axis = -1)

    return output

此预处理步骤在单个批中对~8个数组执行,因此加载一批数据需要8*0.8=6.4秒。有没有办法加快myfunc的计算速度?我能用numba这样的库来做这个吗?你知道吗


Tags: of函数anoutputisdefa1np
1条回答
网友
1楼 · 发布于 2024-04-25 14:27:32

我差不多同时得到:

In [14]: arr = np.ones((512,512,200))                                                                        
In [15]: timeit output = np.stack((arr[:,:,np.arange(50)],)*3, axis=-1)                                      
681 ms ± 5.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [16]: np.stack((arr[:,:,np.arange(50)],)*3, axis=-1).shape                                                
Out[16]: (512, 512, 50, 3)

更详细地看时间安排。你知道吗

首先是索引/复制步骤,大约需要1/3的时间:

In [17]: timeit arr[:,:,np.arange(50)]                                                                       
249 ms ± 306 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

以及stack

In [18]: %%timeit temp = arr[:,:,np.arange(50)] 
    ...: output = np.stack([temp,temp,temp], axis=-1) 
    ...:  
    ...:                                                                                                     
426 ms ± 367 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

stack展开维度,然后连接;因此让我们直接调用concatenate:

In [19]: %%timeit temp = arr[:,:,np.arange(50),None] 
    ...: output = np.concatenate([temp,temp,temp], axis=-1) 
    ...:  
    ...:                                                                                                     
430 ms ± 8.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

另一种方法是使用repeat

In [20]: %%timeit temp = arr[:,:,np.arange(50),None] 
    ...: output = np.repeat(temp, 3, axis=-1) 
    ...:  
    ...:                                                                                                     
531 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

看来你的代码已经很好了。你知道吗

索引和连接已经使用编译过的代码,所以我不希望numba有太多帮助(这不是因为我有太多的经验)。你知道吗

在新的前轴上堆叠更快(使(3,512,512,50))

In [21]: %%timeit temp = arr[:,:,np.arange(50)] 
    ...: output = np.stack([temp,temp,temp]) 
    ...:  
    ...:                                                                                                     
254 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

尽管后续操作可能会比较慢(如果它们需要拷贝和/或重新排序的话),但可以(廉价地)进行转置。一个普通的copy的完整output数组在大约350毫秒的时间


受到评论的启发,我尝试了广播作业:

In [101]: %%timeit temp = arr[:,:,np.arange(50)]  
     ...: res = np.empty(temp.shape + (3,), temp.dtype) 
     ...: res[...] = temp[...,None] 
     ...:  
     ...:  
     ...:                                                                                                    
337 ms ± 1.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

同样的球场。你知道吗

另一个技巧是使用strides制作“虚拟”副本:

In [74]: res1 = np.broadcast_to(arr, (3,)+arr.shape)                                                         
In [75]: res1.shape                                                                                          
Out[75]: (3, 512, 512, 200)
In [76]: res1.strides                                                                                        
Out[76]: (0, 819200, 1600, 8)

由于某些原因,这不适用于(512,512,200,3)。它可能与broadcast_to实现有关。也许有人可以用as_strided做实验。你知道吗

尽管我可以很好地转置它:

np.broadcast_to(arr, (3,)+arr.shape).transpose(1,2,3,0) 

在任何情况下,这都要快得多:

In [82]: timeit res1 = np.broadcast_to(arr, (3,)+arr.shape)                                                  
10.4 µs ± 188 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

(但是做一个copy会让时间倒流。)

相关问题 更多 >

    热门问题