<p>这取决于<code>y</code>最初是如何构造的。这里我假设<code>y</code>是批处理中每个序列的单值标签。在</p>
<p>当有多个输入/输出时,<code>model.fit()</code>期望给出相应的输入/输出列表。<code>np.split(y, output_branches, axis=-1)</code>在下面的一个完全可复制的例子中,正是这样做的-对于每个批次,将单个输出列表拆分为一个单独的输出列表,其中每个输出(在本例中)是1个元素列表:</p>
<pre class="lang-py prettyprint-override"><code>import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
batch_size = 100
seq_length = 10
feature_cnt = 5
output_branches = 3
# Say we've got:
# - 100-element batch
# - of 10-element sequences
# - where each element of a sequence is a vector describing 5 features.
X = np.random.random_sample([batch_size, seq_length, feature_cnt])
# Every sequence of a batch is labelled with `output_branches` labels.
y = np.random.random_sample([batch_size, output_branches])
# Here y.shape() == (100, 3)
# Here we split the last axis of y (output_branches) into `output_branches` separate lists.
y = np.split(y, output_branches, axis=-1)
# Here y is not a numpy matrix anymore, but a list of matrices.
# E.g. y[0].shape() == (100, 1); y[1].shape() == (100, 1) etc...
outputs = []
main_input = tf.keras.layers.Input(shape=(seq_length, feature_cnt), name='main_input')
lstm = tf.keras.layers.LSTM(32, return_sequences=True)(main_input)
for _ in range(output_branches):
prediction = tf.keras.layers.LSTM(8, return_sequences=False)(lstm)
out = tf.keras.layers.Dense(1)(prediction)
outputs.append(out)
model = tf.keras.models.Model(inputs=main_input, outputs=outputs)
model.compile(optimizer='rmsprop', loss='mse')
model.fit(X, y)
</code></pre>
<p>由于没有指定数据的确切外观,可能需要使用轴。在</p>
<p>编辑:
当作者从官方来源寻找答案时,它提到了<a href="https://www.tensorflow.org/beta/guide/keras/functional#models_with_multiple_inputs_and_outputs" rel="nofollow noreferrer">here</a>(虽然不是明确的,它只提到数据集应该产生什么,因此<code>model.fit()</code>期望什么样的输入结构):</p>
<blockquote>
<p>When calling fit with a Dataset object, it should yield either a tuple of lists like <code>([title_data, body_data, tags_data], [priority_targets, dept_targets])</code> or a tuple of dictionaries like <code>({'title': title_data, 'body': body_data, 'tags': tags_data}, {'priority': priority_targets, 'department': dept_targets})</code>.</p>
</blockquote>