zero_grad()函数用于每次计算完一个batch样本后的梯度清零(原因在于pytorch中的梯度反馈在节点上是累加的)
pytorch每计算一次backward会把结果累加给计算图,当我们的batch size为10时,即每处理十个样本并累加了他们的梯度值后再释放显存,相比于batchsize为2时的方差和均值显然是更精确的,但同样的,内存需要存储十个计算图,对卡的性能也提出了更高要求。
那么我们可以通过每计算完多个batch样本后再进行一次zero_grad()清零,就是一种变相提高batch_size的方法,对于显卡不行但是batch_size可能需要设高的领域比较适合,比如目标检测模型的训练。
要注意一点是,学习率也要相应增大。