如何加速nsolve或二分法?

2024-06-10 01:41:12 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在写一个程序,它需要某种类型的根查找器,但是我使用的每个根查找器都非常慢。我在想办法加快速度。你知道吗

我使用了SymPy的nsolve,虽然这会产生非常精确的结果,但它非常慢(如果我对程序进行12次迭代,则需要12个多小时才能运行)。我编写了自己的二等分方法,效果更好,但仍然非常慢(12次迭代需要~1小时才能运行)。我一直无法找到一个symengine解算器,或者这就是我将要使用的。我将发布我的两个程序(使用二分法和nsolve)。任何关于如何加快这一进程的建议都将不胜感激。你知道吗

下面是使用nsolve的代码:

from symengine import *
import sympy
from sympy import Matrix
from sympy import nsolve

trial = Matrix()

r, E1, E = symbols('r, E1, E')
H11, H22, H12, H21 = symbols("H11, H22, H12, H21")
S11, S22, S12, S21 = symbols("S11, S22, S12, S21")
low = 0
high = oo

integrate = lambda *args: sympy.N(sympy.integrate(*args))

quadratic_expression = (H11-E1*S11)*(H22-E1*S22)-(H12-E1*S12)*(H21-E1*S21)
general_solution = sympify(sympy.solve(quadratic_expression, E1)[0])


def solve_quadratic(**kwargs):
    return general_solution.subs(kwargs)


def H(fun):
    return -fun.diff(r, 2)/2 - fun.diff(r)/r - fun/r


psi0 = exp(-3*r/2)
trial = trial.row_insert(0, Matrix([psi0]))
I1 = integrate(4*pi*(r**2)*psi0*H(psi0), (r, low, high))
I2 = integrate(4*pi*(r**2)*psi0**2, (r, low, high))
E0 = I1/I2
print(E0)

for x in range(10):
    f1 = psi0
    f2 = r * (H(psi0)-E0*psi0)
    Hf1 = H(f1).simplify()
    Hf2 = H(f2).simplify()

    H11 = integrate(4*pi*(r**2)*f1*Hf1, (r, low, high))
    H12 = integrate(4*pi*(r**2)*f1*Hf2, (r, low, high))
    H21 = integrate(4*pi*(r**2)*f2*Hf1, (r, low, high))
    H22 = integrate(4*pi*(r**2)*f2*Hf2, (r, low, high))

    S11 = integrate(4*pi*(r**2)*f1**2, (r, low, high))
    S12 = integrate(4*pi*(r**2)*f1*f2, (r, low, high))
    S21 = S12
    S22 = integrate(4*pi*(r**2)*f2**2, (r, low, high))

    E0 = solve_quadratic(
            H11=H11, H22=H22, H12=H12, H21=H21,
            S11=S11, S22=S22, S12=S12, S21=S21,
        )
    print(E0)

    C = -(H11 - E0*S11)/(H12 - E0*S12)
    psi0 = (f1 + C*f2).simplify()
    trial = trial.row_insert(x+1, Matrix([[psi0]]))

# Free ICI Part

h = zeros(x+2, x+2)
HS = zeros(x+2, 1)
S = zeros(x+2, x+2)

for s in range(x+2):
    HS[s] = H(trial[s]).simplify()

for i in range(x+2):
    for j in range(x+2):
        h[i, j] = integrate(4*pi*(r**2)*trial[i]*HS[j], (r, low, high))

for i in range(x+2):
    for j in range(x+2):
        S[i, j] = integrate(4*pi*(r**2)*trial[i]*trial[j], (r, low, high))

m = h - E*S
eqn = m.det()

roots = nsolve(eqn, float(E0))

print(roots)

下面是使用我的二分法的代码:

from symengine import *
import sympy
from sympy import Matrix
from sympy import nsolve

trial = Matrix()

r, E1, E = symbols('r, E1, E')
H11, H22, H12, H21 = symbols("H11, H22, H12, H21")
S11, S22, S12, S21 = symbols("S11, S22, S12, S21")
low = 0
high = oo

integrate = lambda *args: sympy.N(sympy.integrate(*args))

quadratic_expression = (H11-E1*S11)*(H22-E1*S22)-(H12-E1*S12)*(H21-E1*S21)
general_solution = sympify(sympy.solve(quadratic_expression, E1)[0])


def solve_quadratic(**kwargs):
    return general_solution.subs(kwargs)


def bisection(fun, a, b, tol):
    NMax = 100000
    f = Lambdify(E, fun)
    FA = f(a)
    for n in range(NMax):
        p = (b+a)/2
        FP = f(p)
        if FP == 0 or abs(b-a)/2 < tol:
            return p
        if FA*FP > 0:
            a = p
            FA = FP
        else:
            b = p
    print("Failed to converge to desired tolerance")



