Python:微调多个拟合函数
我需要一些帮助来调整我代码生成的图表。这个代码写得很粗糙,但基本上是用来拟合一些数据,这些数据在时间和计数的分布中有两个峰值。每个峰的上升部分用高斯函数来拟合,而下降部分则用指数函数来拟合。
我需要对这些拟合进行微调,因为从图表上看,数据虽然被拟合了,但效果并不好。我想避免不同函数之间的断裂(也就是说,这些函数需要“接触”在一起),并且我希望得到的拟合能够真正跟随数据,并且符合它们的定义(比如,第一个高斯函数在峰值处没有呈现出“钟形”,而第二个高斯函数停止得“太早”)。
这个代码是从网上获取数据的,所以可以直接运行。希望代码和图像能比我的描述更清晰。
非常感谢!
#!/usr/bin/env python
import pyfits, os, re, glob, sys
from scipy.optimize import leastsq
from numpy import *
from pylab import *
from scipy import *
# ---------------- Functions ---------------------------#
def right_exp(p, x, y, err1):
yfit1 = p[0]*exp(-p[2]*(x - p[1]))
dev_exp = (y - yfit1)/err1
return dev_exp
def left_gauss(p, x, y, err2):
yfit2 = p[0]*(1/sqrt(2*pi*(p[2]**2)))*exp(-(x - p[1])**2/(2*p[2]**2))
dev_gauss = (y - yfit2)/err2
return dev_gauss
# ------------------------------------------------------ #
tmin = 56200
tmax = 56249
data=pyfits.open('http://heasarc.gsfc.nasa.gov/docs/swift/results/transients/weak/GX304-1.orbit.lc.fits')
time = data[1].data.field(0)/86400. + data[1].header['MJDREFF'] + data[1].header['MJDREFI']
rate = data[1].data.field(1)
error = data[1].data.field(2)
data.close()
cond1 = ((time > 56200) & (time < 56209)) #| ((time > 56225) & (time < 56234))
time1 = time[cond1]
rate1 = rate[cond1]
error1 = error[cond1]
cond2 = ((time > 56209) & (time < 56225)) #| ((time > 56234) & (time < 56249))
time2 = time[cond2]
rate2 = rate[cond2]
error2 = error[cond2]
cond3 = ((time > 56225) & (time < 56234))
time3 = time[cond3]
rate3 = rate[cond3]
error3 = error[cond3]
cond4 = ((time > 56234) & (time < 56249))
time4 = time[cond4]
rate4 = rate[cond4]
error4 = error[cond4]
totaltime = np.append(time1, time2)
totalrate = np.append(rate1, rate2)
v0= [0.23, 56209.0, 1] #inital guesses for Gaussian Fit, just do it around the peaks
v1= [0.40, 56233.0, 1]
# ------------------------ First peak -------------------------------------------------------------------#
out = leastsq(left_gauss, v0[:], args=(time1, rate1, error1), maxfev = 100000, full_output = 1)
p = out[0]
v = out[0]
xxx = arange(min(time1), max(time1), time1[1] - time1[0])
yfit1 = p[0]*(1/sqrt(2*pi*(p[2]**2)))*exp(-(xxx - p[1])**2/(2*p[2]**2))
out2 = leastsq(right_exp, v0[:], args = (time2, rate2, error2), maxfev = 100000, full_output = 1)
p2 = out2[0]
v2 = out2[0]
xxx2 = arange(min(time2), max(time2), time2[1] - time2[0])
yfit2 = p2[0]*exp(-p2[2]*(xxx2 - p2[1]))
# ------------------------ Second peak -------------------------------------------------------------------#
out3 = leastsq(left_gauss, v1[:], args=(time3, rate3, error3), maxfev = 100000, full_output = 1)
p3 = out3[0]
v3 = out3[0]
xxx3 = arange(min(time3), max(time3), time3[1] - time3[0])
yfit3 = p3[0]*(1/sqrt(2*pi*(p3[2]**2)))*exp(-(xxx3 - p3[1])**2/(2*p3[2]**2))
out4 = leastsq(right_exp, v1[:], args = (time4, rate4, error4), maxfev = 100000, full_output = 1)
p4 = out4[0]
v4 = out4[0]
xxx4 = arange(min(time4), max(time4), time4[1] - time4[0])
yfit4 = p4[0]*exp(-p4[2]*(xxx4 - p4[1]))
# ------------------------------------------------------------------------------------------------------- #
fig = figure(figsize = (9, 9)) #make a plot
ax1 = fig.add_subplot(111)
ax1.plot(time, rate, 'g.')
ax1.plot(xxx, yfit1, 'b-')
ax1.plot(xxx2, yfit2, 'b-')
ax1.plot(xxx3, yfit3, 'b-')
ax1.plot(xxx4, yfit4, 'b-')
axis([tmin, tmax, -0.00, 0.45])
savefig("first peak.png")
1 个回答
2
使用三角级数可以很好地解决这个问题,从而创建一个连续的函数。下面的例子如果粘贴到你的代码后面就能正常工作。如果需要的话,你可以改变三角级数中的项数。
import numpy as np
from scipy.optimize import curve_fit
x = np.concatenate((time1, time2, time3, time4))
y_points = np.concatenate((rate1, rate2, rate3, rate4))
den = x.max() - x.min()
def func(x, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15):
return a1 *sin( 1*pi*x/den)+\
a2 *sin( 2*pi*x/den)+\
a3 *sin( 3*pi*x/den)+\
a4 *sin( 4*pi*x/den)+\
a5 *sin( 5*pi*x/den)+\
a6 *sin( 6*pi*x/den)+\
a7 *sin( 7*pi*x/den)+\
a8 *sin( 8*pi*x/den)+\
a9 *sin( 9*pi*x/den)+\
a10*sin(10*pi*x/den)+\
a11*sin(11*pi*x/den)+\
a12*sin(12*pi*x/den)+\
a13*sin(13*pi*x/den)+\
a14*sin(14*pi*x/den)+\
a15*sin(15*pi*x/den)
popt, pcov = curve_fit(func, x, y_points)
y = func(x, *popt)
plot(x,y, color='r', linewidth=2.)
show()
编辑
正如@Alfe所建议的,这个拟合函数可以用更简洁的格式来写,比如:
def func(x, a):
return sum(a_i * sin(i * pi * x / den) for i, a_i in enumerate(a, 1))