如何缩短二维连接域上集成的集成时间

2024-04-20 04:17:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要计算简单连通域上的许多二维积分(大多数时候是凸的)。我正在使用python函数scipy.integrate.nquad来进行此集成。然而,与矩形区域上的积分相比,此操作所需的时间明显较大。有没有可能更快的实现

这是一个例子;我首先在一个圆形域(使用函数内部的约束)上积分一个常量函数,然后在一个矩形域(默认的nquad函数域)上积分

from scipy import integrate
import time

def circular(x,y,a):
  if x**2 + y**2 < a**2/4:
    return 1 
  else:
    return 0

def rectangular(x,y,a):
  return 1

a = 4
start = time.time()
result = integrate.nquad(circular, [[-a/2, a/2],[-a/2, a/2]], args=(a,))
now = time.time()
print(now-start)

start = time.time()
result = integrate.nquad(rectangular, [[-a/2, a/2],[-a/2, a/2]], args=(a,))
now = time.time()
print(now-start)

矩形域只需要0.00029秒,而圆形域需要2.07061秒才能完成

此外,循环积分给出以下警告:

IntegrationWarning: The maximum number of subdivisions (50) has been achieved.
If increasing the limit yields no improvement it is advised to analyze 
the integrand in order to determine the difficulties.  If the position of a 
local difficulty can be determined (singularity, discontinuity) one will 
probably gain from splitting up the interval and calling the integrator 
on the subranges.  Perhaps a special-purpose integrator should be used.
**opt)

Tags: the函数fromimportreturntimedefscipy
1条回答
网友
1楼 · 发布于 2024-04-20 04:17:00

加快计算速度的一种方法是使用^{},这是一种用于Python的即时编译器

@jit装饰器

Numba提供了一个^{} decorator来编译一些Python代码,并输出可以在多个CPU上并行运行的优化机器代码。jit被积函数只需要很少的努力,并且会节省一些时间,因为代码经过了优化以运行得更快。人们甚至不用担心类型,Numba在幕后完成了所有这些

from scipy import integrate
from numba import jit

@jit
def circular_jit(x, y, a):
    if x**2 + y**2 < a**2 / 4:
        return 1 
    else:
        return 0

a = 4
result = integrate.nquad(circular_jit, [[-a/2, a/2],[-a/2, a/2]], args=(a,))

这确实运行得更快,在我的机器上计时时,我得到:

 Original circular function: 1.599048376083374
 Jitted circular function: 0.8280022144317627

这将减少约50%的计算时间

Scipy的LowLevelCallable

由于Python语言的性质,Python中的函数调用相当耗时。与C等编译语言相比,这种开销有时会使Python代码速度变慢

为了缓解这种情况,Scipy提供了一个^{}类,可用于提供对低级编译回调函数的访问。通过这种机制,可以绕过Python的函数调用开销,进一步节省时间

注意,在nquad的情况下,传递给LowerLevelCallablecfunc的签名必须是以下之一:

double func(int n, double *xx)
double func(int n, double *xx, void *user_data)

其中int是参数的数量,参数的值在第二个参数中user_data用于需要上下文操作的回调

因此,我们可以稍微更改Python中的循环函数签名以使其兼容

from scipy import integrate, LowLevelCallable
from numba import cfunc
from numba.types import intc, CPointer, float64


@cfunc(float64(intc, CPointer(float64)))
def circular_cfunc(n, args):
    x, y, a = (args[0], args[1], args[2]) # Cannot do `(args[i] for i in range(n))` as `yield` is not supported
    if x**2 + y**2 < a**2/4:
        return 1 
    else:
        return 0

circular_LLC = LowLevelCallable(circular_cfunc.ctypes)

a = 4
result = integrate.nquad(circular_LLC, [[-a/2, a/2],[-a/2, a/2]], args=(a,))

用这种方法我得到

LowLevelCallable circular function: 0.07962369918823242

与原始版本相比,该函数减少了95%,与jitted版本相比,该函数减少了90%

定制的装饰师

为了使代码更整洁,并保持被积函数签名的灵活性,可以创建一个定制的装饰函数。它将jit被积函数,并将其包装成LowLevelCallable对象,然后可与nquad一起使用

from scipy import integrate, LowLevelCallable
from numba import cfunc, jit
from numba.types import intc, CPointer, float64

def jit_integrand_function(integrand_function):
    jitted_function = jit(integrand_function, nopython=True)

    @cfunc(float64(intc, CPointer(float64)))
    def wrapped(n, xx):
        return jitted_function(xx[0], xx[1], xx[2])
    return LowLevelCallable(wrapped.ctypes)


@jit_integrand_function
def circular(x, y, a):
    if x**2 + y**2 < a**2 / 4:
        return 1
    else:
        return 0

a = 4
result = integrate.nquad(circular, [[-a/2, a/2],[-a/2, a/2]], args=(a,))

任意数量的参数

如果参数的数量未知,那么我们可以使用Numba提供的方便的^{} functionCPointer(float64)转换为Numpy数组

import numpy as np
from scipy import integrate, LowLevelCallable
from numba import cfunc, carray, jit
from numba.types import intc, CPointer, float64

def jit_integrand_function(integrand_function):
    jitted_function = jit(integrand_function, nopython=True)

    @cfunc(float64(intc, CPointer(float64)))
    def wrapped(n, xx):
        ar = carray(xx, n)
        return jitted_function(ar[0], ar[1], ar[2:])
    return LowLevelCallable(wrapped.ctypes)


@jit_integrand_function
def circular(x, y, a):
    if x**2 + y**2 < a[-1]**2 / 4:
        return 1
    else:
        return 0

ar = np.array([1, 2, 3, 4])
a = ar[-1]
result = integrate.nquad(circular, [[-a/2, a/2],[-a/2, a/2]], args=ar)

相关问题 更多 >