Sklearn SGDClassizer“发现了dim为3的数组。估计值在重塑MNIST后应小于等于2”

2024-04-29 18:45:06 发布

您现在位置:Python中文网/ 问答频道 /正文

我不熟悉python和ml,只是尝试使用mnist 我的程序是这样的

from sklearn.datasets import fetch_openml
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.linear_model import SGDClassifier
import numpy as np

mist=fetch_openml('mnist_784',version=1)
mist.keys()
x,y=mist['data'],mist['target']
x.shape
images = x[1]
images=images.reshape(28,28)
plt.imshow(images)
plt.show()
y=y.astype(np.uint8)
y[0]
xtrain,xtest,ytrain,ytest=x[:60000],x[60000:],y[:60000],y[60000:]
ytrain_5=(ytrain==5)
ytest_5=(ytest==5)
sgd = SGDClassifier(random_state=42)
sgd.fit(xtrain,ytrain_5)
sgd.predict([images])

这就是投掷和错误:

找到了具有dim 3的数组。预计估计值<;=2.


Tags: fromimportmatplotlibaspltfetchsklearnimages
1条回答
网友
1楼 · 发布于 2024-04-29 18:45:06

使用此行:

sgd.predict(images.reshape(1, 784))

该算法是使用一个展平的shape(70000, 784)数组进行训练的,因此在通过它之前需要将images展平为shape(1, 784)。您之前已将其重新设置为(28, 28)

完整代码:

from sklearn.datasets import fetch_openml
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.linear_model import SGDClassifier
import numpy as np

mist=fetch_openml('mnist_784',version=1)
x,y=mist['data'],mist['target']
images = x[1]
images=images.reshape(28,28)
y=y.astype(np.uint8)
xtrain,xtest,ytrain,ytest=x[:60000],x[60000:],y[:60000],y[60000:]
ytrain_5=(ytrain==5)
ytest_5=(ytest==5)
sgd = SGDClassifier(random_state=42)
sgd.fit(xtrain,ytrain_5)
sgd.predict(images.reshape(1, 784))

相关问题 更多 >