当我将ERFNet(pytorch)模型转换为ONNX时。你知道吗
(此模型来自https://github.com/wvangansbeke/LaneDetection_End2End)
我的代码如下。你知道吗
from torch.autograd import Variable
import torch.onnx
import torchvision
from Networks.LSQ_layer import Net
from Networks.utils import define_args, save_weightmap, first_run, \
mkdir_if_missing, Logger, define_init_weights,\
define_scheduler, define_optim, AverageMeter
dummy_input = Variable(torch.randn(1,3,256,512)).cuda()
parser = define_args()
args = parser.parse_known_args()[0]
model = Net(args)
define_init_weights(model, args.weight_init)
checkpoint = torch.load("model_best_epoch_204.pth.tar")
model.load_state_dict(checkpoint['state_dict'])
model = model.cuda()
torch.onnx.export(model, dummy_input, "LaneDetection.onnx", verbose=True)
出现此错误。
ValueError:自动嵌套不知道如何处理int类型的输入对象。接受的类型:张量或张量的列表/元组
Traceback (most recent call last):
File "/tmp/pycharm_project_633/venv/Scripts/darknet2onnx.py", line 22, in <module>
torch.onnx.export(model, dummy_input, "LaneDetection.onnx", verbose=True)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 27, in export
return utils.export(*args, **kwargs)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 104, in export
operator_export_type=operator_export_type)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 281, in _export
example_outputs, propagate)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 224, in _model_to_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 192, in _trace_and_get_graph_from_model
trace, torch_out = torch.jit.get_trace_graph(model, args, _force_outplace=True)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 197, in get_trace_graph
return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 252, in forward
out = self.inner(*trace_inputs)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 487, in __call__
result = self._slow_forward(*input, **kwargs)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in _slow_forward
result = self.forward(*input, **kwargs)
File "/tmp/pycharm_project_633/venv/Scripts/Networks/LSQ_layer.py", line 295, in forward
shared_encoder, output = self.net(input, end_to_end*self.pretrained)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 487, in __call__
result = self._slow_forward(*input, **kwargs)
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 464, in _slow_forward
input_vars = tuple(torch.autograd.function._iter_tensors(input)) #lulu
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/autograd/function.py", line 284, in _iter
for var in _iter(o):
File "/work/dependence/anaconda3/lib/python3.6/site-packages/torch/autograd/function.py", line 293, in _iter
if condition_msg else ""))
ValueError: Auto nesting doesn't know how to process an input object of type int. Accepted types: Tensors, or lists/tuples of them
经过调试,我发现问题来自 Python3/lib/python3.6/site-packages/torch/autograd/函数.py,第293行
def _iter_filter(condition, allow_unknown=False, condition_msg=None,
conversion=None):#lulu change allow_unknown=False to True
def _iter(obj):
if conversion is not None:
obj = conversion(obj)
if condition(obj):
yield obj
elif obj is None:
return
elif isinstance(obj, (list, tuple)):
for o in obj:
for var in _iter(o):
yield var
elif allow_unknown:
yield obj
else:
raise ValueError("Auto nesting doesn't know how to process "
"an input object of type " + torch.typename(obj) +
(". Accepted types: " + condition_msg +
", or lists/tuples of them"
if condition_msg else ""))
return _iter
在这个函数中
for var in _iter(o):
yield var
它让o=(int)0在一个循环中,并报告ValueError。你知道吗
为什么会产生零呢?我该怎么修? 如果你想了解更多细节,请发表评论。我会尽快回复你。你知道吗
目前没有回答
相关问题 更多 >
编程相关推荐