我想在生成器中设置我的LSTM隐藏状态。但是,状态集仅在发电机外部工作:
K.set_value(model.layers[0].states[0], np.random.randn(batch_size,num_outs)) # this works
def gen_data():
x = np.zeros((batch_size, num_steps, num_input))
y = np.zeros((batch_size, num_steps, num_output))
while True:
for i in range(batch_size):
K.set_value(model.layers[0].states[0], np.random.randn(batch_size,num_outs)) # error
x[i, :, :] = X_train[gen_data.current_idx]
y[i, :, :] = Y_train[gen_data.current_idx]
gen_data.current_idx += 1
yield x, y
gen_data.current_idx = 0
在fit_generator
函数中调用生成器:
model.fit_generator(gen_data(), len(X_train)//batch_size, 1, validation_data=None)
这是我打印状态时的结果:
print(model.layers[0].states[0])
<tf.Variable 'lstm/Variable:0' shape=(1, 2) dtype=float32>
这是生成器中发生的错误:
ValueError: Tensor("Placeholder_1:0", shape=(1, 2), dtype=float32) must be from the same graph as Tensor("lstm/Variable:0", shape=(), dtype=resource)
我做错什么了?你知道吗
生成器是多线程的,因此生成器中使用的图形将在不同于创建图形的线程中运行。因此,访问模型表单生成器将访问不同的图形。一个简单(但不好)的解决方案是通过设置
workers=0
,强制生成器在与创建图形的线程相同的线程中运行。你知道吗调试代码:
输出
您可以看到图形对象是不同的。生成
workers=0
将强制生成器运行单线程。你知道吗使用
结果
同一个单线程生成器可以访问同一个图形。你知道吗
但是,要启用多线程生成器,一个优雅的方法是将图形保存到创建图形的主进程中的一个变量中,并将其传递给使用传递的图形作为默认图形的生成器。你知道吗
相关问题 更多 >
编程相关推荐