colorbar和matsh

2024-04-25 14:10:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我基本上想在下面的代码(link to code)中的每个子窗口添加一个颜色条。我尝试在最后一个子图的循环末尾添加所有颜色条。在

print(__doc__)

import matplotlib.pyplot as plt
from sklearn.datasets import fetch_mldata
from sklearn.neural_network import MLPClassifier

mnist = fetch_mldata("MNIST original")
# rescale the data, use the traditional train/test split
X, y = mnist.data / 255., mnist.target
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

# mlp = MLPClassifier(hidden_layer_sizes=(100, 100), max_iter=400, alpha=1e-4,
#                     solver='sgd', verbose=10, tol=1e-4, random_state=1)
mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
                solver='sgd', verbose=10, tol=1e-4, random_state=1,
                learning_rate_init=.1)

mlp.fit(X_train, y_train)
print("Training set score: %f" % mlp.score(X_train, y_train))
print("Test set score: %f" % mlp.score(X_test, y_test))

fig, axes = plt.subplots(4, 4)
# use global min / max to ensure all weights are shown on the same scale
vmin, vmax = mlp.coefs_[0].min(), mlp.coefs_[0].max()
for coef, ax in zip(mlp.coefs_[0].T, axes.ravel()):
ax.matshow(coef.reshape(28, 28), cmap=plt.cm.gray, vmin=.5 * vmin,
           vmax=.5 * vmax)
ax.set_xticks(())
ax.set_yticks(())

plt.show()

更新: 基于下面注释中的链接,下面是在图表右侧添加颜色条的代码

^{pr2}$

Tags: thetestimport颜色trainpltaxmax