无法获取tensorflow DNNClassifi的预测

2024-04-25 19:50:35 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用MNIST教程中的代码:

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=2,
                                            model_dir="/tmp/iris_model")

classifier.fit(x=np.array(train, dtype = 'float32'),
               y=np.array(y_tr, dtype = 'int64'),
               steps=2000)

accuracy_score = classifier.evaluate(x=np.array(test, dtype = 'float32'),
                                     y=y_test)["auc"]
print('AUC: {0:f}'.format(accuracy_score))

from tensorflow.contrib.learn import SKCompat
ds_test_ar = np.array(ds_test, dtype = 'float32')

ds_predict_tf = classifier.predict(input_fn = _my_predict_data)
print('Predictions: {}'.format(str(ds_predict_tf)))

但最后我得到了以下结果而不是预测:

Predictions: <generator object DNNClassifier.predict.<locals>.<genexpr> at 0x000002CE41101CA8>

我做错了什么?


Tags: columnstestmodeltfnpdscontribarray
3条回答

您接收并保存到ds_predict_tf的是一个生成器表达式。 要打印它,您可以:

for i in ds_predict_tf:
    print i

或者

print(list(ds_predict_tf))

您可以阅读有关genexprhere的更多信息。

The DNNClassifier predict function by default have as_iterable=True. Thus, it returns an generator. For getting values of predictions instead of generator, pass as_iterable=False in classifier.predict method.

例如

classifier.predict(input_fn = _my_predict_data,as_iterable=False)



以了解有关分类器方法和参数的更多信息。以下是预测方法文档的一部分。

来自DNNClassifier文档:

预测

Args:

  • x: 特点。
  • 输入:输入功能。如果设置,x必须为无。
  • 批大小:覆盖默认批大小。
  • 输出:str列表,要预测的输出的名称。如果没有,则返回类。
  • as-iterable:如果为True,则返回一个iterable,该iterable将为每个示例生成预测,直到输入用尽为止。注意:如果您希望iterable终止,则输入必须终止(例如,如果您使用类似于read_batch的功能,请确保传递num_epochs=1)。

Returns:

  • 具有形状[批处理大小](如果as iterable为True,则为或可预测类的iterable)的预测类的Numpy数组。每个预测类都由其类索引表示(即从0到n_类-1的整数)。如果设置了输出,则返回预测结果。

解决方案:-

pred = classifier.fit(x=training_set.data, y=training_set.target, steps=2000).predict(test_set.data)

print ("Predictions:")

print(list(pred))

就这样。。。

相关问题 更多 >