如何在Tensorflow中实现预训练?如何部分使用检查点文件中保存的权重?

2024-05-01 21:18:51 发布

您现在位置:Python中文网/ 问答频道 /正文

为了便于讨论,对以下模型进行了简化。

假设在我的训练集中有大约40000张512x512的图像。我正在努力实施预培训,我的计划如下:

1.训练一个接受256x256图像的神经网络(称之为net1),并将训练后的模型保存为tensorflow检查点文件格式。在

net_1: input -> 3 conv2d -> maxpool2d -> 2 conv2d -> rmspool -> flatten -> dense

我们将此结构称为net_1_内核:

^{pr2}$

并将剩余部分称为其他层:

other_layers: rmspool -> flatten -> dense

因此,我们可以用以下形式表示net_1:

net_1: input -> net_1_kernel -> other_layers

2.在net1的结构中插入几层,现在称之为net2。应该是这样的:

net_2: input -> net_1_kernel -> maxpool2d -> 3 conv2d -> other_layers

net_2将以512x512图像作为输入。在

当我训练net_2时,我想使用net_1的检查点文件中保存的权重和偏差来初始化net_2中的net_1_内核部分。我该怎么做?

我知道我可以加载检查点来预测测试数据。但在这种情况下,它将加载所有内容(net1_内核和其他层)。我想要的是只加载net_1_内核,并将其用于netu 2中的权重/偏差初始化。在

我还知道我可以将检查点文件中的内容打印为txt,并复制和粘贴以手动初始化权重和偏差。然而,在这些权重和偏差中有太多的数字,这将是我最后的选择。在


Tags: 模型图像inputnetlayers内核检查点权重
1条回答
网友
1楼 · 发布于 2024-05-01 21:18:51

首先,您可以使用以下代码检查保存的ckpt文件中所有检查点的列表。在

from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file(file_name="file.ckpt", tensor_name="xxx", all_tensors=False, all_tensor_names=True)

请记住,当您还原检查点文件时,它将还原检查点文件中的所有变量。如果必须保存和恢复特定的变量,可以执行以下操作:

  1. 列出要从tf.trainable_variables()保存的所有变量的列表

var = [v for v in tf.trainable_variables() if "net_1_kernel" in v.name]

saverAndRestore = tf.train.Saver(var)

  1. 现在您可以方便地保存或恢复var列表中的所有变量,如下所示:

saverAndRestore.save(sess_1,"net_1.ckpt")

这将只将列表var中的变量保存到net_1.ckpt。在

saverAndRestore.restore(sess_1,"net_1.ckpt")

这将仅从net_1.ckpt恢复列表变量中的变量。在

除此之外,您所要做的就是命名/限定变量的范围,这样您就可以轻松地执行上面的步骤1。在

相关问题 更多 >