我有一个架构,在输入RNN之前使用编码器。编码器输入形状为[batch, height, width, channels]
,RNN输入形状为[batch, time, height, width, channels]
。我想把编码器的输出直接传送到RNN,但这会造成内存问题。我必须一次将batch*time ~= 3*100
(通过重塑)图像输入编码器。我知道tf.nn.dynamic_rnn
可以利用swap_memory
,我也想在编码器中利用它。下面是一些简明代码:
#image inputs [batch, time, height, width, channels]
inputs = tf.placeholder(tf.float32, [batch, time, in_sh[0], in_sh[1], in_sh[2]])
#This is where the trouble starts
#merge batch and time
inputs = tf.reshape(inputs, [batch*time, in_sh[0], in_sh[1], in_sh[2]])
#build the encoder (and get shape of output)
enc, enc_sh = build_encoder(inputs)
#change back to time format
enc = tf.reshape(enc, [batch, time, enc_sh[0], enc_sh[1], enc_sh[2]])
#build rnn and get initial state (zero_state)
rnn, initial_state = build_rnn()
#use dynamic unrolling
rnn_outputs, rnn_state = tf.nn.dynamic_rnn(
rnn, enc,
initial_state=initial_state,
swap_memory=True,
time_major=False)
我目前使用的方法是在我的所有图像上运行编码器(并保存到光盘),但我想执行数据集扩充(到图像),一旦提取了特征就不可能了。你知道吗
对于任何其他遇到这个问题的人。我从
RNNCell
派生了一个包装器,它完成了我需要的东西。model_fn
是一个使用输入构建子图并返回输出张量的函数。不幸的是,输出形状必须是已知的(至少我不能让它以其他方式工作)。你知道吗相关问题 更多 >
编程相关推荐