Cython/Numba编译函数是否可以改进numpy.max.最大(numpy.abs公司(ab))?

2024-04-23 15:26:48 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在优化代码的瓶颈部分——迭代函数a'=f(a),其中a和a'是N乘1个向量,直到max(abs(a'-a))足够小。在

我在f(a)上加了一个Numba包装器,比我能生产的最优化的纯NumPy版本有了一个不错的加速(运行时减少了大约50%)。在

我试着写一个与C兼容的版本numpy.max.最大(numpy.abs公司(aprime-a)),但事实证明这是比较慢的!实际上,我失去了我从麻木迭代的第一部分得到的所有收益!在

有没有可能有一种方法可以让Numba或Cython改进numpy.max.最大(numpy.abs公司(四月一日)?我复制下面的代码以供参考,其中a是P0,a'是pptime:

编辑:对我来说,“flant()”输入到“maxabs()”似乎很重要。当我这样做的时候,表现并不比纽比差。然后,当我像JoshAdel建议的那样在时间括号之外对函数进行“试运行”时,带有“maxabs”的循环比使用“maxabs”的循环稍好一些numpy.max.最大(numpy.abs公司()). 在

from numba import jit
import numpy as np

### Preliminaries, to make the working example fully functional

n = 1200
Gammer = np.exp(-np.random.rand(n,n))

alpher = np.ones((n,1))
xxer = 10000*np.random.rand(n,1)

chii = 6.5
varkappa = 6.5
phi3 = 1.5
A = .5
sig = .2 

mmer = np.dot(Gammer,xxer**phi3)


totalprod = A*alpher + (1-A)*mmer
Gammerchii = Gammer**chii
Gammerrats = Gammerchii[:,0].flatten()/Gammerchii[0,:].flatten()
Gammerrats[(Gammerchii[0,:].flatten() == 0) | (Gammerchii[:,0].flatten() == 0)] = 1.
P0 = (Gammerrats*(xxer[0]/totalprod[0])*(totalprod/xxer).flatten())**(1/(1+2*chii))
P0 *= n/np.sum(P0)
### End of preliminaries

### This is the function to produce a' = f(a)
@jit
def Piteration(P0, chii, sig, n, xxer, totalprod, Gammerrats, Gammerchii):
    Mac = np.zeros((n,))
    Pprime = np.zeros((n,))
    themacpow = 1-(1/chii)*(sig/(1-sig))
    specialchiipow = 1/(1+2*chii)
    Psum = 0.

    for i in range(n):
        for j in range(n):
            Mac[j] += ((P0[i]/P0[j])**chii)*Gammerchii[i,j]*totalprod[j]

    for i in range(n):
        Pprime[i] = (Gammerrats[i]*(xxer[0]/totalprod[0])*(totalprod[i]/xxer[i])*((Mac[i]/Mac[0])**themacpow))**specialchiipow
        Psum += Pprime[i]

    Psum = n/Psum

    for i in range(n):
        Pprime[i] *= Psum

    return Pprime

### This is the function to find max(abs(aprime - a))
@jit
def maxabs(vec1,vec2,n):
    themax = 0.
    curdiff = 0.
    for i in range(n):
        curdiff = vec1[i] - vec2[i]
        if curdiff < 0:
            curdiff *= -1
        if curdiff > themax:
            themax = curdiff
    return themax

### This is the main loop
diff = 1000.
while diff > 1e-2:
    Pprime = Piteration(P0.flatten(),  chii,  sig,  n,  xxer.flatten(), totalprod.flatten(), Gammerrats.flatten(),  Gammerchii)

    diff = maxabs(P0.flatten(),Pprime.flatten(),n)
    P0 = 1.*Pprime

Tags: numpynpabsmaxsigflattenp0maxabs
1条回答
网友
1楼 · 发布于 2024-04-23 15:26:48

当我计算你的maxabs函数与np.max(np.abs(vec1 - vec2))的形状(1200,)数组时,使用numba 0.32.0的numba版本要快2.6倍。在

当您为代码计时时,请确保在计时之前先运行一次函数,这样就不会包括jit代码所需的时间,而jit代码只在第一次运行时付费。一般来说,使用timeit并多次运行可以解决这一问题。我不确定您是如何计时的,因为我发现使用maxabs与numpy调用几乎没有区别,大部分运行时似乎都在对Piteration的调用中。在

相关问题 更多 >