def H(fun):
    return -fun.diff(r, 2)/2 - fun.diff(r)/r - fun/r


psi0 = exp(-3*r/2)
trial = trial.row_insert(0, Matrix([psi0]))
I1 = integrate(4*pi*(r**2)*psi0*H(psi0), (r, low, high))
I2 = integrate(4*pi*(r**2)*psi0**2, (r, low, high))
E0 = I1/I2
print(E0)

for x in range(11):
    f1 = psi0
    f2 = r * (H(psi0)-E0*psi0)
    Hf1 = H(f1).simplify()
    Hf2 = H(f2).simplify()

    H11 = integrate(4*pi*(r**2)*f1*Hf1, (r, low, high))
    H12 = integrate(4*pi*(r**2)*f1*Hf2, (r, low, high))
    H21 = integrate(4*pi*(r**2)*f2*Hf1, (r, low, high))
    H22 = integrate(4*pi*(r**2)*f2*Hf2, (r, low, high))

    S11 = integrate(4*pi*(r**2)*f1**2, (r, low, high))
    S12 = integrate(4*pi*(r**2)*f1*f2, (r, low, high))
    S21 = S12
    S22 = integrate(4*pi*(r**2)*f2**2, (r, low, high))

    E0 = solve_quadratic(
            H11=H11, H22=H22, H12=H12, H21=H21,
            S11=S11, S22=S22, S12=S12, S21=S21,
        )
    print(E0)

    C = -(H11 - E0*S11)/(H12 - E0*S12)
    psi0 = (f1 + C*f2).simplify()
    trial = trial.row_insert(x+1, Matrix([[psi0]]))

# Free ICI Part

h = zeros(x+2, x+2)
HS = zeros(x+2, 1)
S = zeros(x+2, x+2)

for s in range(x+2):
    HS[s] = H(trial[s]).simplify()

for i in range(x+2):
    for j in range(x+2):
        h[i, j] = integrate(4*pi*(r**2)*trial[i]*HS[j], (r, low, high))

for i in range(x+2):
    for j in range(x+2):
        S[i, j] = integrate(4*pi*(r**2)*trial[i]*trial[j], (r, low, high))

m = h - E*S
eqn = m.det()

roots = bisection(eqn, E0 - 1, E0, 10**(-15))

print(roots)

正如我所说的,他们都按他们应该做的工作,但他们做得很慢。你知道吗


Tags: forpilowf2f1integratehightrial
1条回答
网友
1楼 · 发布于 2024-06-10 01:41:12

下面是对代码的一些优化

  1. 使用Lambdify(E, fun, cse=True)使用公共子表达式消除
  2. 添加pi = sympify(sympy.N(pi))以使用pi的数值。保持pi作为符号是有害的,因为表达式太大了。你知道吗
  3. .simplify调用更改为.expand调用。你知道吗
  4. 你的积分表达式有一种特殊的形式。它们有一种特殊的形式integrate(r**n * exp(-p*r), (r, 0, inf),很容易集成。你知道吗
In [21]: var("n, r, p", positive=True)                                                                                                                                
Out[21]: (n, r, p)

In [22]: integrate(q*r**n*exp(-p*r), (r, 0, oo))                                                                                                                      
Out[22]: p**(-n)*q*gamma(n + 1)/p

你可以用下面这样的方法来获得这个好处。(理想情况下,sympy应该能够更快地完成这项工作,但sympy在这方面做得并不好。去年夏天,当我试图用符号方法解狄拉克和薛定谔方程来调试我的数值代码时,我遇到了同样的问题。我猜你也在做类似的事情)

def integrate(*args):
    args = list(args)
    expr = args[0].expand()
    r = sympy.S(args[1][0])
    limits = args[1][1:]
    p = sympy.Wild("p")
    n = sympy.Wild("n")
    q = sympy.Wild("q")
    pattern = q * r**n * sympy.exp(p*r)
    terms = expr.args
    if not expr.is_Add:
        terms = [expr]
    result = 0
    for arg in terms:
        d = sympy.S(arg).match(pattern)
        if d is None:
            result += sympy.N(sympy.integrate(arg, args[1]))
            continue
        if d[p].is_number and d[q].is_number and d[n].is_number:
            ex = d[q]*(-d[p])**(-d[n])/d[p]*sympy.lowergamma(d[n]+1, -d[p]*r)
            result += sympify(sympy.factorial(d[n])*d[q]/(-d[p])**(d[n]+1))
        else:
            result += sympy.N(sympy.integrate(arg, args[1]))
    return result

这4个改变将我的时间缩短到16秒。你知道吗

相关问题 更多 >