我有一个函数,它处理一个维数为(h,w,200)的输入数组(数字200可以变化),并返回一个维数为(h,w,50,3)的数组。对于大小为51251220的输入数组,该函数需要约0.8秒。你知道吗
def myfunc(arr, n = 50):
#shape of arr is (h,w,200)
#output shape is (h,w,50,3)
#a1 is an array of length 50, I get them from a different
#function, which doesn't take much time. For simplicity, I fix it
#as np.arange(0,50)
a1 = np.arange(0,50)
output = np.stack((arr[:,:,a1],)*3, axis = -1)
return output
此预处理步骤在单个批中对~8个数组执行,因此加载一批数据需要8*0.8=6.4秒。有没有办法加快myfunc的计算速度?我能用numba这样的库来做这个吗?你知道吗
我差不多同时得到:
更详细地看时间安排。你知道吗
首先是索引/复制步骤,大约需要1/3的时间:
以及
stack
:stack
展开维度,然后连接;因此让我们直接调用concatenate:另一种方法是使用
repeat
:看来你的代码已经很好了。你知道吗
索引和连接已经使用编译过的代码,所以我不希望
numba
有太多帮助(这不是因为我有太多的经验)。你知道吗在新的前轴上堆叠更快(使(3,512,512,50))
尽管后续操作可能会比较慢(如果它们需要拷贝和/或重新排序的话),但可以(廉价地)进行转置。一个普通的
copy
的完整output
数组在大约350毫秒的时间受到评论的启发,我尝试了广播作业:
同样的球场。你知道吗
另一个技巧是使用
strides
制作“虚拟”副本:由于某些原因,这不适用于
(512,512,200,3)
。它可能与broadcast_to
实现有关。也许有人可以用as_strided
做实验。你知道吗尽管我可以很好地转置它:
在任何情况下,这都要快得多:
(但是做一个
copy
会让时间倒流。)相关问题 更多 >
编程相关推荐