使用numpy数组避免for循环组合

2024-06-16 11:10:33 发布

您现在位置:Python中文网/ 问答频道 /正文

一定有一种更像Python的方式:

r = np.arange(100)
results = []

for i in r:
    for j in r:
        for k in r:
            for l in r:

                #Here f is some predefined function
                if f(i,j,k,l) < 5.0:
                     results.append(f(i,j,k,l))

我确信使用数组可以简化这一点,但我不确定如何简化。谢谢!在


Tags: inforifhereisnp方式function
2条回答

使用NumPy的^{}和{a2}可以避免for循环和if语句。提出的方法被包装在comb_np(n)中,而@Ohad-Eytan提出的基于itertools的解决方案被包装在comb_it(n)中。为了方便起见,每个For循环(在您的示例中为100)上的迭代次数作为参数传递给两个函数(n)。为了比较分析这两种方法,我使用了一个简单的政治函数f(x, y, z, t)。在

from numpy import fromfunction
from itertools import product
from numpy import arange

def f(x, y, z, t):
    return x + 2*y + 3*z + 4*t

def comb_np(n):
    arr = fromfunction(f, (n,)*4)
    return arr[arr < 5.0]

def comb_it(n):
    return [f(i,j,k,l) for (i,j,k,l) in product(arange(n),repeat=4) if f(i,j,k,l) < 5.0]

样本运行:

^{pr2}$

两种方法产生相同的结果。到目前为止,还不错。但现在让我们来评估一下在效率方面是否存在任何差异:

In [1304]: import timeit

In [1305]: timeit.timeit("comb_np(10)", setup="from numpy import fromfunction;from __main__ import comb_np, f", number=1)
Out[1305]: 0.0008685288485139608

In [1306]: timeit.timeit("comb_it(10)", setup="from itertools import product;from numpy import arange;from __main__ import comb_it, f", number=1)
Out[1306]: 0.05153228418203071

In [1307]: timeit.timeit("comb_np(100)", setup="from numpy import fromfunction;from __main__ import comb_np, f", number=1)
Out[1307]: 3.4775129712652415

In [1308]: timeit.timeit("comb_it(100)", setup="from itertools import product;from numpy import arange;from __main__ import comb_it, f", number=1)
Out[1308]: 354.3811327822914

从上面的结果可以清楚地看出,在这个特定的问题中,NumPy的向量化代码比迭代器的性能高出大约两个数量级。在


有趣的是,我发现只要将NumPy的arange替换为内置函数range,那么{}的性能就会显著提高:

def comb_it2(n):
    return [f(i,j,k,l) for (i,j,k,l) in product(range(n),repeat=4) if f(i,j,k,l) < 5.0]

结果:

In [1381]: comb_it2(10)
Out[1381]: [0, 4, 3, 2, 4, 1, 4, 3, 2, 4, 3, 4]

In [1382]: timeit.timeit("comb_it2(10)", setup="from itertools import product;from __main__ import comb_it2, f", number=1)
Out[1382]: 0.009133451094385237

In [1383]: timeit.timeit("comb_it2(100)", setup="from itertools import product;from __main__ import comb_it2, f", number=1)
Out[1383]: 32.556062019226374

使用itertools笛卡尔积:

import itertools
r = np.arange(100)
results = []
for (i,j,k,l) in itertools.product(r,repeat=4):
    if f(i,j,k,l) < 5.0:
         results.append(f(i,j,k,l))

或者更简洁的方式,使用列表理解:

^{pr2}$

相关问题 更多 >