在Keras网络中间使用辅助输入

0 投票
1 回答
23 浏览
提问于 2025-04-14 18:11

我有两个模型,一个是用来计算轨迹的,另一个是用来判断这些轨迹的。评分模型包含了所有的层,而目标模型只有前半部分。两个模型都需要目标参数来计算轨迹,然后分别进行评分。为了让评分模型能够使用目标数据,它需要把这些数据作为辅助输入(因为目标输出只有两个值)。我尝试用一个连接层来实现这个,但出现了错误:

# targeting_model
inp=layers.Input(4)

tdense1=layers.Dense(1024, activation='relu')(inp)
tdense2=layers.Dense(1024, activation='relu')(tdense1)
tdense3=layers.Dense(1024, activation='relu')(tdense2)

tout=layers.Dense(2, activation='linear')(tdense3)

# scoring_model
auxinp=layers.Input(4)

sinp=layers.Concatenate()([tout,auxinp])
sdense1=layers.Dense(1024, activation='relu')(sinp)
sdense2=layers.Dense(1024, activation='relu')(sdense1)
sout=layers.Dense(1, activation='linear')(sdense2)






targeting_model =  tf.keras.Model(inputs=inp, outputs=tout, name="targeting_model")
print(targeting_model.summary())
scoring_model =  tf.keras.Model(inputs=inp, outputs=sout, name="scoring_model")

这是错误信息:

ValueError`: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 4), dtype=tf.float32, name='input_2'), name='input_2', description="created by layer 'input_2'") at layer "concatenate". The following previous layers were accessed without issue: ['dense', 'dense_1', 'dense_2', 'dense_3']`

但是tout层是完全连接的,而auxinput层只是一个输入层。那为什么会说图是断开的呢?

补充:我搞明白了。需要在Model()中把auxinp添加到输入中。

1 个回答

0

需要把auxinp添加到模型的输入中

# targeting_model
inp=layers.Input(4)

tdense1=layers.Dense(1024, activation='relu')(inp)
tdense2=layers.Dense(1024, activation='relu')(tdense1)
tdense3=layers.Dense(1024, activation='relu')(tdense2)

tout=layers.Dense(2, activation='linear')(tdense3)

# scoring_model
auxinp=layers.Input(4)

sinp=layers.Concatenate()([tout,auxinp])
sdense1=layers.Dense(1024, activation='relu')(sinp)
sdense2=layers.Dense(1024, activation='relu')(sdense1)
sout=layers.Dense(1, activation='linear')(sdense2)






targeting_model =  tf.keras.Model(inputs=inp, outputs=tout, name="targeting_model")
print(targeting_model.summary())
scoring_model =  tf.keras.Model(inputs=[inp,auxinp], outputs=sout, name="scoring_model")

撰写回答