<p>有两种方法可以做到这一点,第一种方法是在保存之前手动更新权重,如文档中的<a href="https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/AveragedOptimizerWrapper#example_7" rel="nofollow noreferrer">example</a></p>
<pre><code>import tensorflow as tf
import tensorflow_addons as tfa
model = tf.Sequential([...])
opt = tfa.optimizers.SWA(
tf.keras.optimizers.SGD(lr=2.0), 100, 10)
model.compile(opt, ...)
model.fit(x, y, ...)
# Update the weights to their mean before saving
opt.assign_average_vars(model.variables)
model.save('model.h5')
</code></pre>
<p>第二个选项是,如果设置了^{<cd2>,则通过</strong><code>AverageModelCheckpoint</code>更新权重。如collab笔记本<a href="https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/average_optimizers_callback.ipynb" rel="nofollow noreferrer">example</a>所示</p>
<pre><code>avg_callback = tfa.callbacks.AverageModelCheckpoint(filepath=checkpoint_dir,
update_weights=True)
...
#Build Model
model = create_model(moving_avg_sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
</code></pre>
<p>请注意<code>AverageModelCheckpoint</code>在保存模型之前也会从<a href="https://github.com/tensorflow/addons/blob/v0.12.0/tensorflow_addons/callbacks/average_model_checkpoint.py#L81" rel="nofollow noreferrer">source code</a>调用<code>assign_average_vars</code>:</p>
<pre><code>def _save_model(self, epoch, logs):
optimizer = self._get_optimizer()
assert isinstance(optimizer, AveragedOptimizerWrapper)
if self.update_weights:
optimizer.assign_average_vars(self.model.variables)
return super()._save_model(epoch, logs)
...
</code></pre>