Checkpointing的作用,设置

本文介绍了Checkpointing技术,一种用于减少深度学习训练中GPU内存使用的策略,通过牺牲部分训练速度以存储分区边界张量而非所有中间激活。GPipe中广泛应用此技术,讨论了不同策略的选择及其对性能的影响。实现涉及torch库的API和重写autograd.Function。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 基本概念

Checkpointing 是一种用于减少训练期间GPU内存使用的技术。这是通过避免在向前传递期间存储中间激活张量来实现的。具体而言,Checkpointing 在正向传播过程中,只会记住分区边界处的张量,所有其他中间张量都不会记住,而是在向后传播过程中跟踪原始输入来重新计算向前传播。因此,隐藏层消耗的内存仅为带有检查点的单个微批次所需要的数量

Checkpointing 是性能和内存之间的折衷,因为如果完全重计算,则所花费的时间与正向传播所花费的时间相同。但 Checkpointing 减少了存储大型激活张量的需要,从而允许我们增加批量大小,增加模型的净吞吐量。

2 使用

在 GPipe之中,Checkpointing 应用于每个分区,以最小化模型的总体内存消耗。

Checkpointing 会极大减少内存使用,但总体训练速度会降低25%左右。您可以处理如何在模型上应用检查点。Checkpointing 只有三种选择,不能够指定某些特定

  • “always” :在所有微批次上应用检查点。

  • “except_last” : 在最后一个微批次之外应用检查点。

  • “never” :从不应用检查点。

@pytest.mark.parametrize('checkpoint', ['never', 'always', 'except_last'])

通常,在最后一个微批次上的检查点可能没有用处,因为保存的内存将立即重建。这就是为什么我们选择"except_last"作为默认选项。如果您决定根本不使用检查点,那么<torch.nn.DataParallel>可能比GPipe更有效。 

3 实现概述

Checkpointing 已经作为“torch.utils.checkpoint.checkpoint_wrapper"API的一部分实现,通过该API可以包装前向过程中的不同模块。

Checkpointing 通过重写“torch.autograd.Function"来实现。在处理模块前向传递的“forward"函数中,如果使用“no_grad",我们可以在很长一段时间内(即直到反向传播之前)防止正向图的创建和中间激活张量的物化(?什么意思?)。相反,在后向传播期间,会再次执行前向传播,然后执行后向传播。

前向传播过程的输入使用上下文对象保存,然后在后向传播过程中访问该上下文对象以检索原始输入。PyTorch还保存了RNG(Random Number Generator)的状态,用于前向传播和后向传播,如 Dropout layers 所需。

以下是几个注意点:

  1. 内存节省完全取决于检查点所包装的模型和分段。每个backprop由几个迷你前向传播(mini-forward)和backprop过程组成。收益完全取决于每层激活值的内存占用。

  2. 使用BatchNormalization时,您可能需要冻结统计数据的计算,因为我们运行了两次正向传递

  3. 确保输入张量的’requires_grad’字段设置为True。为了触发后向传播功能,输出需要设置此字段。通过在输入张量设置这个字段,我们可以确保将其传播到输出,并触发’backward’函数。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值