我发现这个很好的代码Pytorch mobilenet,我无法在CPU上运行。 https://github.com/rdroste/unisal
我是Pytorch的新手,所以我不知道舒尔该怎么做
在模块train.py的第174行中,设备设置为:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
就我所知,这是正确的
我也必须更换手电筒吗?我尝试了,但没有成功。
class BaseModel(nn.Module):
"""Abstract model class with functionality to save and load weights"""
def forward(self, *input):
raise NotImplementedError
def save_weights(self, directory, name):
torch.save(self.state_dict(), directory / f'weights_{name}.pth')
def load_weights(self, directory, name):
self.load_state_dict(torch.load(directory / f'weights_{name}.pth'))
def load_best_weights(self, directory):
self.load_state_dict(torch.load(directory / f'weights_best.pth'))
def load_epoch_checkpoint(self, directory, epoch):
"""Load state_dict from a Trainer checkpoint at a specific epoch"""
chkpnt = torch.load(directory / f"chkpnt_epoch{epoch:04d}.pth")
self.load_state_dict(chkpnt['model_state_dict'])
def load_checkpoint(self, file):
"""Load state_dict from a specific Trainer checkpoint"""
"""Load """
chkpnt = torch.load(file)
self.load_state_dict(chkpnt['model_state_dict'])
def load_last_chkpnt(self, directory):
"""Load state_dict from the last Trainer checkpoint"""
last_chkpnt = sorted(list(directory.glob('chkpnt_epoch*.pth')))[-1]
self.load_checkpoint(last_chkpnt)
我不明白。我必须在哪里告诉pytorch没有gpu?
完全错误:
Traceback (most recent call last):
File "run.py", line 99, in <module>
fire.Fire()
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/fire/core.py", line 138, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/fire/core.py", line 471, in _Fire
target=component.__name__)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/fire/core.py", line 675, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "run.py", line 95, in predict_examples
example_folder, is_video, train_id=train_id, source=source)
File "run.py", line 72, in predictions_from_folder
folder_path, is_video, source=source, model_domain=model_domain)
File "/home/b256/Data/saliency_models/unisal-master/unisal/train.py", line 871, in generate_predictions_from_path
self.model.load_best_weights(self.train_dir)
File "/home/b256/Data/saliency_models/unisal-master/unisal/train.py", line 1057, in model
self._model = model_cls(**self.model_cfg)
File "/home/b256/Data/saliency_models/unisal-master/unisal/model.py", line 190, in __init__
self.cnn = MobileNetV2(**self.cnn_cfg)
File "/home/b256/Data/saliency_models/unisal-master/unisal/models/MobileNetV2.py", line 156, in __init__
Path(__file__).resolve().parent / 'weights/mobilenet_v2.pth.tar')
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 367, in load
return _load(f, map_location, pickle_module)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 538, in _load
result = unpickler.load()
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 504, in persistent_load
data_type(size), location)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 113, in default_restore_location
result = fn(storage, location)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 94, in _cuda_deserialize
device = validate_cuda_device(location)
File "/home/b256/anaconda3/envs/unisal36/lib/python3.6/site-packages/torch/serialization.py", line 78, in validate_cuda_device
raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.
在https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-on-gpu-load-on-cpu中,您将看到一个
map_location
关键字参数,用于将权重发送到适当的设备:从文件https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
相关问题 更多 >
编程相关推荐