如何加速Numpy数组过滤/选择?

2024-04-19 19:33:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我有大约40k行,我想测试行上的各种选择组合。我所说的选择是布尔掩码。面罩/过滤器的数量约为250毫米。在

现行简化代码:

np_arr = np.random.randint(1, 40000, 40000)
results = np.empty(250000000)
filters = np.random.randint(1, size=(250000000, 40000))
for i in range(250000000):
    row_selection = np_arr[filters[i].astype(np.bool_)] # Select rows based on next filter
    # Performing simple calculations such as sum, prod, count on selected rows and saving to result
    results[i] = row_selection.sum() # Save simple calculation result to results array

我尝试了Numba和多处理,但由于大多数处理都是在过滤器选择中,而不是在计算中,所以这没有多大帮助。在

解决这个问题最有效的方法是什么?有没有什么方法可以把它并行化?就我所见,我需要遍历每个过滤器,然后分别计算sum、prod、count等,因为我不能并行地应用过滤器(即使应用过滤器后的计算非常简单)。在

感谢任何关于性能改进/加速的建议。在


Tags: 过滤器oncountnprandomprodsimpleresults
2条回答

要在Numba内获得良好的性能,只需避免掩蔽,因此需要非常昂贵的阵列拷贝。你必须自己实现过滤器,但是你提到的过滤器不应该有任何问题。在

并行化也很容易实现。在

示例

import numpy as np
import numba as nb

max_num = 250000 #250000000
max_num2 = 4000#40000
np_arr = np.random.randint(1, max_num2, max_num2)
filters = np.random.randint(low=0,high=2, size=(max_num, max_num2)).astype(np.bool_)

#Implement your functions like this, avoid masking
#Sum Filter
@nb.njit(fastmath=True)
def sum_filter(filter,arr):
  sum=0.
  for i in range(filter.shape[0]):
    if filter[i]==True:
      sum+=arr[i]
  return sum

#Implement your functions like this, avoid masking
#Prod Filter
@nb.njit(fastmath=True)
def prod_filter(filter,arr):
  prod=1.
  for i in range(filter.shape[0]):
    if filter[i]==True:
      prod*=arr[i]
  return sum

@nb.njit(parallel=True)
def main_func(np_arr,filters):
  results = np.empty(filters.shape[0])
  for i in nb.prange(max_num):
    results[i]=sum_filter(filters[i],np_arr)
    #results[i]=prod_filter(filters[i],np_arr)
  return results

一种改进的方法是将as_类型移到循环之外。在我的测试中,它将执行时间缩短了一半以上。 为了进行比较,请检查以下两个代码:

import numpy as np
import time

max_num = 250000 #250000000
max_num2 = 4000#40000
np_arr = np.random.randint(1, max_num2, max_num2)
results = np.empty(max_num)
filters = np.random.randint(1, size=(max_num, max_num2))
start = time.time()
for i in range(max_num):
    row_selection = np_arr[filters[i].astype(np.bool_)] # Select rows based on next filter
    # Performing simple calculations such as sum, prod, count on selected rows and saving to result
    results[i] = row_selection.sum() # Save simple calculation result to results array

end = time.time()
print(end - start)

接受2.12

同时

^{pr2}$

接受0.940

相关问题 更多 >