Scikit数字学习

2024-04-24 14:07:37 发布

您现在位置:Python中文网/ 问答频道 /正文

我不明白

iam正在尝试将scikit learn与带数字数据集的matplotlib一起使用

这是我的密码

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn import svm
from sklearn.model_selection import train_test_split

digits = datasets.load_digits()

clf = svm.SVC(gamma=0.001, C=100.)

X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=2017)

clf.fit(X_train, y_train)

pred = clf.predict(X_test)

print("Prediction: {}".format(pred))    
plt.imshow(digits.images[-1], cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()

matplotlib绘制4号图像,当我尝试打印预测时,它每次都会打印这个输出给我

^{pr2}$

我试图在matplotlib中打印数字,但它显示了这个输出

我希望打印prediction: 8


Tags: fromtestimportmatplotlibtrainplt数字sklearn
1条回答
网友
1楼 · 发布于 2024-04-24 14:07:37

好吧,下面是评论中的人想告诉你的:

X_test不是单个数据点,而是测试集的整个数据点。它包括多少个样品?在

X_test.shape
# (360, 64)

因此,它包含360个样本;因此,pred变量也必须包含所有这些360个样本的预测。事实上:

^{pr2}$

您想检查X_test中第一个样本的预测?在

pred[0]
# 8

这个预言的基本事实是什么?在

y_test[0]
# 8

看来你对第一个测试样本的预测是正确的。你想绘制这个样本(X_test[0])吗?您应该首先将其重塑为(8,8),因为train_test_splitload_digits()函数将原始8x8图像展平为长度为64的一维数组:

plt.imshow(X_test[0].reshape(8,8), cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()

enter image description here

嗯,看起来像8(有点)。。。在

相关问题 更多 >