对mm系列真的是又爱又恨,很方便做baseline又特别容易报错,有时候报的错还解决不了。
报错内容:
Traceback (most recent call last):
File "tools/train.py", line 237, in <module>
main()
File "tools/train.py", line 226, in main
train_detector(
File "/root/miniconda3/lib/python3.8/site-packages/mmdet/apis/train.py", line 246, in train_detector
runner.run(data_loaders, cfg.workflow)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 136, in run
epoch_runner(data_loaders[i], **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 53, in train
self.run_iter(data_batch, train_mode=True, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/runner/epoch_based_runner.py", line 31, in run_iter
outputs = self.model.train_step(data_batch, self.optimizer,
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/data_parallel.py", line 76, in train_step
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/data_parallel.py", line 55, in scatter
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/scatter_gather.py", line 60, in scatter_kwargs
inputs = scatter(inputs, target_gpus, dim) if inputs else []
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/scatter_gather.py", line 50, in scatter
return scatter_map(inputs)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/scatter_gather.py", line 35, in scatter_map
return list(zip(*map(scatter_map, obj)))
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/scatter_gather.py", line 40, in scatter_map
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/scatter_gather.py", line 35, in scatter_map
return list(zip(*map(scatter_map, obj)))
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/scatter_gather.py", line 33, in scatter_map
return Scatter.forward(target_gpus, obj.data)
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/_functions.py", line 75, in forward
streams = [_get_stream(device) for device in target_gpus]
File "/root/miniconda3/lib/python3.8/site-packages/mmcv/parallel/_functions.py", line 75, in <listcomp>
streams = [_get_stream(device) for device in target_gpus]
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/parallel/_functions.py", line 117, in _get_stream
if device.type == "cpu":
AttributeError: 'int' object has no attribute 'type'
报错原因:可能是原先代码用了多卡训练,返回的是列表如['0','1'],然后我这里只有单卡,返回的是整数0被识别了,就产生了错误,得在源头上让单卡是0的情况也被识别成功。
解决方法:为了不动pytorch包下的文件,我这里改了mmdet的内容。
#自定义一个函数
def safe_get_stream(device):
if isinstance(device, int):
device = torch.device(f'cuda:{device}')
return _get_stream(device)
#改写原先的_get_stream方法
class Scatter:
@staticmethod
def forward(target_gpus: List[int], input: Union[List, Tensor]) -> tuple:
input_device = get_input_device(input)
streams = None
if input_device == -1 and target_gpus != [-1]:
# Perform CPU to GPU copies in a background stream
streams = [safe_get_stream(device) for device in target_gpus]
outputs = scatter(input, target_gpus, streams)
# Synchronize with the copy stream
if streams is not None:
synchronize_stream(outputs, target_gpus, streams)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )