批量输入共享参数

2024-03-29 11:32:50 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在建立一个网络来排列一组N个输入。理想情况下,它们应该同时输入并共享参数。它们的目标向量应该是一个N-hot向量来匹配输入。你知道吗

这意味着我的输入应该是(批大小,N,序列长度,特征长度)

但是对于任何大于3维的输入,keras都会抛出一个错误,如下所示:

ValueError: Input 0 is incompatible with layer lstm_2: expected ndim=3, found ndim=4

我目前的keras设置是:

x = Input(shape=(72,300))
aux_input = Input(shape=(72, 4))
probs = Input(shape=(1,))
#dim_red_1 = Dense(100)(x)
dim_red_2 = Dense(20, activation='tanh')(x)
cat = concatenate([dim_red_2, aux_input])
encoded = LSTM(64)(cat)
cat2 = concatenate([encoded, probs])
output = Dense(1, activation='sigmoid')(cat2)

lstm_model = Model(inputs=[x, aux_input, probs], outputs=output)
lstm_model.compile(optimizer='ADAM', loss='binary_crossentropy', metrics=['accuracy'])

有没有办法用Keras实现这一点?你知道吗


Tags: inputredactivation向量catkerasdenseshape
1条回答
网友
1楼 · 发布于 2024-03-29 11:32:50

尽管您的代码看起来不错,但请确保导入正确的包:

import numpy as np
from tensorflow.python.keras import Input
from tensorflow.python.keras.engine.training import Model
from tensorflow.python.keras.layers import Dense, LSTM, Concatenate

a = np.zeros(shape=[1000, 72, 300])
b = np.zeros(shape=[1000, 72, 4])
c = np.zeros(shape=[1000, 1])
d = np.zeros(shape=[1000, 1])

x = Input(shape=(72, 300))
aux_input = Input(shape=(72, 4))
probs = Input(shape=(1,))
dim_red_2 = Dense(20, activation='tanh')(x)
cat = Concatenate()([dim_red_2, aux_input])
encoded = LSTM(64)(cat)
cat2 = Concatenate()([encoded, probs])
output = Dense(1, activation='sigmoid')(cat2)

lstm_model = Model(inputs=[x, aux_input, probs], outputs=output)
lstm_model.compile(optimizer='ADAM', loss='binary_crossentropy', metrics=['accuracy'])
lstm_model.summary()
lstm_model.fit([a, b, c], d, batch_size=256)

输出:

256/1000 [======>.......................] - ETA: 2s - loss: 0.6931 - acc: 1.0000
 512/1000 [==============>...............] - ETA: 1s - loss: 0.6910 - acc: 1.0000
 768/1000 [======================>.......] - ETA: 0s - loss: 0.6885 - acc: 1.0000
1000/1000 [==============================] - 1s 1ms/step - loss: 0.6859 - acc: 1.00

相关问题 更多 >