PyTorch将ProGAN代理从pth转换为onnx

2024-06-16 10:47:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用thisPyTorch重新实现来训练ProGAN代理,并将代理保存为.pth。现在,我需要将代理转换为.onnx格式,我正在使用这个scipt:

from torch.autograd import Variable

import torch.onnx
import torchvision
import torch

device = torch.device("cuda")

dummy_input = torch.randn(1, 3, 64, 64)
state_dict = torch.load("GAN_agent.pth", map_location = device)

torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")

一旦我运行它,我就会得到错误AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'(下面是完整的提示)。据我所知,问题在于将代理转换为.onnx需要更多信息。我错过什么了吗

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-2-c64481d4eddd> in <module>
     10 state_dict = torch.load("GAN_agent.pth", map_location = device)
     11 
---> 12 torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
    146                         operator_export_type, opset_version, _retain_param_name,
    147                         do_constant_folding, example_outputs,
--> 148                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
    149 
    150 

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
     64             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
     65             example_outputs=example_outputs, strip_doc_string=strip_doc_string,
---> 66             dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
     67 
     68 

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size)
    414                                                         example_outputs, propagate,
    415                                                         _retain_param_name, do_constant_folding,
--> 416                                                         fixed_batch_size=fixed_batch_size)
    417 
    418         # TODO: Don't allocate a in-memory string for the protobuf

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size)
    277             model.graph, tuple(in_vars), False, propagate)
    278     else:
--> 279         graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
    280         state_dict = _unique_state_dict(model)
    281         params = list(state_dict.values())

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _trace_and_get_graph_from_model(model, args, training)
    226     # A basic sanity check: make sure the state_dict keys are the same
    227     # before and after running the model.  Fail fast!
--> 228     orig_state_dict_keys = _unique_state_dict(model).keys()
    229 
    230     # By default, training=False, which is good because running a model in

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\jit\__init__.py in _unique_state_dict(module, keep_vars)
    283     # id(v) doesn't work with it. So we always get the Parameter or Buffer
    284     # as values, and deduplicate the params using Parameters and Buffers
--> 285     state_dict = module.state_dict(keep_vars=True)
    286     filtered_dict = type(state_dict)()
    287     seen_ids = set()

AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'

Tags: nameininputmodelparamnamesexporttorch
1条回答
网友
1楼 · 发布于 2024-06-16 10:47:51

您拥有的文件有state_dict,它们只是将层名称映射到tensor权重偏差和类似a的文件(有关更详细的介绍,请参见here

这意味着您需要一个模型,以便可以映射保存的权重和偏差,但首先要做的是:

一,。模型制备

克隆模型定义所在的the repository并打开文件/pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.py。我们需要一些修改,以便它能够与onnx一起工作onnx导出器要求input仅作为torch.tensor传递(或其中list/dict),而Generator类需要intfloat参数)

简单的解决方案是将forward函数(文件中的80行,您可以验证它on GitHub)稍微修改为以下内容:

def forward(self, x, depth, alpha):
    """
    forward pass of the Generator
    :param x: input noise
    :param depth: current depth from where output is required
    :param alpha: value of alpha for fade-in effect
    :return: y => output
    """

    # THOSE TWO LINES WERE ADDED
    # We will pas tensors but unpack them here to `int` and `float`
    depth = depth.item()
    alpha = alpha.item()
    # THOSE TWO LINES WERE ADDED
    assert depth < self.depth, "Requested output depth cannot be produced"

    y = self.initial_block(x)

    if depth > 0:
        for block in self.layers[: depth - 1]:
            y = block(y)

        residual = self.rgb_converters[depth - 1](self.temporaryUpsampler(y))
        straight = self.rgb_converters[depth](self.layers[depth - 1](y))

        out = (alpha * straight) + ((1 - alpha) * residual)

    else:
        out = self.rgb_converters[0](y)

    return out

这里只添加了通过item()解包。每个不是Tensor类型的输入都应该在函数定义中打包为一个,并尽快在函数顶部解包。它不会破坏您创建的检查点,所以不用担心,因为它只是layer-weight映射

二,。模型导出

将此脚本放置在/pro_gan_pytorch(其中README.md)中:

import torch

from pro_gan_pytorch import PRO_GAN as pg

gen = torch.nn.DataParallel(pg.Generator(depth=9))
gen.load_state_dict(torch.load("GAN_GEN_SHADOW_8.pth"))

module = gen.module.to("cpu")

# Arguments like depth and alpha may need to be changed
dummy_inputs = (torch.randn(1, 512), torch.tensor([5]), torch.tensor([0.1]))
torch.onnx.export(module, dummy_inputs, "GAN_GEN8.onnx", verbose=True)

请注意以下几点:

  • 我们必须在加载权重之前创建模型,因为它只是一个state_dict
  • torch.nn.DataParallel是需要的,因为这是模型的培训内容(不确定您的案例,请相应调整)。加载后,我们可以通过module属性获取模块本身
  • 一切都是铸造的CPU,我想这里不需要GPU。如果你坚持的话,你可以把一切都投给GPU
  • 生成器的虚拟输入不能是图像(我使用了repo作者提供的文件on their Google Drive),它必须是带有512元素的噪声

运行它,您的.onnx文件应该在那里

哦,由于您在不同的检查点之后,您可能希望遵循类似的过程,尽管不能保证一切都会正常工作(尽管看起来确实如此)

相关问题 更多 >