函数内部的变量输入
我正在尝试制作一个可以用于所有情况的 dormand-prince 类型的函数,这样我就可以处理不同数量的变量耦合微分方程,比如2个、3个、4个等等。我希望这个函数能够适用于所有的情况。
比如,lotkav 是一个有两个变量的耦合微分方程,而 lorenz 是一个有三个变量的耦合微分方程。在我写的这个函数里,我需要根据输入参数的不同数量来评估 func
,因为它可能是 lotkav 也可能是 lorenz。
bt = np.array([[0. , 0. , 0. ,0. ,0. , 0. , 0. ], # h
[1/5. , 1/5. , 0. ,0. ,0. , 0. , 0. ], # k0
[3/10.,3/40. , 9/40. ,0. ,0. , 0. , 0. ], # k1
[4/5. ,44/45. ,-56/15. ,32/9. ,0. , 0. , 0. ], # k2
[8/9. ,19372/6561.,-25360/2187 ,64448/6561 ,-212/729., 0. , 0. ], # k3
[1. ,9017/3168. ,-355/33. ,46732/5247 ,49/176 ,-5103/18656. , 0. ], # k4
[1 ,35/384 ,0 ,500/1113 ,125/192 ,-2187/6784. ,11/84.]]) # k5
def lotkav(t, x, y):
return [x*(a-b*y), -y*(c-d*x)]
def lorenz(t, x, y, z):
return [sigma*(y-x), x*(rho-z)-y, x*y - beta*z]
def ode45(func, t_span, y0, hi = .001, hmax = .01, hmin = .000001, tol = 0.00000001):
k = np.zeros(shape=(7,len(y0)))
h = hi
vi = y0.copy()
yp = np.zeros(len(y0)) # fifth order solution
zp = np.zeros(len(y0)) # sixth order solution
ts = [ ] # time at which values are calculated
ts.append(t_span[0])
t = t_span[0]
yr = [y0]
ri = np.zeros(len(y0))
sig = signature(func)
while t <= t_span[1]:
for i in range(7):
sum = np.zeros(len(y0))
ti = t + bt[i][0]*h
for j in range(0,i):
sum = np.add(sum, bt[i][j+1]*k[j])
for narg in range(len(y0)):
ri[narg] = vi[narg] + sum[narg]*h
# I want to change the following line so it works for both without if else.
k[i] = func(ti, ri[0], ri[1]) # for lorenz it will be k[i] = func(ti, ri[0], ri[1], ri[2])
for i in range(len(y0)):
yp[i] = vi[i] + ((35/384.)*k[0][i]+(500/1113.)*k[2][i]+(125/192.)*k[3][i]+(-2187/6784.)*k[4][i]+(11/84.)*k[5][i])*h
zp[i] = vi[i] + ((5179/57600.)*k[0][i]+(7571/16695.)*k[2][i]+(393/640.)*k[3][i]+(-92097/339200.)*k[4][i]+(187/2100)*k[5][i])*h+(1/40)*k[6][i]*h
err = np.min(np.abs(np.subtract(yp, zp)))
if 0 < err < tol: # if error within tolerance accept result and try for larger step
yr.append(zp.copy())
t = t + h
ts.append(t)
vi = zp
h = 2.0*h # increase time step by 20%
if h > hmax:
h = hmax
elif err == 0: # if error becomes 0 then do the same
yr.append(zp.copy())
t = t + h
ts.append(t)
vi = zp
h = 2.0*h
if h > hmax:
h = hmax
else: # error is not within limit reduce the step size
h = h*(tol*h/(2*(err)))**(1/5.0)
if h < hmin:
h = hmin
return ts,yr
tp, yp = ode45(lotkav, [0,50], y0)
t,r = ode45(lorenz, [0,100], y0)
我知道 *arg
本来可以在 lorenz 或 lotkav 的函数参数中使用,但 scipy 的 odeint 并不需要这样做。
1 个回答
0
写一个包装函数,这个函数可以接受不同数量的参数——对于lotkav函数,它可以接收3个值;对于lorenz函数,它可以接收4个值。
大概是这样的:
def lotkav(a, b, c):
return "lotkav"
def lorenz(a, b, c, d):
return "lorenz"
def lowrapper(*args):
match len(args):
case 3:
return lotkav(*args)
case 4:
return lorenz(*args)
case _:
raise Exception("Invalid number of arguments")
ti = 1
ri = [2, 3]
print(lowrapper(ti, *ri))
ri = [2, 3, 4]
print(lowrapper(ti, *ri))
输出:
lotkav
lorenz