如何使用numpy(及scipy)找到函数的所有零点?

9 投票
4 回答
48410 浏览
提问于 2025-04-17 15:53

假设我有一个函数 f(x),它的定义范围在 ab 之间。这个函数可能有很多零点,也可能有很多渐近线。我需要找到这个函数的所有零点。有什么好的方法吗?

其实,我的策略是这样的:

  1. 我在一些特定的点上计算这个函数的值。
  2. 我检查函数值是否有符号变化。
  3. 我在那些符号发生变化的点之间找到零点。
  4. 我确认找到的零点是否真的为零点,或者它是否是一个渐近线。

    U = numpy.linspace(a, b, 100) # evaluate function at 100 different points
    c = f(U)
    s = numpy.sign(c)
    for i in range(100-1):
        if s[i] + s[i+1] == 0: # oposite signs
            u = scipy.optimize.brentq(f, U[i], U[i+1])
            z = f(u)
            if numpy.isnan(z) or abs(z) > 1e-3:
                continue
            print('found zero at {}'.format(u))
    

这个算法看起来是有效的,但我发现有两个潜在的问题:

  1. 它无法检测到那些不穿过x轴的零点(比如在函数 f(x) = x**2 中)。不过,我觉得在我正在评估的函数中,这种情况不会发生。
  2. 如果计算的点之间距离太远,可能会有多个零点在它们之间,这样算法可能就找不到它们了。

你有没有更好的策略(仍然高效)来找到一个函数的所有零点呢?


我觉得这对问题不是很重要,但对于好奇的人来说,我正在处理光纤中波传播的特征方程。这个函数看起来是这样的(其中 Vell 是之前定义的,ell 是一个正整数):

def f(u):
    w = numpy.sqrt(V**2 - u**2)

    jl = scipy.special.jn(ell, u)
    jl1 = scipy.special.jnjn(ell-1, u)
    kl = scipy.special.jnkn(ell, w)
    kl1 = scipy.special.jnkn(ell-1, w)

    return jl / (u*jl1) + kl / (w*kl1)

4 个回答

2

我发现用 scipy.optimize.fsolve 来实现自己的根查找器其实挺简单的。

  • 想法:通过不断改变 x0 的值,反复调用 fsolve,在区间 (start, stop) 内以步长 step 找到所有的零点。为了找到所有的根,步长可以设置得相对小一些。

  • 只能在一维中搜索零点(其他维度必须固定)。如果你有其他需求,我建议使用 sympy 来计算解析解。

  • 注意:它可能并不总是能找到所有的零点,但我看到它的结果相对不错。我把代码放在了一个 gist 上,如果需要的话我会更新。

import numpy as np
import scipy
from scipy.optimize import fsolve
from matplotlib import pyplot as plt

# Defined below
r = RootFinder(1, 20, 0.01)
args = (90, 5)
roots = r.find(f, *args)
print("Roots: ", roots)

# plot results
u = np.linspace(1, 20, num=600)
fig, ax = plt.subplots()
ax.plot(u, f(u, *args))
ax.scatter(roots, f(np.array(roots), *args), color="r", s=10)
ax.grid(color="grey", ls="--", lw=0.5)
plt.show()

示例输出:

Roots:  [ 2.84599497  8.82720551 12.38857782 15.74736542 19.02545276]

在这里输入图片描述

放大: 在这里输入图片描述

根查找器定义

import numpy as np
import scipy
from scipy.optimize import fsolve
from matplotlib import pyplot as plt


