数值求解代码速度更快(numpy/scipy)

2024-03-28 12:20:03 发布

您现在位置:Python中文网/ 问答频道 /正文

如果能帮我把代码写得更快,我将不胜感激。目前,它实在是太慢了。它基本上产生两个具有不同dutycycles的方波,然后对产生的方波应用一个特殊的滤波器从方波中提取一个频率分量,并尝试通过改变两个方波的dutycycles将该频率分量与一个值匹配。你知道吗

import os
import numpy as np
from scipy import optimize, integrate, signal
import math
import cmath

PI = np.pi


def Goetrzel(x, target_frequency, sample_rate):
    s_prev = 0
    s_prev2 = 0
    normalized_frequency = target_frequency / sample_rate

    wr = np.cos(2.0 * np.pi * normalized_frequency)
    wi = np.sin(2.0 * np.pi * normalized_frequency)

    coeff = 2.0 * wr
    for sample in x:
        s = sample + coeff * s_prev - s_prev2
        s_prev2 = s_prev
        s_prev = s

    XKreal = s_prev * wr - s_prev2
    XKimag = s_prev * wi

    XK = (XKreal + 1j*XKimag) / (len(x)/2.)

    #power = s_prev2 * s_prev2 + s_prev * s_prev - coeff * s_prev * s_prev2 ;
    return abs(XK), np.angle(XK)*180./PI


def equations(p, zcurr, z, k1, k2):
    P = lambda z, D1, D2: \
        signal.square(k1*z, duty=D1) * signal.square(k2*z, duty=D2)
    K12 = lambda z: -np.cos(np.pi/2.*z/L)+1.
    K32 = lambda z: -np.sin(np.pi/2.*z/L)+1.

    D1, D2 = p

    h = 0.01
    eq1 = Goetrzel(P(np.arange(0.,10.,h),D1,D2), k1/(2.*PI), 1./h)[0] - K12(zcurr)
    eq2 = Goetrzel(P(np.arange(0.,10.,h),D1,D2), k2/(2.*PI), 1./h)[0] - K32(zcurr)

    return eq1**2 + eq2**2


def DutyCycleSolver(z, k1, k2, display=False):
    D1 = np.empty([len(z)])
    D1.fill(np.nan)
    D2 = np.empty([len(z)])
    D2.fill(np.nan)
    Derr = np.empty([len(z)])
    Derr.fill(np.inf)
    D1_D2_guess = np.empty([len(z),2])

    for i in range(len(z)):
        solutionFound = False
        for guessD1 in np.arange(0.8, 1., 0.1):
            for guessD2 in np.arange(0.8, 1., 0.1):
                temp = optimize.fmin(equations,
                                     x0=(guessD1, guessD2),
                                     args=(z[i],z,k1,k2,),
                                     xtol=1e-6,
                                     ftol=1e-6,
                                     disp=False,
                                     full_output=True)
                if temp[0][0] < -1.e-8 or temp[0][1] < -1.e-8 or \
                   temp[0][0] > 1.+1.e-8 or temp[0][1] > 1.+1.e-8:
                    continue

                DerrCur = temp[1]
                if DerrCur <= 1.e-3:
                    D1[i], D2[i] = temp[0]
                    Derr[i] = temp[1]
                    D1_D2_guess[i] = [guessD1, guessD2]
                    solutionFound = True
                    break
                elif DerrCur > 1.e-3 and DerrCur < Derr[i]:
                    D1[i], D2[i] = temp[0]
                    Derr[i] = temp[1]
                    D1_D2_guess[i] = [guessD1, guessD2]

            if solutionFound is True:
                if display:
                    print 'Solution found at', z[i]
                    print 'Using:', D1[i], D2[i]
                    print 'Found with guess:', D1_D2_guess[i]
                    print 'Error:', Derr[i]
                    print
                break

        if solutionFound is False and display:
            print 'No solution found at', z[i]
            print 'Using:', D1[i], D2[i]
            print 'With guess:', D1_D2_guess[i]
            print 'Error:', Derr[i]
            print


h = 0.3
L = 2.e3
z = np.arange(0., L, h)

DutyCycleSolver(z, 3., 8., display=True)

Tags: importlennppik2k1tempd2
1条回答
网友
1楼 · 发布于 2024-03-28 12:20:03

这会提高一点(0.000001%?)你的代码。所以这没用。你知道吗

之前:

DerrCur = temp[1]
if DerrCur <= 1.e-3:
    D1[i], D2[i] = temp[0]
    Derr[i] = temp[1]
    D1_D2_guess[i] = [guessD1, guessD2]
    solutionFound = True
    break
elif DerrCur > 1.e-3 and DerrCur < Derr[i]:
    D1[i], D2[i] = temp[0]
    Derr[i] = temp[1]
    D1_D2_guess[i] = [guessD1, guessD2]

之后:

DerrCur = temp[1]
if DerrCur <= 1.e-3:
    D1[i], D2[i] = temp[0]
    Derr[i] = temp[1]
    D1_D2_guess[i] = [guessD1, guessD2]
    solutionFound = True
    break
elif DerrCur < Derr[i]:
    D1[i], D2[i] = temp[0]
    Derr[i] = temp[1]
    D1_D2_guess[i] = [guessD1, guessD2]

显然temp = optimize.fmin(equations,,,是瓶颈。向量运算执行了大约20*20*1000=400k次。这意味着如果向量计算是1毫秒,你需要400秒来完成它。你知道吗

相关问题 更多 >