曲线拟合问题 - matplotlib

0 投票
2 回答
8373 浏览
提问于 2025-04-18 15:18

我正在尝试将一个正弦函数绘制到一个数据集上。我在网上找到了一篇使用scipy.optimize的教程,但即使我完全复制了代码,似乎也不起作用。

在开头:

def func(x, a, b, c, d):
    return a * np.sin(b * (x + c)) + d

在结尾:

scipy.optimize.curve_fit(func, clean_time, clean_rate)
pylab.show()

输出窗口里没有任何线条。

如果有人想要截图或者完整的代码,欢迎在下面留言。

谢谢!

2 个回答

1

我有一段示例代码,里面的数据是一个正弦曲线的形状,还有一个用户自定义的函数用来拟合这些数据。

下面是代码:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import math

xdata = np.array([2.65, 2.80, 2.96, 3.80, 3.90, 4.60, 4.80, 4.90, 5.65, 5.92])
ydata = np.sin(xdata)

def func(x,p1,p2,p3): # HERE WE DEFINE A SIN FUNCTION THAT WE THINK WILL FOLLOW THE DATA DISTRIBUTION
    return p1*np.sin(x*p2+p3)

# Here you give the initial parameters for p0 which Python then iterates over
# to find the best fit
popt, pcov = curve_fit(func,xdata,ydata,p0=(1.0,1.0,1.0)) #THESE PARAMETERS ARE USER DEFINED

print(popt) # This contains your two best fit parameters

# Performing sum of squares
p1 = popt[0]
p2 = popt[1]
p3 = popt[2]
residuals = ydata - func(xdata,p1,p2,p3)
fres = sum(residuals**2)

print(fres) #THIS IS YOUR CHI-SQUARE VALUE!

xaxis = np.linspace(1,7,100) # we can plot with xdata, but fit will not look good 
curve_y = func(xaxis,p1,p2,p3)
plt.plot(xdata,ydata,'*')
plt.plot(xaxis,curve_y,'-')
plt.show()

在这里输入图片描述

7

当然,它不会画出任何东西,curve_fit 本身并不负责绘图。

你可以查看文档,curve_fit 的返回值是一个数组,里面包含了估算的参数,还有一个二维数组,里面是估算的协方差矩阵。你需要自己用这些估算的参数来绘制拟合函数。

我还建议你使用 a*sin(bx +c) +d 这样的形式来拟合,因为这样 b 和 c 之间就不会互相影响了。

这个方法是有效的:

import matplotlib.pyplot as plt
import numpy as np
from numpy.random import normal
from scipy.optimize import curve_fit

x_data = np.linspace(0, 2*np.pi, 30)
y_data = np.sin(x_data) + normal(0, 0.2, 30)


def func(x, a, b, c, d):
    return a * np.sin(b*x + c) + d

parameter, covariance_matrix = curve_fit(func, x_data, y_data)

x = np.linspace(min(x_data), max(x_data), 1000)
plt.plot(x_data, y_data, 'rx', label='data')
plt.plot(x, func(x, *parameter), 'b-', label='fit')   # the star is to unpack the parameter array
plt.show()

这是结果:

plot result

撰写回答