无法从TFA加载实现RSquare度量的Tensorflow LSTM模型

2024-05-15 04:04:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在运行Python3.8.10和Tensorflow 2.3.0,安装了conda,Tensorflow插件v.0.13.0,通过pip安装(因为conda上可用的最新版本是0.9.1)。根据Tensorflow插件GitHub自述文件上的compatibility matrix,这些版本应该是兼容的

我创建了一组LSTM模型,并将其保存为TF模型,如下所示:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow_addons as tfa

n_lstm_cells = [5, 10, 20, 50, 100]
models = []

for cells in n_lstm_cells:
  model = Sequential()
  model.add(LSTM(units = cells))
  model.add(Dense(units = 182))
  model.compile(optimizer='adam', loss = 'mean_squared_error', metrics = [tfa.metrics.RSquare(y_shape = (182,))])
    
  os.makedirs(os.path.join(r"C:\Users\stefa\Documents\models\Case A\checkpoints\LSTM",f"{cells} cells"), exist_ok = True) 
  checkdir = os.path.join(r"C:\Users\stefa\Documents\models\Case A\checkpoints\LSTM",f"{cells} cells",'noDST-log1p-{epoch:02d}-{val_r_square:.3f}.hdf5')
  callbacks = [ModelCheckpoint(checkdir, save_freq='epoch', save_best_only = True, monitor = 'val_r_square', mode = 'max'),
             EarlyStopping(patience = 10)]
  print(f'Fitting model with {cells} cells')
  history = model.fit(train_data_gen, epochs = 500, validation_data = val_data_gen, callbacks = callbacks)
  models.append(history)

for cells, model in zip(n_lstm_cells, models):
    os.makedirs(r'C:\Users\stefa\Documents\models\Case A\LSTM\noDST-singlestep-24h-LSTM-{}'.format(cells), exist_ok = True)
    model.model.save(r'C:\Users\stefa\Documents\models\Case A\LSTM\noDST-singlestep-24h-LSTM-{}'.format(cells))

然而,当我尝试加载它时,我得到一个“ValueError:当前无法恢复类型为{u tf_keras_metric”的自定义对象。请确保该层在保存时实现get_configfrom_config。此外,在调用load_model()时请使用custom_objects参数,即使我将r{}作为自定义度量传递

from tensorflow.keras.models import load_model
import tensorflow_addons as tfa
lstm_model = load_model(r'C:\Users\stefa\Documents\models\Case A\LSTM\noDST-singlestep-24h-LSTM-100', custom_objects = {'r_square': tfa.metrics.RSquare})

我怎样才能解决这个问题


Tags: fromimportmodelmodelstensorflowusersdocumentskeras
1条回答
网友
1楼 · 发布于 2024-05-15 04:04:31

你几乎是对的,用RSquare代替r_square

lstm_model = load_model(r'C:\Users\stefa\Documents\models\Case A\LSTM\noDST-singlestep-24h-LSTM-100', custom_objects = {'RSquare': tfa.metrics.RSquare})

相关问题 更多 >

    热门问题