如何在散点图中显示图例以区分类别

2024-06-16 12:15:24 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在研究sklearn的iris数据集。正如您所知,iris数据集有3个类['setosa','VersionColor','virginica']。我已经为这个数据集绘制了散点图。详情如下:

from sklearn.datasets import load_iris
iris=load_iris()
Y_train=iris.target
X_train=iris.data
class_labels=iris.target_names
plt.scatter(X_train[:,0], X_train[:,1], c=Y_train)
plt.xlabel('attr1')
plt.ylabel('attr2')
plt.show()

Saccter plot:

我有散点图,你可以看到黄色、绿色和紫色的点。我想知道哪个色点属于哪个类别('setosa','versicolor','virginica')。我想显示图例,以便我知道哪个颜色代表哪个类别


Tags: 数据fromiristarget绘制loadtrainplt
1条回答
网友
1楼 · 发布于 2024-06-16 12:15:24

在这种情况下,您可以通过循环标签并使用与散点图相同的colormapnorm来创建custom legend。默认情况下,使用'viridis'颜色映射,并使用将最小颜色值映射为零,将最大颜色值映射为一的范数

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

iris = load_iris()
Y_train = iris.target
X_train = iris.data
class_labels = iris.target_names
cmap = plt.get_cmap('viridis')
norm = plt.Normalize(Y_train.min(), Y_train.max())
plt.scatter(X_train[:, 0], X_train[:, 1], c=Y_train, cmap='viridis', norm=norm)
handles = [plt.Line2D([0, 0], [0, 0], color=cmap(norm(i)), marker='o', linestyle='', label=label)
           for i, label in enumerate(class_labels)]
plt.legend(handles=handles, title='Species')
plt.show()

scatter plot with legend

您也可以使用seaborn,尽管当前设置图例标签并不简单

import seaborn as sns

sns.set()
ax = sns.scatterplot(x=X_train[:, 0], y=X_train[:, 1], hue=Y_train, palette='viridis')
ax.legend(ax.legend_.legendHandles, class_labels, title='Species')

相关问题 更多 >