使用Tensorflow的numpy_input fn抛出“list object has no attribute shape”

2024-06-07 09:11:22 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图用RNN创建一个文本分类器。 这个分级机.列车line抛出错误:

    model_fn = rnn_model
    classifier = tf.estimator.Estimator(model_fn=model_fn)

    # Train.
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={WORDS_FEATURE: x_train},
        y=y_train,
        batch_size=len(x_train),
        num_epochs=None,
        shuffle=True)
    classifier.train(input_fn=train_input_fn, steps=100)

这就是Xu火车的样子:

^{pr2}$

这是错误: enter image description here

我使用的是python3.4和Tensorflow 1.4

我知道我需要把名单改成np.数组但我不知道在哪里。在


Tags: 文本inputmodel分类器tf错误linetrain
2条回答

tf.estimator.inputs.numpy_input_fn()函数要求x字典中的所有都是NumPy数组。您可以执行以下必要的转换:

train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={WORDS_FEATURE: np.array(x_train)},  # Convert `x_train` to a NumPy array.
    y=y_train,
    batch_size=len(x_train),
    num_epochs=None,
    shuffle=True)

注意,只有当x_train是一个列表列表,其中每个嵌套列表的长度相同时,这才有效。如果不是,则需要将每个嵌套列表填充到相同的长度。在

我不知道为什么这个问题被否决了,这是一个合理的问题。在

答案是您的y_train很可能是一个列表,将其转换为numpy数组应该可以解决这个问题。在

相关问题 更多 >

    热门问题