如何解决PyTorch中剪枝模型的深拷贝错误

0 投票
1 回答
50 浏览
提问于 2025-04-14 16:57

我正在尝试构建一个强化学习模型,其中我的演员网络有一些被剪枝的连接。

在使用torchrl中的数据收集器SyncDataCollector时,深拷贝操作失败了(见下面的错误信息)。

这似乎是因为那些被剪枝的连接,它们的设置是gradfn(而不是requires_grad=True),这个建议可以参考这篇帖子

下面是我想运行的代码示例,SyncDataCollector在尝试对模型进行深拷贝,

device = torch.device("cpu")

model = nn.Sequential(
    nn.Linear(1,5),
    nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)


policy_module = TensorDictModule(
    model, in_keys=["in"], out_keys=["out"]
)

env = FlyEnv()

collector = SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=1,
    total_frames=2,
    split_trajs=False,
    device=device,
)

这里是一个最小的示例,产生了错误

import torch
from torch import nn
from copy import deepcopy

import torch.nn.utils.prune as prune

device = torch.device("cpu")

model = nn.Sequential(
    nn.Linear(1,5),
    nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)

new_model = deepcopy(model)

错误信息是

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment.  If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001

我尝试通过prune.remove(model[0], 'weight')来移除剪枝,然后设置model[0].requires_grad_(),这样虽然解决了问题,但所有的权重都会被训练...

我觉得可以通过在每次前向传播之前“手动”屏蔽被剪枝的权重来解决这个问题,但这样做似乎既不高效也不优雅。

1 个回答

1

这个错误是因为参数被移动到了 <param>_orig,而被遮罩的值和它一起存储。当 SyncDataCollector 处理这些参数并把它们放到一个叫“meta”的设备上,以创建一个无状态的策略时,这些额外的值就被忽略了,因为它们不再是参数了(所以在调用 "to" 时不会被捕捉到)。

作为解决方法,你可以在创建收集器之前调用

policy_module.module[0].weight = policy_module.module[0].weight.detach()

这样做是可以的,因为 weight 属性在下一个前向调用时会重新计算。

TorchRL 可能需要更好地处理深拷贝,虽然在这个情况下,错误是因为一个张量在不应该要求梯度的地方要求了梯度。个人认为,剪枝方法应该在前向调用时计算 "weight"(就像它们现在做的那样),然后再进行剪枝。

撰写回答