torch.nn
DataParallel详解
定义:model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
代码流程
下面主要是详细介绍一下数据并行的流程,自己动手Debug一下更清晰!
- 在调用网络模型model(DataParallel类实例)时,实际是进入模型基类Module的特殊方法
__call__
(module.py)。
注: 特殊方法__call__
用于实现对类实例的调用。 - 再通过模型基类Module的self.forward进入DataParallel类的forward方法(data_parallel.py):
- 先调用
self.scatter
在第一个维度分配输入 - 调用
self.replicate
产生模型副本放置在多个GPU上,形成modules列表 - 调用
parallel_apply
执行并行操作(parallel_apply.py)
- 先调用
- 在
parallel_apply
中通过多线程模块threading,将不同的module(自定义的网络模型类实例),input,GPU_ID以及kwargs分配给不同的线程。通过for循环控制启动线程活动和等待至线程中止。 - 通过多线程的
run()
方法进入线程目标函数_worker
中,调用module(自定义的网络模型类实例),返回此次线程的结果并存储在指定字典results中。
注:调用module时,同样也是先进入模