如何使用numpy(及scipy)找到函数的所有零点?
假设我有一个函数 f(x)
,它的定义范围在 a
和 b
之间。这个函数可能有很多零点,也可能有很多渐近线。我需要找到这个函数的所有零点。有什么好的方法吗?
其实,我的策略是这样的:
- 我在一些特定的点上计算这个函数的值。
- 我检查函数值是否有符号变化。
- 我在那些符号发生变化的点之间找到零点。
我确认找到的零点是否真的为零点,或者它是否是一个渐近线。
U = numpy.linspace(a, b, 100) # evaluate function at 100 different points c = f(U) s = numpy.sign(c) for i in range(100-1): if s[i] + s[i+1] == 0: # oposite signs u = scipy.optimize.brentq(f, U[i], U[i+1]) z = f(u) if numpy.isnan(z) or abs(z) > 1e-3: continue print('found zero at {}'.format(u))
这个算法看起来是有效的,但我发现有两个潜在的问题:
- 它无法检测到那些不穿过x轴的零点(比如在函数
f(x) = x**2
中)。不过,我觉得在我正在评估的函数中,这种情况不会发生。 - 如果计算的点之间距离太远,可能会有多个零点在它们之间,这样算法可能就找不到它们了。
你有没有更好的策略(仍然高效)来找到一个函数的所有零点呢?
我觉得这对问题不是很重要,但对于好奇的人来说,我正在处理光纤中波传播的特征方程。这个函数看起来是这样的(其中 V
和 ell
是之前定义的,ell
是一个正整数):
def f(u):
w = numpy.sqrt(V**2 - u**2)
jl = scipy.special.jn(ell, u)
jl1 = scipy.special.jnjn(ell-1, u)
kl = scipy.special.jnkn(ell, w)
kl1 = scipy.special.jnkn(ell-1, w)
return jl / (u*jl1) + kl / (w*kl1)
4 个回答
我发现用 scipy.optimize.fsolve 来实现自己的根查找器其实挺简单的。
想法:通过不断改变
x0
的值,反复调用fsolve
,在区间(start, stop)
内以步长step
找到所有的零点。为了找到所有的根,步长可以设置得相对小一些。只能在一维中搜索零点(其他维度必须固定)。如果你有其他需求,我建议使用 sympy 来计算解析解。
注意:它可能并不总是能找到所有的零点,但我看到它的结果相对不错。我把代码放在了一个 gist 上,如果需要的话我会更新。
import numpy as np
import scipy
from scipy.optimize import fsolve
from matplotlib import pyplot as plt
# Defined below
r = RootFinder(1, 20, 0.01)
args = (90, 5)
roots = r.find(f, *args)
print("Roots: ", roots)
# plot results
u = np.linspace(1, 20, num=600)
fig, ax = plt.subplots()
ax.plot(u, f(u, *args))
ax.scatter(roots, f(np.array(roots), *args), color="r", s=10)
ax.grid(color="grey", ls="--", lw=0.5)
plt.show()
示例输出:
Roots: [ 2.84599497 8.82720551 12.38857782 15.74736542 19.02545276]
根查找器定义
import numpy as np
import scipy
from scipy.optimize import fsolve
from matplotlib import pyplot as plt
class RootFinder:
def __init__(self, start, stop, step=0.01, root_dtype="float64", xtol=1e-9):
self.start = start
self.stop = stop
self.step = step
self.xtol = xtol
self.roots = np.array([], dtype=root_dtype)
def add_to_roots(self, x):
if (x < self.start) or (x > self.stop):
return # outside range
if any(abs(self.roots - x) < self.xtol):
return # root already found.
self.roots = np.append(self.roots, x)
def find(self, f, *args):
current = self.start
for x0 in np.arange(self.start, self.stop + self.step, self.step):
if x0 < current:
continue
x = self.find_root(f, x0, *args)
if x is None: # no root found.
continue
current = x
self.add_to_roots(x)
return self.roots
def find_root(self, f, x0, *args):
x, _, ier, _ = fsolve(f, x0=x0, args=args, full_output=True, xtol=self.xtol)
if ier == 1:
return x[0]
return None
测试函数
虽然 scipy.special.jnjn
已经不存在了,但我为这个情况创建了一个类似的测试函数。
def f(u, V=90, ell=5):
w = np.sqrt(V ** 2 - u ** 2)
jl = scipy.special.jn(ell, u)
jl1 = scipy.special.yn(ell - 1, u)
kl = scipy.special.kn(ell, w)
kl1 = scipy.special.kn(ell - 1, w)
return jl / (u * jl1) + kl / (w * kl1)
你为什么只局限于 numpy
呢?其实 scipy
有一个包正好能满足你的需求:
http://docs.scipy.org/doc/scipy/reference/optimize.nonlin.html
我学到的一课是:数值编程很难,所以别自己做了 :)
不过,如果你真的想自己动手写算法,我刚才提到的 scipy
文档页面(加载速度慢得要命)给你列了一些可以开始的算法。之前我用过的一种方法是把函数离散化到适合你问题的程度。也就是说,调整 \delta x 让它比你问题中的特征大小小得多。这样你就可以观察函数的特征(比如符号的变化)。而且,你可以很容易地计算一条线段的导数(可能从幼儿园就学过),所以你离散化后的函数有一个明确的导数。因为你把 dx 调整得比特征大小小,所以你可以确保不会遗漏对你问题重要的函数特征。
如果你想知道什么是“特征大小”,可以找找你函数中带有长度单位或 1/长度单位的参数。比如,对于某个函数 f(x),假设 x 有长度单位,而 f 没有单位。然后找找那些乘以 x 的东西。例如,如果你想离散化 cos(\pi x),那么乘以 x 的参数(如果 x 有长度单位)必须有 1/长度的单位。所以 cos(\pi x) 的特征大小是 1/\pi。如果你把离散化做得比这个小,就不会有问题。不过要注意,这个方法并不总是有效,所以你可能需要进行一些调整。
我看到的主要问题是,能否找到所有的根——正如评论中提到的,这并不是总能做到的。如果你确定你的函数不是特别复杂(比如sin(1/x)
已经被提到过),那么接下来就是你能接受错过一个或多个根的容忍度。换句话说,就是你愿意花多大力气去确保没有漏掉任何根——据我所知,没有通用的方法可以帮你找到所有的根,所以你得自己动手。你展示的内容已经是一个合理的第一步了。这里有几点建议:
- 布伦特方法在这里确实是个不错的选择。
- 首先,处理发散的问题。因为在你的函数中,分母有贝塞尔函数,你可以先找出它们的根——最好查一下,比如在阿布拉莫维奇和斯特根的书中(Mathworld链接)。这会比你现在使用的临时网格要好。
- 一旦你找到了两个根或发散点
x_1
和x_2
,可以在区间[x_1+epsilon, x_2-epsilon]
中再次进行搜索。继续这个过程,直到不再找到新的根(布伦特方法保证会收敛到一个根,只要存在一个)。 - 如果你无法列出所有的发散点,可能需要更小心地验证一个候选点是否真的发散:给定
x
,不仅要检查f(x)
是否很大,还要检查,比如|f(x-epsilon/2)| > |f(x-epsilon)|
,并对多个epsilon
值(如1e-8, 1e-9, 1e-10等)进行检查。 - 如果你想确保没有根只是轻轻碰到零,可以寻找函数的极值点,对于每个极值点
x_e
,检查f(x_e)
的值。