如何使用tf.data.Dataset和tf.keras实现多输入和多输出?

2024-04-25 17:49:32 发布

您现在位置:Python中文网/ 问答频道 /正文

我遇到一个关于多输出使用的问题特斯拉斯建立模型并使用tf.data.dataset作为输入管道。下面是我的代码:

  a = tf.keras.layers.Input(shape=(368, 368, 3))
  conv1 = tf.keras.layers.Conv2D(64, 3, 1)(a)
  conv2 = tf.keras.layers.Conv2D(64, 3, 1)(conv1)
  maxpool = tf.keras.layers.MaxPooling2D(pool_size=8, strides=8, 
   padding='same')(conv2)
  conv3 = tf.keras.layers.Conv2D(5, 1, 1)(maxpool)
  conv4 = tf.keras.layers.Conv2D(6, 1, 1)(maxpool)

  inputs = a
  outputs = [conv3, conv4]

  model = tf.keras.models.Model(inputs=inputs, outputs=outputs)


  model.compile(optimizer=tf.keras.optimizers.SGD(),
          loss=tf.keras.losses.mean_squared_error)


  import numpy as np
  data = np.random.rand(10, 368, 368, 3)
  cpm  = np.random.rand(10, 46, 46, 5)
  paf  = np.random.rand(10, 46, 46, 6)

  dataset1 = tf.data.Dataset.from_tensor_slices((data))
  dataset2 = tf.data.Dataset.from_tensor_slices((cpm, paf))
  dataset1 = dataset1.batch(10).repeat()
  dataset2 = dataset2.batch(10).repeat()

  dataset  = tf.data.Dataset.zip((dataset1, dataset2))

  model.fit(dataset, epochs=200, steps_per_epoch=30)

我使用tensorflow==1.10.1,得到的错误如下:

^{pr2}$

更新: 在升级tf==1.11.0之后,我已经让这段代码正常工作了。所以也许我认为是版本错误。在


Tags: datamodellayerstfnprandomoutputsdataset
1条回答
网友
1楼 · 发布于 2024-04-25 17:49:32

您可以尝试将输出串联起来,然后对目标numpy数组执行相同的操作。我不确定它是否对你的应用程序逻辑有意义。在

def conc_op(tensors):
    return K.concatenate(tensors) # K refers to Keras backend

def conc_op_shape(input_shapes):
    shape1 = list(input_shapes[0])
    shape2 = list(input_shapes[1])
    return tuple(shape1[:-1], shape1[-1]+shape2[-1])

a = tf.keras.layers.Input(shape=(368, 368, 3))
conv1 = tf.keras.layers.Conv2D(64, 3, 1)(a)
conv2 = tf.keras.layers.Conv2D(64, 3, 1)(conv1)
maxpool = tf.keras.layers.MaxPooling2D(pool_size=8, strides=8, padding='same')(conv2)
conv3 = tf.keras.layers.Conv2D(5, 1, 1)(maxpool)
conv4 = tf.keras.layers.Conv2D(6, 1, 1)(maxpool)

inputs = a
outputs = [conv3, conv4]
conc_outputs = Lambda(conc_op, output_shape=conc_op_shape)(outputs) # This is a keras layer
model = tf.keras.models.Model(inputs=inputs, outputs=conc_outputs)

model.compile(optimizer=tf.keras.optimizers.SGD(), loss=keras.losses.mean_squared_error)
model.summary()
data = np.random.rand(10, 368, 368, 3)
cpm  = np.random.rand(10, 46, 46, 5)
paf  = np.random.rand(10, 46, 46, 6)
label = np.concatenate((cpm, paf), axis=-1)

dataset = tf.data.Dataset.from_tensor_slices((data, label))
dataset = dataset.batch(2).repeat()
model.fit(dataset.make_one_shot_iterator(), epochs=2, steps_per_epoch=5) # Just to check if it runs

返回结果:

^{pr2}$

相关问题 更多 >