记录一下这个问题,困扰了好久,终于解决了。
本来打算利用现成的光流网络模型来提取光流图,用下面的办法加载进去。之后两张图片作为输入送入flow_model得到光流图。
flow_model = nn.DataParallel(xxxxFlowNet(args), device_ids=[0, 1, 2, 3, 4, 5])
flow_model.load_state_dict(torch.load('../xxxxFlowNet/pretrained_models/xxxxflownet-things.pth'))
for param in flow_model.parameters():
param.requires_grad = False
flow_model = flow_model.cuda()
但是出现了下述错误。
Traceback (most recent call last):
File “/data2/xxxxx/RGBTCrowdCounting/xxxxNet-pytorch/train.py”, line 230, in
main()
File “/data2/xxxxx/RGBTCrowdCounting/xxxxNet-pytorch/train.py”, line 82, in main
train(train_list, model, criterion, optimizer, epoch, flow_model)
File “/data2/xxxx/RGBTCrowdCounting/xxxxNet-pytorch/train.py”, line 121, in train
for i, (img, flow, target) in enumerate(train_loader):
File “/home/xxxx/anaconda3/envs/COG/lib/python3.9/site-packages/torch/utils/data/dataloader.py”, line 652, in next
data = self._next_data()
File “/home/xxxx/anaconda3/envs/COG/lib/python3.9/site-packages/torch/utils/data/dataloader.py”, line 692, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File “/home/xxxx/anaconda3/envs/COG/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py”, line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File “/home/xxxx/anaconda3/envs/COG/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py”, line 49, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File “/data2/xxxx/RGBTCrowdCounting/xxxxNet-pytorch/dataset.py”, line 37, in getitem
img, flow, target = load_data(img_path, self.train, self.GMFlowNet)
File “/data2/xxxxx/RGBTCrowdCounting/xxxxNet-pytorch/image.py”, line 79, in load_data
_, _ = flow_model(image1, image2, test_mode=True)
File “/home/xxxx/anaconda3/envs/COG/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1130, in _call_impl
return forward_call(*input, **kwargs)
File “/home/xxxx/anaconda3/envs/COG/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py”, line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/home/xxxx/anaconda3/envs/COG/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py”, line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/home/xxxx/anaconda3/envs/COG/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py”, line 86, in parallel_apply
output.reraise()
File “/home/xxxx/anaconda3/envs/COG/lib/python3.9/site-packages/torch/_utils.py”, line 461, in reraise
raise exception
TypeError: Caught TypeError in replica 1 on device 1.
Original Traceback (most recent call last):
File “/home/xxxx/anaconda3/envs/COG/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py”, line 61, in _worker
output = module(*input, **kwargs)
File “/home/xxxx/anaconda3/envs/COG/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1130, in _call_impl
return forward_call(*input, **kwargs)
TypeError: forward() missing 2 required positional arguments: ‘image1’ and ‘image2’
找了好久资料,也没发现可靠的。
后来仔细分析加载的网络发现,加载的网络是通过多gpu训练的,因此,加载的时候也得使用nn.DataParallel。
但是要注意!!使用的时候必须要用model.module,必须加module!!
原始代码最后一行改成:
flow_model = flow_model.module.cuda()
问题解决。