class RootFinder:
    def __init__(self, start, stop, step=0.01, root_dtype="float64", xtol=1e-9):

        self.start = start
        self.stop = stop
        self.step = step
        self.xtol = xtol
        self.roots = np.array([], dtype=root_dtype)

    def add_to_roots(self, x):

        if (x < self.start) or (x > self.stop):
            return  # outside range
        if any(abs(self.roots - x) < self.xtol):
            return  # root already found.

        self.roots = np.append(self.roots, x)

    def find(self, f, *args):
        current = self.start

        for x0 in np.arange(self.start, self.stop + self.step, self.step):
            if x0 < current:
                continue
            x = self.find_root(f, x0, *args)
            if x is None:  # no root found.
                continue
            current = x
            self.add_to_roots(x)

        return self.roots

    def find_root(self, f, x0, *args):

        x, _, ier, _ = fsolve(f, x0=x0, args=args, full_output=True, xtol=self.xtol)
        if ier == 1:
            return x[0]
        return None

测试函数

虽然 scipy.special.jnjn 已经不存在了,但我为这个情况创建了一个类似的测试函数。

def f(u, V=90, ell=5):
    w = np.sqrt(V ** 2 - u ** 2)

    jl = scipy.special.jn(ell, u)
    jl1 = scipy.special.yn(ell - 1, u)
    kl = scipy.special.kn(ell, w)
    kl1 = scipy.special.kn(ell - 1, w)

    return jl / (u * jl1) + kl / (w * kl1)
4

你为什么只局限于 numpy 呢?其实 scipy 有一个包正好能满足你的需求:

http://docs.scipy.org/doc/scipy/reference/optimize.nonlin.html

我学到的一课是:数值编程很难,所以别自己做了 :)


不过,如果你真的想自己动手写算法,我刚才提到的 scipy 文档页面(加载速度慢得要命)给你列了一些可以开始的算法。之前我用过的一种方法是把函数离散化到适合你问题的程度。也就是说,调整 \delta x 让它比你问题中的特征大小小得多。这样你就可以观察函数的特征(比如符号的变化)。而且,你可以很容易地计算一条线段的导数(可能从幼儿园就学过),所以你离散化后的函数有一个明确的导数。因为你把 dx 调整得比特征大小小,所以你可以确保不会遗漏对你问题重要的函数特征。

如果你想知道什么是“特征大小”,可以找找你函数中带有长度单位或 1/长度单位的参数。比如,对于某个函数 f(x),假设 x 有长度单位,而 f 没有单位。然后找找那些乘以 x 的东西。例如,如果你想离散化 cos(\pi x),那么乘以 x 的参数(如果 x 有长度单位)必须有 1/长度的单位。所以 cos(\pi x) 的特征大小是 1/\pi。如果你把离散化做得比这个小,就不会有问题。不过要注意,这个方法并不总是有效,所以你可能需要进行一些调整。

1

我看到的主要问题是,能否找到所有的根——正如评论中提到的,这并不是总能做到的。如果你确定你的函数不是特别复杂(比如sin(1/x)已经被提到过),那么接下来就是你能接受错过一个或多个根的容忍度。换句话说,就是你愿意花多大力气去确保没有漏掉任何根——据我所知,没有通用的方法可以帮你找到所有的根,所以你得自己动手。你展示的内容已经是一个合理的第一步了。这里有几点建议:

  • 布伦特方法在这里确实是个不错的选择。
  • 首先,处理发散的问题。因为在你的函数中,分母有贝塞尔函数,你可以先找出它们的根——最好查一下,比如在阿布拉莫维奇和斯特根的书中(Mathworld链接)。这会比你现在使用的临时网格要好。
  • 一旦你找到了两个根或发散点x_1x_2,可以在区间[x_1+epsilon, x_2-epsilon]中再次进行搜索。继续这个过程,直到不再找到新的根(布伦特方法保证会收敛到一个根,只要存在一个)。
  • 如果你无法列出所有的发散点,可能需要更小心地验证一个候选点是否真的发散:给定x,不仅要检查f(x)是否很大,还要检查,比如|f(x-epsilon/2)| > |f(x-epsilon)|,并对多个epsilon值(如1e-8, 1e-9, 1e-10等)进行检查。
  • 如果你想确保没有根只是轻轻碰到零,可以寻找函数的极值点,对于每个极值点x_e,检查f(x_e)的值。

撰写回答