具有批量大小和shu的Tensorflow输入函数

2024-04-16 19:12:45 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图用训练批处理(). 我有训练,评估和预测的数据帧。所以input_fn应该取df,batch_size的参数。在df中有连续的和分类的列。在

修订代码:

COLUMNS = ['atemp', 'holiday', 'humidity', 'season', 'temp', 'weather', 'windspeed', 'workingday', 'hour', 'weekday', 'month', 'label']

CONTINUOUS_COLUMNS = ['atemp',  'humidity',  'temp',  'windspeed',]
CATEGORICAL_COLUMNS =[ 'holiday', 'season', 'weather',
                      'workingday', 'weekday', 'month', 'hour' ]

LEARNING_RATE = 0.1
LABEL_COLUMN = 'label'
batch_size = 128

data_set =  pd.read_excel('bike_str.xlsx')

# Split the data into a training set, an eval set and a pred set.
train_set = data_set[:9500]
eval_set = data_set[9500:10800]
pred_set = data_set[10800:]

## Eval and Prediction labels:

eval_label = eval_set['label']
pred_label = pred_set['label']

输入端:

^{pr2}$

直接调用批处理输入时,错误为:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-9c356159093d> in <module>()
----> 1 dnnregressor.fit(input_fn= lambda: batch_input_fn(train_set, batch_size), steps=15000 )

C:\Python\Anaconda\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
    287             'in a future version' if date is None else ('after %s' % date),
    288             instructions)
--> 289       return func(*args, **kwargs)
    290     return tf_decorator.make_decorator(func, new_func, 'deprecated',
    291                                        _add_deprecated_arg_notice_to_docstring(

C:\Python\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\estimators\estimator.py in fit(self, x, y, input_fn, steps, batch_size, monitors, max_steps)
    453       hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
    454 
--> 455     loss = self._train_model(input_fn=input_fn, hooks=hooks)
    456     logging.info('Loss for final step: %s.', loss)
    457     return self

C:\Python\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\estimators\estimator.py in _train_model(self, input_fn, hooks)
    951       random_seed.set_random_seed(self._config.tf_random_seed)
    952       global_step = contrib_framework.create_global_step(g)
--> 953       features, labels = input_fn()
    954       self._check_inputs(features, labels)
    955       model_fn_ops = self._get_train_ops(features, labels)

TypeError: 'function' object is not iterable

从这段代码看来,这是可行的,但这里张量不是dict的列表:

def batched_input_fn(dataset_x, dataset_y, batch_size):
    def _input_fn():
        all_x = tf.constant(dataset_x, shape=dataset_x.shape, dtype=tf.float32)
        all_y = tf.constant(dataset_y, shape=dataset_y.shape, dtype=tf.float32)
        sliced_input = tf.train.slice_input_producer([all_x, all_y])
        return tf.train.batch(sliced_input, batch_size=batch_size)
return _input_fn

Tags: inselfinputdatasizereturntfbatch