如何解决PyTorch中剪枝模型的深拷贝错误
我正在尝试构建一个强化学习模型,其中我的演员网络有一些被剪枝的连接。
在使用torchrl中的数据收集器SyncDataCollector时,深拷贝操作失败了(见下面的错误信息)。
这似乎是因为那些被剪枝的连接,它们的设置是gradfn(而不是requires_grad=True),这个建议可以参考这篇帖子。
下面是我想运行的代码示例,SyncDataCollector在尝试对模型进行深拷贝,
device = torch.device("cpu")
model = nn.Sequential(
nn.Linear(1,5),
nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)
policy_module = TensorDictModule(
model, in_keys=["in"], out_keys=["out"]
)
env = FlyEnv()
collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=1,
total_frames=2,
split_trajs=False,
device=device,
)
这里是一个最小的示例,产生了错误
import torch
from torch import nn
from copy import deepcopy
import torch.nn.utils.prune as prune
device = torch.device("cpu")
model = nn.Sequential(
nn.Linear(1,5),
nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)
new_model = deepcopy(model)
错误信息是
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment. If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001
我尝试通过prune.remove(model[0], 'weight')
来移除剪枝,然后设置model[0].requires_grad_()
,这样虽然解决了问题,但所有的权重都会被训练...
我觉得可以通过在每次前向传播之前“手动”屏蔽被剪枝的权重来解决这个问题,但这样做似乎既不高效也不优雅。
1 个回答
1
这个错误是因为参数被移动到了 <param>_orig
,而被遮罩的值和它一起存储。当 SyncDataCollector 处理这些参数并把它们放到一个叫“meta”的设备上,以创建一个无状态的策略时,这些额外的值就被忽略了,因为它们不再是参数了(所以在调用 "to"
时不会被捕捉到)。
作为解决方法,你可以在创建收集器之前调用
policy_module.module[0].weight = policy_module.module[0].weight.detach()
这样做是可以的,因为 weight
属性在下一个前向调用时会重新计算。
TorchRL 可能需要更好地处理深拷贝,虽然在这个情况下,错误是因为一个张量在不应该要求梯度的地方要求了梯度。个人认为,剪枝方法应该在前向调用时计算 "weight"
(就像它们现在做的那样),然后再进行剪枝。