我试图用tensorflow的LSTM模块来预测序列,但是没有用。我想不出这个问题,我希望有人能帮我一把。我的代码是:
首先,我主要创建合成数据,并准备数据加载器
x = np.linspace(0,30.,500)
y = x*np.sin(x) + 2*np.sin(5*x)
nb_steps = 20
def load_batch(batch_size = 32):
x_b = np.zeros((nb_steps,batch_size,1))
y_b = np.zeros((nb_steps*batch_size,1))
inds = np.random.randint(0, 479, (batch_size))
for i,ind in enumerate(inds):
x_b[:,i,0] = x[ind:ind+nb_steps]
y_b[i*nb_steps:(i+1)*nb_steps,0] = y[ind+1:ind+nb_steps+1]
return x_b, y_b
一些捷径
^{pr2}$接下来是我创建模型的部分
with tf.variable_scope('data'):
x_p = tf.placeholder(tf.float32, shape = [nb_steps, None, 1], name = 'x') # batch, steps, features
y_p = tf.placeholder(tf.float32, shape = [None, 1], name = 'labels')
with tf.variable_scope('network'):
cell = lstm(num_units = 100)
outputs, states = tf.nn.dynamic_rnn(cell, x_p, dtype = tf.float32, time_major = True)
reshaped_outputs = tf.reshape(outputs, [-1,100])
projection = dense(reshaped_outputs, 1, activation = None, name = 'projection')
上面是我最不确定的部分。我为每个时间步重塑lstm的输出,并将它们堆叠在第一个轴上(或者是这样?)。然后在整个线性矩阵层中发送。在
with tf.variable_scope('training'):
loss = tf.reduce_mean(tf.square(projection - y_p))
train_lstm = adam(1e-3).minimize(loss)
epochs = 1000
batch_size = 64
f, ax = plt.subplots(2,1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
mean_loss = 0.
for epoch in range(1,epochs+1):
x_b,y_b = load_batch(batch_size)
batch_loss,_ = sess.run([loss, train_lstm], feed_dict = {x_p:x_b, y_p:y_b})
mean_loss += batch_loss
if epoch%100 == 0:
print('Epoch: {} | Loss: {:.6f}'.format(epoch, mean_loss/100.))
mean_loss = 0.
while True :
x_b, y_b = load_batch(1)
pred = sess.run(projection, feed_dict = {x_p:x_b}).reshape(-1)
ax[0].plot(x,y, label= 'Real')
ax[0].plot(x_b.reshape(-1),y_b.reshape(-1), label= 'Real batch')
ax[0].plot(x_b.reshape(-1), pred, label = 'Pred')
ax[1].scatter(x_b.reshape(-1),y_b.reshape(-1), label= 'Real')
ax[1].scatter(x_b.reshape(-1), pred, label = 'Pred')
for a in ax: a.legend()
plt.pause(0.1)
input()
for a in ax:
a.clear()
非常感谢!在
每个LSTM单元产生100个输出,因此在执行tf.nn.动态你需要使输出变平。我宁愿用
在这行之后:
^{pr2}$代替这一行:
希望有帮助:)
编辑:我没注意到你用了time_major=True。我对你的代码做了一点修改,time_major=False,因为它更易于使用。在
我假设你想预测nb_步长的输出。在
代码:
^{4}$相关问题 更多 >
编程相关推荐