`output_signature`必须包含`tf.TypeSpec`子类对象,但发现了不属于此的<class 'list'>
出错的代码
hist = model.fit(
data_gen_train.generate(),
steps_per_epoch=2 if params['quick_test'] else data_gen_train.get_total_batches_in_data(),
validation_data=data_gen_test.generate(),
validation_steps=2 if params['quick_test'] else data_gen_test.get_total_batches_in_data(),
epochs=1,
verbose=0
)
完整的错误日志
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 123, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/tensorflow/python/data/ops/from_generator_op.py", line 124, in _from_generator
raise TypeError(f"`output_signature` must contain objects that are "
TypeError: `output_signature` must contain objects that are subclass of `tf.TypeSpec` but found <class 'list'> which is not.
现在我在本地电脑上练习一个声音事件检测的论文。不过我之前用的是keras和Tensorflow,在训练的时候遇到了一些问题。
模型是用keras的Model创建的,然后我用下面的代码进行了编译。
model = Model(inputs=spec_start, outputs=[sed, doa])
model.compile(optimizer=Adam(), loss=['binary_crossentropy', 'mse'], loss_weights=weights)
从错误代码来看,<class 'list'>可能在哪里出现?还有,.fit()方法中的output_signature是什么?
1 个回答
0
我正在使用tensorflow 2.16,遇到了一个问题,就是在返回一个列表的时候出现了问题:
def __getitem__(self, index):
X1, X2, y = get_data()
return [X1, X2], y
为了解决这个问题,我把列表改成了一个元组,问题就解决了:
def __getitem__(self, index):
X1, X2, y = get_data()
return (X1, X2), y