模型训练的时候,爆显存了,可以调整batch,对数据进行crop等等操作
今天发现一个模型,训练ok,每次测试的时候爆显存。开始以为是因为用了全图(1920x1080略大)进行inference。
这是一方面,但后来发现忘了用with torch.no_grad():
这导致模型运算的时候不能释放显存(记录了梯度信息),所以显存巨大。
加了之后,用了不过3G显存就够了。。确实inference不需要那么多显存的,以后记着这种不正常现象如何处理。
一般训练不爆显存,测试也不会爆;训练时的显存占用远多于inference

本文探讨了解决深度学习模型测试阶段显存溢出的问题,指出使用全图进行推理和未使用with torch.no_grad()是导致显存消耗过大的原因,并分享了通过调整批大小和使用with torch.no_grad()来有效降低显存占用的经验。
404





