解析求解三次样条插值数据导数的零点

1 投票
1 回答
2458 浏览
提问于 2025-04-18 18:14

我有一组数据,我用三次样条插值(UnivariateSpline,三次的)来处理这些数据。我想做一种峰值检测的方法,不是通过求导数然后找零点,而是直接求导数,然后把结果代入二次方程来找到所有的零点。

这个函数到底返回了什么呢?因为为了生成一组数据,使其符合插值结果,你必须给返回的结果提供一系列的点,像这样:

from numpy import linspace,exp
from numpy.random import randn
import matplotlib.pyplot as plt
from scipy.interpolate import UnivariateSpline
x = linspace(-3, 3, 100) # original data x axis
y = exp(-x**2) + randn(100)/10  #original data y axis
s = UnivariateSpline(x, y, s=1) # interpolation, returned to value s
xs = linspace(-3, 3, 1000) #values for x axis
ys = s(xs) # create new y axis
plt.plot(x, y, '.-')
plt.plot(xs, ys)
plt.show()

那么这个返回的东西s到底是什么?它是列出了三次方程的系数吗?如果是的话,我该如何通过对这些值求导来找到峰值呢?

1 个回答

0

对象 sscipy.interpolate.fitpack2.UnivariateSpline 的一个实例。

和其他 Python 对象一样,它有一些属性和方法可以用来操作它。想要了解你可以对某个 Python 对象做什么,可以使用内置的函数 typedirvars

在这个情况下,dir 最有帮助。只需运行

dir(s)

你就能看到 s 的所有属性,具体包括:

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__doc__',
 '__format__',
 '__getattribute__',
 '__hash__',
 '__init__',
 '__module__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_data',
 '_eval_args',
 '_from_tck',
 '_reset_class',
 '_reset_nest',
 '_set_class',
 'antiderivative',
 'derivative',
 'derivatives',
 'get_coeffs',
 'get_knots',
 'get_residual',
 'integral',
 'roots',
 'set_smoothing_factor']

Python 有个约定,属性和方法如果名字是以一个下划线开头的,那就是私有的,通常不建议使用,除非你知道自己在做什么。不过,正如你所看到的,列表的末尾包含了你想从 s 中获取的信息:比如样条系数、导数、根等等。

接下来我们来看一个例子:

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import UnivariateSpline
x = np.linspace(-3, 3, 100) # original data x axis
y = np.exp(-x**2) + randn(100)/10  #original data y axis
s = UnivariateSpline(x, y, s=1) # interpolation, returned to value s

# watch the changes
xs = np.linspace(-3, 3, 10000) #values for x axis
ys = s(xs) # create new y axis
ds = s.derivative(n=1)  # get the derivative
dy = ds(xs)  # compute it on xs
tol=1e-4  # stabilish a tolerance 
root_index = np.where((dy>-tol)&(dy<tol))  # find indices where dy is  close to zero within tolerance
root = xs[root_index]  # get the correspondent xs values
root = set(np.round(root, decimals=2).tolist())  # remove redundancy duo to tolerance
root = np.array(list(root))
print(root)
plt.plot(x, y, '.-')
plt.plot(xs, ys)
plt.plot(xs, dy, 'r--')  # plot the derivative
plt.vlines(root, -1, 1, lw=2, alpha=.4)  # draw vertical lines through each root
plt.hlines(0, -3, 3, lw=2, alpha=.4)  # draw a horizontal line through zero

撰写回答