使用numpy在重复信号的一部分中绘制抛物线
我有一个重复的信号,每次循环的过程大约每秒钟重复一次,虽然每个循环的持续时间和内容在某些参数范围内会有一点变化。每秒钟的信号数据有一千个x,y坐标。在每个循环中,有一小段但很重要的数据是损坏的,我想用一个向上的抛物线来替换每个损坏的部分。
对于每个需要用抛物线替换的数据段,我有三个点的x,y坐标。其中一个点是抛物线的顶点(最低点)。另外两个点是抛物线的左右顶端。换句话说,左顶是这个函数的最低x值的x,y坐标,而右顶是最高x值的x,y坐标。左顶和右顶的y坐标是相等的,都是这个数据段中最高的y值。
我该如何编写代码来绘制这个向上的抛物线中的其余数据点呢? 请记住,这个函数需要在每分钟的数据中调用60到70次,并且每次调用这个函数时,抛物线的形状和公式都需要变化,以适应这三个x,y坐标对之间的不同关系。
def ReplaceCorruptedDataWithParabola(Xarray, Yarray, LeftTopX, LeftTopY
, LeftTopIndex, MinX, MinY, MinIndex
, RightTopX, RightTopY, RightTopIndex):
# Step One: Derive the formula for the upward-facing parabola using
# the following data from the three points:
LeftTopX,LeftTopY,LeftTopIndex
MinX,MinY,MinIndex
RightTopX,RightTopY,RightTopIndex
# Step Two: Use the formula derived in step one to plot the parabola in
# the places where the corrupted data used to reside:
for n in Xarray[LeftTopX:RightTopX]:
Yarray[n]=[_**The formula goes here**_]
return Yarray
注意:Xarray和Yarray都是单列向量,每个索引中的数据将这两个数组链接为一组x,y坐标。它们都是numpy数组。Xarray包含时间信息且不变,而Yarray包含信号数据,包括需要用抛物线数据替换的损坏段,这些数据需要通过这个函数计算出来。
1 个回答
9
根据我的理解,你有三个点,想要把它们拟合成一个抛物线。
通常来说,最简单的方法是使用 numpy.polyfit,不过如果你特别在意速度,而且只需要拟合这三个点,那就没必要用最小二乘法了。
相反,我们这里是一个确定的系统(就是把抛物线拟合到这三个 x,y 点上),可以通过简单的线性代数得到精确的解。
所以,总的来说,你可以这样做(大部分内容是绘制数据):
import numpy as np
import matplotlib.pyplot as plt
def main():
# Generate some random data
x = np.linspace(0, 10, 100)
y = np.cumsum(np.random.random(100) - 0.5)
# Just selecting these arbitrarly
left_idx, right_idx = 20, 50
# Using the mininum y-value within the arbitrary range
min_idx = np.argmin(y[left_idx:right_idx]) + left_idx
# Replace the data within the range with a fitted parabola
new_y = replace_data(x, y, left_idx, right_idx, min_idx)
# Plot the data
fig = plt.figure()
indicies = [left_idx, min_idx, right_idx]
ax1 = fig.add_subplot(2, 1, 1)
ax1.axvspan(x[left_idx], x[right_idx], facecolor='red', alpha=0.5)
ax1.plot(x, y)
ax1.plot(x[indicies], y[indicies], 'ro')
ax2 = fig.add_subplot(2, 1, 2)
ax2.axvspan(x[left_idx], x[right_idx], facecolor='red', alpha=0.5)
ax2.plot(x,new_y)
ax2.plot(x[indicies], y[indicies], 'ro')
plt.show()
def fit_parabola(x, y):
"""Fits the equation "y = ax^2 + bx + c" given exactly 3 points as two
lists or arrays of x & y coordinates"""
A = np.zeros((3,3), dtype=np.float)
A[:,0] = x**2
A[:,1] = x
A[:,2] = 1
a, b, c = np.linalg.solve(A, y)
return a, b, c
def replace_data(x, y, left_idx, right_idx, min_idx):
"""Replace the section of "y" between the indicies "left_idx" and
"right_idx" with a parabola fitted to the three x,y points represented
by "left_idx", "min_idx", and "right_idx"."""
x_fit = x[[left_idx, min_idx, right_idx]]
y_fit = y[[left_idx, min_idx, right_idx]]
a, b, c = fit_parabola(x_fit, y_fit)
new_x = x[left_idx:right_idx]
new_y = a * new_x**2 + b * new_x + c
y = y.copy() # Remove this if you want to modify y in-place
y[left_idx:right_idx] = new_y
return y
if __name__ == '__main__':
main()
希望这能帮到你一点……