ERFNet (pytorch) 到 ONNX

2024-04-29 08:25:40 发布

您现在位置:Python中文网/ 问答频道 /正文

当我将ERFNet(pytorch)模型转换为ONNX时。你知道吗

(此模型来自https://github.com/wvangansbeke/LaneDetection_End2End
我的代码如下。你知道吗

from torch.autograd import Variable
import torch.onnx
import torchvision
from Networks.LSQ_layer import Net
from Networks.utils import define_args, save_weightmap, first_run, \
                           mkdir_if_missing, Logger, define_init_weights,\
                           define_scheduler, define_optim, AverageMeter

dummy_input = Variable(torch.randn(1,3,256,512)).cuda()
parser = define_args()
args = parser.parse_known_args()[0]  

model = Net(args)
define_init_weights(model, args.weight_init)
checkpoint = torch.load("model_best_epoch_204.pth.tar")
model.load_state_dict(checkpoint['state_dict'])
model = model.cuda()

torch.onnx.export(model, dummy_input, "LaneDetection.onnx", verbose=True)

出现此错误。
ValueError:自动嵌套不知道如何处理int类型的输入对象。接受的类型:张量或张量的列表/元组

Traceback (most recent call last):
  File "/tmp/pycharm_project_633/venv/Scripts/darknet2onnx.py", line 22, in <module>
    torch.onnx.export(model, dummy_input, "LaneDetection.onnx", verbose=True)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 27, in export
    return utils.export(*args, **kwargs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 104, in export
    operator_export_type=operator_export_type)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 281, in _export
    example_outputs, propagate)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 224, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 192, in _trace_and_get_graph_from_model
    trace, torch_out = torch.jit.get_trace_graph(model, args, _force_outplace=True)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 197, in get_trace_graph
    return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 252, in forward
    out = self.inner(*trace_inputs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 487, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/tmp/pycharm_project_633/venv/Scripts/Networks/LSQ_layer.py", line 295, in forward
    shared_encoder, output = self.net(input, end_to_end*self.pretrained)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 487, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 464, in _slow_forward
    input_vars = tuple(torch.autograd.function._iter_tensors(input)) #lulu 
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/autograd/function.py", line 284, in _iter
    for var in _iter(o):
  File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/autograd/function.py", line 293, in _iter
    if condition_msg else ""))
ValueError: Auto nesting doesn't know how to process an input object of type int. Accepted types: Tensors, or lists/tuples of them

经过调试,我发现问题来自 Python3/lib/python3.6/site-packages/torch/autograd/函数.py,第293行

def _iter_filter(condition, allow_unknown=False, condition_msg=None,
                 conversion=None):#lulu change allow_unknown=False to True
    def _iter(obj):
        if conversion is not None:
            obj = conversion(obj)
        if condition(obj):
            yield obj
        elif obj is None:
            return
        elif isinstance(obj, (list, tuple)):
            for o in obj:
                for var in _iter(o):
                    yield var
        elif allow_unknown:
            yield obj
        else:
            raise ValueError("Auto nesting doesn't know how to process "
                             "an input object of type " + torch.typename(obj) +
                             (". Accepted types: " + condition_msg +
                              ", or lists/tuples of them"
                              if condition_msg else ""))

    return _iter

在这个函数中

for var in _iter(o):
    yield var

它让o=(int)0在一个循环中,并报告ValueError。你知道吗

为什么会产生零呢?我该怎么修? 如果你想了解更多细节,请发表评论。我会尽快回复你。你知道吗


Tags: inpyinputmodellibpackageslinesite