在使用 PyTorch 进行深度学习训练时,可能会遇到 GPU 0 的显存占用明显高于其他 GPU 的情况。这可能导致显存不足,影响训练效率。排查思路:
- 开启两个 Shell,Shell A 用于运行训练脚本,Shell B 用于实时监控显存。
- 在 Shell A 运行训练脚本。
- 在另一个 shell 中运行
nvidia-smi -l 1
。这个命令会每秒刷新一次nvidia-smi
的结果,方便实时观察各 GPU 显存的变化情况。 - 观察 GPU 0 何时开始占用较高显存
- 在适当时机终止训练脚本,检查 traceback,确定代码中的异常行号。
可能的原因有很多,比如使用 DP 而不是 DDP、没有执行 torch.cuda.set_device
等,网上由很多博客已经介绍的比较详细了。这里分享一种相对少见的情况:torch.load
没有指定 map_location='cpu'
。
多卡训练脚本一般会定期保存 ckpt。为了防止多个进程同时写入 ckpt 文件,一般只会保存 0 号进程的 ckpt。如果在保存前没有将权重移动到 cpu 上,使用 torch.load
加载 ckpt 的时候就会加载到 GPU 0。
解决方案:
model = torch.load("model.pth", "cpu")
然后再手动将权重分配到合适的 GPU。