使用Matplotlib进行主成分分析
我有一个大约1500行、包含6个变量的数据集,我用主成分分析(PCA)处理这些数据,然后用以下代码展示出来:
from matplotlib.mlab import PCA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
import numpy as np
data = np.array(data)
try:
results = PCA(data)
except:
raise
#this will return an array of variance percentages for each component
print results.fracs
#
print results.Wt
#this will return a 2d array of the data projected into PCA space
print results.Y
x = []
y = []
z = []
for item in results.Y:
x.append(item[0])
y.append(item[1])
z.append(item[2])
fig1 = plt.figure() # Make a plotting figure
ax = Axes3D(fig1) # use the plotting figure to create a Axis3D object.
pltData = [x,y,z]
ax.scatter(pltData[0], pltData[1], pltData[2], 'bo') # make a scatter plot of blue dots from the data
# make simple, bare axis lines through space:
xAxisLine = ((min(pltData[0]), max(pltData[0])), (0, 0), (0,0)) # 2 points make the x-axis line at the data extrema along x-axis
ax.plot(xAxisLine[0], xAxisLine[1], xAxisLine[2], 'r') # make a red line for the x-axis.
yAxisLine = ((0, 0), (min(pltData[1]), max(pltData[1])), (0,0)) # 2 points make the y-axis line at the data extrema along y-axis
ax.plot(yAxisLine[0], yAxisLine[1], yAxisLine[2], 'r') # make a red line for the y-axis.
zAxisLine = ((0, 0), (0,0), (min(pltData[2]), max(pltData[2]))) # 2 points make the z-axis line at the data extrema along z-axis
ax.plot(zAxisLine[0], zAxisLine[1], zAxisLine[2], 'r') # make a red line for the z-axis.
# label the axes
ax.set_xlabel("x-axis label")
ax.set_ylabel("y-axis label")
ax.set_zlabel("y-axis label")
ax.set_title("The title of the plot")
plt.show() # show the plot
现在我想要做的是,比如说,如何根据另一个数据变量来给显示的点上色。举个例子,如果我在每一行数据中添加一个叫做颜色的变量,取值范围是 ['蓝色', '红色', '绿色'],那么我能用这个变量来给点上色吗?
1 个回答
1
我觉得最简单的方法就是,像你说的那样,先准备一个颜色列表,比如 colors=[ 'blue', 'red', 'green',...],然后把每个数据点单独画出来:
for i,xi in enumerate(x):
ax.scatter(x[i],y[i],z[i],color=colors[i])
还有一种方法是看看 这个问题。在这里,你可以给每个点分配一个数字,这个数字会在颜色图上显示为一种颜色。这样你就可以展示某个参数的渐变效果。
编辑:为了回应你的评论,颜色可以用一个包含3个值的元组(RGB)来表示,这些值的范围在0到1之间。
from matplotlib.colors import Normalize
xnorm=Normalize(x.min(),x.max())
ynorm=Normalize(y.min(),y.max())
colors=[(xnorm(x[i]),ynorm(y[i]),0) for i in range(x.size)]
ax.scatter(x,y,c=colors)