tf.contrib.学习快速启动:修复float64警告

2024-04-27 02:20:36 发布

您现在位置:Python中文网/ 问答频道 /正文

通过阅读发布的教程,我开始使用TensorFlow。在

我有运行在Fedora23(二十三)上的LinuxCPU python2.7版本0.10.0。在

我正在尝试tf.contrib.学习根据下面的代码快速入门教程。在

https://www.tensorflow.org/versions/r0.10/tutorials/tflearn/index.html#tf-contrib-learn-quickstart

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

# Data sets
IRIS_TRAINING = "IRIS_data/iris_training.csv"
IRIS_TEST = "IRIS_data/iris_test.csv"

# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING,
                                                   target_dtype=np.int)
test_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TEST,
                                               target_dtype=np.int)

# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                        hidden_units=[10, 20, 10],
                                        n_classes=3,
                                        model_dir="/tmp/iris_model")

# Fit model.
classifier.fit(x=training_set.data, 
           y=training_set.target, 
           steps=2000)

# Evaluate accuracy.
accuracy_score = classifier.evaluate(x=test_set.data,
                                 y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))

# Classify two new flower samples.
new_samples = np.array(
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = classifier.predict(new_samples)
print('Predictions: {}'.format(str(y)))

代码执行,但给出float64警告。因此:

^{pr2}$

注意:将“load_csv()”替换为“load_csv_with_header()”将生成正确的预测。但float64警告仍然存在。在

我尝试过显式地列出dtype(np.int32; np.浮动32; tf.int32型; tf.float32型)用于训练集、测试集和新样本。在

我还尝试了“选角”专栏:

feature_columns = tf.cast(feature_columns, tf.float32)

float64的问题是已知的开发问题,但我想知道是否有一些解决方法?在


Tags: columnscsvtestimportiristargetdatatf
1条回答
网友
1楼 · 发布于 2024-04-27 02:20:36

我通过github从开发团队那里得到了这个答案。在

Hi @qweelar, the float64 warning is due to a bug with the load_csv_with_header function that was fixed in commit b6813bd. This fix isn't in TensorFlow release 0.10, but should be in the next release.

In the meantime, for the purposes of the tf.contrib.learn quickstart, you can safely ignore the float64 warning.

(Side note: In terms of the other deprecation warning, I will be updating the tutorial code to use load_csv_with_header, and will update this issue when that's in place.)

相关问题 更多 >