在3D中拟合一条线
有没有什么算法可以从一组三维数据点中得到一条直线的方程?我能找到很多关于从二维数据集中得到直线方程的资料,但在三维方面却没有。
2 个回答
6
如果你的数据表现得比较正常,那么找到各个点到线的最小平方和就足够了。接着,你可以找到线性回归,让z与x无关,再让z与y无关。
根据文档中的例子:
import numpy as np
pts = np.add.accumulate(np.random.random((10,3)))
x,y,z = pts.T
# this will find the slope and x-intercept of a plane
# parallel to the y-axis that best fits the data
A_xz = np.vstack((x, np.ones(len(x)))).T
m_xz, c_xz = np.linalg.lstsq(A_xz, z)[0]
# again for a plane parallel to the x-axis
A_yz = np.vstack((y, np.ones(len(y)))).T
m_yz, c_yz = np.linalg.lstsq(A_yz, z)[0]
# the intersection of those two planes and
# the function for the line would be:
# z = m_yz * y + c_yz
# z = m_xz * x + c_xz
# or:
def lin(z):
x = (z - c_xz)/m_xz
y = (z - c_yz)/m_yz
return x,y
#verifying:
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
fig = plt.figure()
ax = Axes3D(fig)
zz = np.linspace(0,5)
xx,yy = lin(zz)
ax.scatter(x, y, z)
ax.plot(xx,yy,zz)
plt.savefig('test.png')
plt.show()
如果你想最小化从线到三维空间中点的实际垂直距离(也就是与线垂直的距离),那么我建议你写一个函数来计算残差平方和(RSS),然后使用scipy.optimize中的最小化函数来解决这个问题。
64
如果你想通过两个值来预测第三个值,那么你应该使用 lstsq
这个方法,把 a
作为自变量(还要加一列全是1的数字来估算截距),把 b
作为因变量。
但是,如果你只是想找到一条最适合数据的直线,也就是那条如果把数据投影到上面,可以让实际点和它的投影之间的距离平方和最小的直线,那么你需要找的是第一主成分。
简单来说,第一主成分就是一条通过你数据的平均值的直线,它的方向是与协方差矩阵中最大特征值对应的特征向量。需要注意的是,使用 eig(cov(data))
来计算这个主成分并不是个好方法,因为它会做很多不必要的计算和复制,而且可能不如使用 svd
来得准确。下面有示例:
import numpy as np
# Generate some data that lies along a line
x = np.mgrid[-2:5:120j]
y = np.mgrid[1:9:120j]
z = np.mgrid[-5:3:120j]
data = np.concatenate((x[:, np.newaxis],
y[:, np.newaxis],
z[:, np.newaxis]),
axis=1)
# Perturb with some Gaussian noise
data += np.random.normal(size=data.shape) * 0.4
# Calculate the mean of the points, i.e. the 'center' of the cloud
datamean = data.mean(axis=0)
# Do an SVD on the mean-centered data.
uu, dd, vv = np.linalg.svd(data - datamean)
# Now vv[0] contains the first principal component, i.e. the direction
# vector of the 'best fit' line in the least squares sense.
# Now generate some points along this best fit line, for plotting.
# I use -7, 7 since the spread of the data is roughly 14
# and we want it to have mean 0 (like the points we did
# the svd on). Also, it's a straight line, so we only need 2 points.
linepts = vv[0] * np.mgrid[-7:7:2j][:, np.newaxis]
# shift by the mean to get the line in the right place
linepts += datamean
# Verify that everything looks right.
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d as m3d
ax = m3d.Axes3D(plt.figure())
ax.scatter3D(*data.T)
ax.plot3D(*linepts.T)
plt.show()
这是它的样子: