在使用 PyTorch 进行深度学习训练时,可能会遇到 GPU 0 的显存占用明显高于其他 GPU 的情况。这可能导致显存不足,影响训练效率。排查思路:
- 开启两个 Shell,Shell A 用于运行训练脚本,Shell B 用于实时监控显存。
- 在 Shell A 运行训练脚本。
- 在另一个 shell 中运行
nvidia-smi -l 1。这个命令会每秒刷新一次nvidia-smi的结果,方便实时观察各 GPU 显存的变化情况。 - 观察 GPU 0 何时开始占用较高显存
- 在适当时机终止训练脚本,检查 traceback,确定代码中的异常行号。
可能的原因有很多,比如使用 DP 而不是 DDP、没有执行 torch.cuda.set_device 等,网上由很多博客已经介绍的比较详细了。这里分享一种相对少见的情况:torch.load 没有指定 map_location='cpu'。
多卡训练脚本一般会定期保存 ckpt。为了防止多个进程同时写入 ckpt 文件,一般只会保存 0 号进程的 ckpt。如果在保存前没有将权重移动到 cpu 上,使用 torch.load 加载 ckpt 的时候就会加载到 GPU 0。
解决方案:
model = torch.load("model.pth", "cpu")
然后再手动将权重分配到合适的 GPU。
1462

被折叠的 条评论
为什么被折叠?



