直接上代码
- DDP forward
if self.device_ids:
if len(self.device_ids) == 1:
inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0])
output = self.module(*inputs[0], **kwargs[0])
else:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module(*inputs, **kwargs)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def scatter_kwargs(inputs, kwargs, target_gpus, dim