TypeError:传递给参数“indexes”的值的数据类型float32不在允许值列表中:int32,int64

2024-04-28 03:49:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用Keras构建模型,模型中有两个输入,其数据类型为“int32”。然后我使用keras Lamba层在嵌入矩阵中查找K.gather(reference,index)。我看到索引应该是int的张量,我想我的代码符合这一点,我不知道为什么会出错。我真的需要帮助!在

    input_A = Input(batch_shape=(128,1),name='A_input',dtype='int32')
    input_B = Input(batch_shape=(128,1),name='B_input',dtype='int32')

    input_A_ = Lambda(lambda x:K.reshape(x,(-1,)))(input_A)
    input_B_ = Lambda(lambda x:K.reshape(x, (-1,)))(input_B)

    input_A__ = Lambda(lambda x:K.cast(x,dtype='int32'))(input_A_)
    input_B__ = Lambda(lambda x:K.cast(x,dtype='int32'))(input_B_)

    embedded_text_A = Lambda(lambda x:K.gather(M1,x))(input_A__)
    embedded_text_B = Lambda(lambda x:K.gather(M1,x))(input_B__)

Tags: lambdatextname模型inputbatchembeddedshape
1条回答
网友
1楼 · 发布于 2024-04-28 03:49:59

出于某种神秘的原因,如果将K.cast()放在lambda内,它将正常工作:

input_A = Input(batch_shape=(128,1), name='A_input', dtype='int32')
input_B = Input(batch_shape=(128,1), name='B_input', dtype='int32')

input_A_ = Lambda(lambda x: K.reshape(x, (-1,)))(input_A)
input_B_ = Lambda(lambda x: K.reshape(x, (-1,)))(input_B)

embedded_text_A = Lambda(lambda x: K.gather(M1, K.cast(x, dtype='int32')))(input_A_)
embedded_text_B = Lambda(lambda x: K.gather(M1, K.cast(x, dtype='int32')))(input_B_)

因此,Lambda层在其中进行了一些奇怪的数据类型转换。在

我想这是某种bug,我的假设是隐式转换发生在Lambda__call__(which is inherited from ^{})内部。我无法跟踪它,但我想“隐式转换”bug在Layer.__call__中的某个地方,但在451行之前,实际上调用了Lambda.call。在

相关问题 更多 >