如何在Keras2.0中连接2个嵌入层与“mask_zero=True”

2024-04-25 11:36:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个嵌入层,一个指定mask_zero=True,另一个没有,如下面定义的那样。在

a = Input(shape=[30])
b = Input(shape=[30])
emb_a = Embedding(10, 5, mask_zero=True)(a)
emb_b = Embedding(20, 5, mask_zero=False)(b)
cat = Concatenate(axis=1)([emb_a, emb_b]) # problem here
model = Model(inputs=[a, b], outputs=[cat])

当我试图在axis=1处连接它们时,我期望输出的大小为[None, 60, 5],但它引发了一个错误:

^{pr2}$

为什么emb_a的形状变成[None, 30, 1]?为什么还有另一个空张量[]输入到串联中?在

如果两个嵌入层都被分配了mask_zero=True,或者两者都被分配了mask_zero=False,则不会引发此错误。 如果它们在axis=2处串联,也不会引发此错误。在

我的keras版本是2.0.8。在

谢谢。在


Tags: nonefalsetrueinput定义错误maskembedding
1条回答
网友
1楼 · 发布于 2024-04-25 11:36:08

因为在一种情况下,mask_zero=True,而在另一种情况下是{},这导致了一些内部问题(这不应该发生),可能是一个bug,您可以在Github上报告它。在

目前,我认为有效的两个选项是只对两个嵌入使用其中一个:mask_zero=True或{}

a = Input(shape=[30])
b = Input(shape=[30])
emb_a = Embedding(10, 5)(a)
emb_b = Embedding(20, 5)(b)
cat = Concatenate(axis=1)([emb_a, emb_b])
model = Model(inputs=[a, b], outputs=[cat])

print(model.output_shape) # (None, 60, 5)

解决此问题的另一种方法是在axis=-1上连接

^{pr2}$

相关问题 更多 >