在tensorflow(python)的GRUCell中,输入和隐藏状态的大小应该是多少?

2024-04-23 18:24:01 发布

您现在位置:Python中文网/ 问答频道 /正文

我是tensorflow的新手(1天的经验)。在

我尝试以下小代码创建一个简单的基于GRU的RNN,单层,隐藏大小为100,如下所示:

import pickle
import numpy as np
import pandas as pd
import tensorflow as tf

# parameters
batch_size = 50
hidden_size = 100

# create network graph
input_data = tf.placeholder(tf.int32, [batch_size])
output_data = tf.placeholder(tf.int32, [batch_size])

cell = tf.nn.rnn_cell.GRUCell(hidden_size)

initial_state = cell.zero_state(batch_size, tf.float32)

hidden_state = initial_state

output_of_cell, hidden_state = cell(input_data, hidden_state)

但最后一行的错误如下(即调用cell()

^{pr2}$

我做错什么了?在


Tags: importinputoutputdatasizetftensorflowas
1条回答
网友
1楼 · 发布于 2024-04-23 18:24:01

GRUCell的call运算符的输入应该是tf.float32类型的二维张量。以下内容应该起作用:

input_data = tf.placeholder(tf.float32, [batch_size, input_size])

cell = tf.nn.rnn_cell.GRUCell(hidden_size)

initial_state = cell.zero_state(batch_size, tf.float32)

hidden_state = initial_state

output_of_cell, hidden_state = cell(input_data, hidden_state)

相关问题 更多 >