用Python创建3D平面并进行回归分析
我对编程了解得不多,但我正在尝试用两个独立变量进行回归分析。我知道怎么在Python中创建一个3D散点图,也知道怎么进行回归分析,但我不知道怎么把这些东西结合起来。让我给你看看我想做的事情:回归分析示例。这是我用Excel做的,里面有散点图、趋势线、趋势线的方程和R²值。我现在想看看能不能用Python创建一个3D版本(多一个独立变量)。
我在ChatGPT的帮助下开始了这个项目,但结果并不是我期待的样子。这里是我的代码和输出结果:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from sklearn.linear_model import LinearRegression
x= *insert my numbers*
y= *insert my numbers*
z= *insert my numbers*
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z)
model = LinearRegression()
model.fit(np.column_stack((x, y)), z)
coef_x, coef_y = model.coef_
intercept = model.intercept_
x_plane, y_plane = np.meshgrid(np.linspace([min(x), max(x)], 10), np.linspace([min(y), max(y)], 10))
z_plane = coef_x * x_plane + coef_y * y_plane + intercept
ax.plot_surface(x_plane, y_plane, z_plane, alpha=0.5)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
plt.show()
虽然散点图看起来很完美,但这个平面并不是我想要的。我希望这个平面能覆盖整个图,而不仅仅是从原点开始的两条线。我的代码中是不是缺少了什么?制作一个3D版本的Excel图表是否真的可行?希望有人能告诉我。
1 个回答
0
你不需要使用任何库。其实你可以自己写一个标准的最小二乘分析。
依次对a、b、c进行偏微分,然后让它们等于0:
把这些整理成一组线性方程:
然后求解出明确的解:
例如:
还有:
等等。
示例代码:
import numpy as np
import matplotlib.pyplot as plt
def planeFit( x, y, z ):
'''
Calculate least-squares best fit for z = ax + by + c
'''
N = len( x )
xbar = sum( x ) / N
ybar = sum( y ) / N
zbar = sum( z ) / N
Sxx = sum( x * x ) / N - xbar * xbar
Syy = sum( y * y ) / N - ybar * ybar
Sxy = sum( x * y ) / N - xbar * ybar
Sxz = sum( x * z ) / N - xbar * zbar
Syz = sum( y * z ) / N - ybar * zbar
denominator = Sxx * Syy - Sxy ** 2
a = ( Syy * Sxz - Sxy * Syz ) / denominator
b = ( Sxx * Syz - Sxy * Sxz ) / denominator
c = zbar - a * xbar - b * ybar
return a, b, c
# Enter data
x = np.array( [ 9, 12, 15, 18 ] )
y = np.array( [ 2, 6, 4, 1 ] )
z = 5 * x - 3 * y + 7
z[3] = z[3] - 25 # a little bit of noise
# Fit the plane as z = ax+by+c
a, b, c = planeFit( x, y, z )
fmt = "{:8.3f} "
print( "Best fit a, b, c = ", ( 3 * fmt ).format( a, b, c ) )
# Check numerical values
zfit = a * x + b * y + c
print( "Values x, y, z, zfit:" )
for i in range( len( x ) ):
print( ( 4 * fmt ).format( x[i], y[i], z[i], zfit[i] ) )
# Check plot
ax = plt.figure().add_subplot( projection='3d' )
ax.scatter( x, y, z, color='r' )
x1d = np.linspace( min( x ), max( x ), 5 )
y1d = np.linspace( min( y ), max( y ), 5 )
x2d, y2d = np.meshgrid( x1d, y1d )
zplane = a * x2d + b * y2d + c
ax.plot_surface( x2d, y2d, zplane, rstride=1, cstride=1, alpha=0.3 )
ax.set_xlabel( 'X' )
ax.set_ylabel( 'Y' )
ax.set_zlabel( 'Z' )
plt.show()