解析求解三次样条插值数据导数的零点
我有一组数据,我用三次样条插值(UnivariateSpline,三次的)来处理这些数据。我想做一种峰值检测的方法,不是通过求导数然后找零点,而是直接求导数,然后把结果代入二次方程来找到所有的零点。
这个函数到底返回了什么呢?因为为了生成一组数据,使其符合插值结果,你必须给返回的结果提供一系列的点,像这样:
from numpy import linspace,exp
from numpy.random import randn
import matplotlib.pyplot as plt
from scipy.interpolate import UnivariateSpline
x = linspace(-3, 3, 100) # original data x axis
y = exp(-x**2) + randn(100)/10 #original data y axis
s = UnivariateSpline(x, y, s=1) # interpolation, returned to value s
xs = linspace(-3, 3, 1000) #values for x axis
ys = s(xs) # create new y axis
plt.plot(x, y, '.-')
plt.plot(xs, ys)
plt.show()
那么这个返回的东西s到底是什么?它是列出了三次方程的系数吗?如果是的话,我该如何通过对这些值求导来找到峰值呢?
1 个回答
0
对象 s
是 scipy.interpolate.fitpack2.UnivariateSpline
的一个实例。
和其他 Python 对象一样,它有一些属性和方法可以用来操作它。想要了解你可以对某个 Python 对象做什么,可以使用内置的函数 type
、dir
和 vars
。
在这个情况下,dir
最有帮助。只需运行
dir(s)
你就能看到 s
的所有属性,具体包括:
['__call__',
'__class__',
'__delattr__',
'__dict__',
'__doc__',
'__format__',
'__getattribute__',
'__hash__',
'__init__',
'__module__',
'__new__',
'__reduce__',
'__reduce_ex__',
'__repr__',
'__setattr__',
'__sizeof__',
'__str__',
'__subclasshook__',
'__weakref__',
'_data',
'_eval_args',
'_from_tck',
'_reset_class',
'_reset_nest',
'_set_class',
'antiderivative',
'derivative',
'derivatives',
'get_coeffs',
'get_knots',
'get_residual',
'integral',
'roots',
'set_smoothing_factor']
Python 有个约定,属性和方法如果名字是以一个下划线开头的,那就是私有的,通常不建议使用,除非你知道自己在做什么。不过,正如你所看到的,列表的末尾包含了你想从 s
中获取的信息:比如样条系数、导数、根等等。
接下来我们来看一个例子:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import UnivariateSpline
x = np.linspace(-3, 3, 100) # original data x axis
y = np.exp(-x**2) + randn(100)/10 #original data y axis
s = UnivariateSpline(x, y, s=1) # interpolation, returned to value s
# watch the changes
xs = np.linspace(-3, 3, 10000) #values for x axis
ys = s(xs) # create new y axis
ds = s.derivative(n=1) # get the derivative
dy = ds(xs) # compute it on xs
tol=1e-4 # stabilish a tolerance
root_index = np.where((dy>-tol)&(dy<tol)) # find indices where dy is close to zero within tolerance
root = xs[root_index] # get the correspondent xs values
root = set(np.round(root, decimals=2).tolist()) # remove redundancy duo to tolerance
root = np.array(list(root))
print(root)
plt.plot(x, y, '.-')
plt.plot(xs, ys)
plt.plot(xs, dy, 'r--') # plot the derivative
plt.vlines(root, -1, 1, lw=2, alpha=.4) # draw vertical lines through each root
plt.hlines(0, -3, 3, lw=2, alpha=.4) # draw a horizontal line through zero