从切点Tensorflow对象检测API恢复权重

2024-04-25 19:58:32 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用tensorflow对象检测API,并在以下网站上学习了他们的教程: https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/eager_few_shot_od_training_tf2_colab.ipynb

问题是我想训练更快的R-CNN,但我不知道如何从切点加载预先训练的重量。 下面的代码(从教程中报告)必须为此目的进行调整,但我不知道如何。。我没有找到任何参考资料或文档

# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification, the other for box regression.  We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    # _prediction_heads=detection_model._box_predictor._prediction_heads,
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

Tags: theboxformodeltfrestorepredictorhead

热门问题