1 基本概念
Checkpointing 是一种用于减少训练期间GPU内存使用的技术。这是通过避免在向前传递期间存储中间激活张量来实现的。具体而言,Checkpointing 在正向传播过程中,只会记住分区边界处的张量,所有其他中间张量都不会记住,而是在向后传播过程中跟踪原始输入来重新计算向前传播。因此,隐藏层消耗的内存仅为带有检查点的单个微批次所需要的数量。
Checkpointing 是性能和内存之间的折衷,因为如果完全重计算,则所花费的时间与正向传播所花费的时间相同。但 Checkpointing 减少了存储大型激活张量的需要,从而允许我们增加批量大小,增加模型的净吞吐量。
2 使用
在 GPipe之中,Checkpointing 应用于每个分区,以最小化模型的总体内存消耗。
Checkpointing 会极大减少内存使用,但总体训练速度会降低25%左右。您可以处理如何在模型上应用检查点。Checkpointing 只有三种选择,不能够指定某些特定
-
“always” :在所有微批次上应用检查点。
-
“except_last” : 在最后一个微批次之外应用检查点。
-
“never” :从不应用检查点。
@pytest.mark.parametrize('checkpoint', ['never', 'always', 'except_last'])
通常,在最后一个微批次上的检查点可能没有用处,因为保存的内存将立即重建。这就是为什么我们选择"except_last"作为默认选项。如果您决定根本不使用检查点,那么<torch.nn.DataParallel>
可能比GPipe更有效。
3 实现概述
Checkpointing 已经作为“torch.utils.checkpoint.checkpoint_wrapper"API的一部分实现,通过该API可以包装前向过程中的不同模块。
Checkpointing 通过重写“torch.autograd.Function"来实现。在处理模块前向传递的“forward"函数中,如果使用“no_grad",我们可以在很长一段时间内(即直到反向传播之前)防止正向图的创建和中间激活张量的物化(?什么意思?)。相反,在后向传播期间,会再次执行前向传播,然后执行后向传播。
前向传播过程的输入使用上下文对象保存,然后在后向传播过程中访问该上下文对象以检索原始输入。PyTorch还保存了RNG(Random Number Generator)的状态,用于前向传播和后向传播,如 Dropout layers 所需。
以下是几个注意点:
-
内存节省完全取决于检查点所包装的模型和分段。每个backprop由几个迷你前向传播(mini-forward)和backprop过程组成。收益完全取决于每层激活值的内存占用。
-
使用BatchNormalization时,您可能需要冻结统计数据的计算,因为我们运行了两次正向传递。
-
确保输入张量的’requires_grad’字段设置为True。为了触发后向传播功能,输出需要设置此字段。通过在输入张量设置这个字段,我们可以确保将其传播到输出,并触发’backward’函数。