Keras Lambda层在函数API中抛出ndim错误,但在Sequenti中不抛出

2021-04-12 00:49:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图从每个时间步的LSTM层获取输出,并且仅在最后一个时间步(步骤输出和上下文向量)分别获得,所以我发现解决方法是创建一个lambda层,用return_sequences=True从LSTM中提取上下文向量。在顺序模型中,它工作得很好,但是当我试图在函数式API中实现它时,它突然不再接受维度,声明所有的东西都是ndim=1的,即使它不是。 代码:

def ContextVector(x):
    return x[-1][-1]
def ContextVectorOut(input_shape):
    print([None, input_shape[-1]])
    print((input_shape[::2]))
    print(input_shape)
    return list((None, input_shape[-1]))

input_layer = Input(shape=(10, 5))
LSTM_layer = LSTM(5, return_sequences=True)(input_layer)
context_layer = Lambda(ContextVector, output_shape=ContextVectorOut)(LSTM_layer)
repeat_context_layer = RepeatVector(10, name='context')(context_layer)
timed_dense = TimeDistributed(Dense(10))(LSTM_layer)
connected_dense = Dense(2)
connect_dense_context = connected_dense(repeat_context_layer)
connect_dense_time = connected_dense(timed_dense)
concat_out = concatenate([connect_dense_context, connect_dense_time])
output_dense = Dense(5)(concat_out)
model = Model(inputs = [input_layer], output = output_dense)

#model.add(LSTM(20, input_shape = (10, 5), return_sequences=True))
#model.add(Lambda(ContextVector, output_shape=ContextVectorOut))
#model.add(Dense(1))

model.summary()

错误:

^{pr2}$