基于pytorch框架使用多gpu训练时,如何有效降低显存
文章目录
1. 问题阐述
我们在基于pytorch框架采用多gpu进行大数据的训练时,由于自己的batch_size设计的偏大或者其他原因,经常会在程序运行过程中遇到RuntimeError
,使得程序中断。在RuntimeError中,Out of Memory``错误出现的最多,其次是 CUDNN_STATUS_EXECUTION_FAILED
错误。这两种错误发生的原因都是因为显存不够。
2. 解决方案
为了优化显存的使用,在网上查阅的资料和自己亲身实践的基础上,总结出以下几点
2.1 周期性使用torch.cuda.empty_cache()
函数
- 使用pytorch在进行神经网络训练时,会产生很多临时变量,随着训练的进行,临时变量将会对训练效率产生一点影响,所以定期采用
empty_cache()
不仅可以及时清除训练过程中的缓存,释放更多的显存空间,而且提高训练效率。 - 在
model.eval()
之前使用empty_cache()
,可以及时清理训练阶段的临时数据 - 在保存网络模型前使用
empty_cache()
,可以减少保存的模型大小。
总之,empty_cache()函数是个非常有用的函数,对于优化显存使用有着很大的帮助
2.2 将loss的计算写入网络的forward()
函数中
这里要先说明一下pytorch框架所采用的多gpu并行计算的原理。pytoch会先根据你的设定选用第一