AI大模型LLM微调技术—Activation Checkpointing
Activation Checkpointing(激活检查点)是一种在大型语言模型(LLMs)的微调中用来更有效管理内存使用的技术。在训练过程中,尤其是在非常大的模型中,内存限制可能是一个重大挑战。Activation Checkpointing通过仅保存反向传播所需的部分中间激活(网络每一层的输出),而不是一次性存储所有激活,从而帮助解决这一问题。这种方法减少了内存占用,使得可以训练更大的模型或使用更大的批量大小。然而,它可能会增加计算开销,因为在反向传递过程中需要重新计算某些激活。
划重点:Activation Checkpointing是一种在训练大型神经网络(包括大型语言模型LLMs)中用来减少训练期间内存使用的技术。它使得可以在内存容量有限的硬件(如GPU或TPU)上高效地微调或训练具有数十亿或数万亿参数的模型。
概念与目的
在深度学习中,前向传递期间,激活(层的中间输出)被计算并存储在内存中,因为稍后在反向传递中需要它们来计算梯度。对于大型模型,这些激活可能会消耗大量内存,经常超出硬件的容量。
Activation Checkpointing通过在前向传递期间策略性地仅存储部分激活(或“Checkpointing”)并在需要时在反向传递期间重新计算其他激活来解决这一问题。这种权衡在减少内存使用的同时增加了反向传播期间的额外计算。
工作原理
-
前向传递:
- 算法不是存储所有中间激活,而是只检查点一些关键激活(例如,在某些层之后)。
- 未存储的激活被丢弃以释放内存。
-
反向传递:
- 当需要丢弃的激活时,通过从最近的Checkpointing激活执行部分前向传递来重新计算它们。
- 然后使用这些重新计算的激活计算梯度。
示例
考虑一个有100层的网络:
- 没有Activation Checkpointing:所有层的所有中间激活都将存储在内存中。
- 使用Checkpointing:只存储来自层0、25、50、75和100的激活。在反向传播期间,层1-24的中间激活从层0重新计算,层26-49从层25重新计算,依此类推。
好处
- 减少内存使用: 允许在内存有限的硬件上训练更大的模型或使用更大的批