在一张图中绘制散点图
我想用散点图来展示我的数据,但是我好像无法把三个图放到一张图里。请问我该怎么解决这个问题呢:
nu_cluster = 3
kmeans = KMeans(n_clusters=nu_cluster,random_state=0)
data_df["cluster"] = kmeans.fit_predict(X_std)
print("after Kmeans predict")# visualization
plt.figure(figsize=(8, 6))
for i in range(nu_cluster):
cluster_data = data_df[data_df["cluster"] == i]#return a boolean and then passed to data_df
plt.scatter(cluster_data["charges"], cluster_data["age"],c=[plt.cm.viridis(i / (nu_cluster - 1))] ,label=f"Cluster {i + 1}")
plt.xlabel("Charges")
plt.ylabel("Age")
plt.title("Cluster of age against charges", fontsize=16, fontweight="bold")
plt.legend(loc="lower right")
plt.show()
我试过使用matplotlib提供的figure()函数,但没有成功。plt.figure(figsize=(8,7))
3 个回答
0
在你的循环每次运行时,实际上是在覆盖之前的图表——可以理解为你的图表没有被“保存”到任何地方。
在循环中你只需要:
plt.figure(figsize=(8, 6))
for i in range(nu_cluster):
cluster_data = data_df[data_df["cluster"] == i]
plt.scatter(cluster_data["charges"], cluster_data["age"], c=[plt.cm.viridis(i / (nu_cluster - 1))], label=f"Cluster {i + 1}")
你可以把循环外面的代码去掉(不断给图表命名是不高效的,尽量只把需要重复执行的代码放在循环里)。
关于你的评论,散点图非常适合用来展示两个连续的变量。如果你发现有趋势,可以在后面加一条线,但如果你的变量之间没有多项式或线性关系,那加线就没什么意义了。想想看,把线图加到聚类图上会有多大用处?
0
你想用Pyplot的subplots()
函数在同一个图形中创建多个坐标轴(或者说多个图)。这个subplots函数会返回两个东西,一个是figure(图形),另一个是一个axes的列表。
- https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplots.html 这里可以找到关于subplots的详细信息。
简单来说:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(3)
axs[0].plot([1,2,3])
axs[1].plot([1,2,3], [1,2,3])
axs[2].plot([1,2,3], [2,4,6])
fig.show()
1
把 plt.show()
移到循环外面去。一旦你调用了 plt.show()
,图形就会被“丢弃”,之后的绘图命令会自动创建一个新的图形。
更好的方法是使用 Matplotlib 的面向对象的明确接口:
nu_cluster = 3
kmeans = KMeans(n_clusters=nu_cluster,random_state=0)
data_df["cluster"] = kmeans.fit_predict(X_std)
print("after Kmeans predict")# visualization
fig, ax = plt.subplots(figsize=(8, 6))
for i in range(nu_cluster):
cluster_data = data_df[data_df["cluster"] == i]#return a boolean and then passed to data_df
ax.scatter(cluster_data["charges"], cluster_data["age"],c=[plt.cm.viridis(i / (nu_cluster - 1))] ,label=f"Cluster {i + 1}")
ax.set_xlabel("Charges")
ax.set_ylabel("Age")
ax.set_title("Cluster of age against charges", fontsize=16, fontweight="bold")
ax.legend(loc="lower right")
fig.show()