显存不足时的Trick
- 降低batch_size
适当降低batch size, 则模型每层的输入输出就会成线性减少 - 数据类型
32位的浮点数,切换到 16位的浮点数 - 梯度累积
当执行 loss.backward() 时, 会为每个参数计算梯度,并将其存储在 paramter.grad 中, 注意到, paramter.grad 是一个张量, 其会累加每次计算得到的梯度。
在 Pytorch 中, 只有调用 optimizer.step()时才会进行梯度下降更新网络参数。
batch size 与占用显存息息相关,但有时候我们的batch size 又不能设置的太小,这咋办呢?答案就是梯度累加
传统训练:
for i,(feature,target) in enumerate(train_loader):
outputs