state_dict = torch.load(self._model_path).cuda()
model = nn.DataParallel(model)
model.load_state_dict(state_dict.state_dict())
model_single_gpu = model.module
model_single_gpu.eval()
多GPU训练,加载模型测试
最新推荐文章于 2024-11-12 01:04:23 发布