PyTorch 1.0 中文文档:torch.utils.checkpoint

本文详细介绍了PyTorch中checkpointing的功能和使用方法,解释了如何通过牺牲计算资源来节省内存,适用于深度学习模型训练,特别是对于具有大量参数的模型。文章还讨论了checkpointing对随机数生成器状态的影响及如何调整以获得更好的性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

译者: belonHan

注意

checkpointing的实现方法是在向后传播期间重新运行已被checkpint的前向传播段。 所以会导致像RNG这类(模型)的持久化的状态比实际更超前。默认情况下,checkpoint包含了使用RNG状态的逻辑(例如通过dropout),与non-checkpointed传递相比,checkpointed具有更确定的输出。RNG状态的存储逻辑可能会导致一定的性能损失。如果不需要确定的输出,设置全局标志(global flag) torch.utils.checkpoint.preserve_rng_state=False 忽略RNG状态在checkpoint时的存取。

torch.utils.checkpoint.checkpoint(function, *args)

checkpoint模型或模型的一部分

checkpoint通过计算换内存空间来工作。与向后传播中存储整个计算图的所有中间激活不同的是,checkpoint不会保存中间激活部分,而是在反向传递中重新计算它们。它被应用于模型的任何部分。

具体来说,在正向传播中,function将以torch.no_grad()方式运行 ,即不存储中间激活,但保存输入元组和 function的参数。在向后传播中,保存的输入变量以及 function会被取回,并且function在正向传播中被重新计算.现在跟踪中间激活,然后使用这些激活值来计算梯度。

Warning
警告

Checkpointing 在 torch.autograd.grad()中不起作用, 仅作用于 torch.autograd.backward().

警告

如果function在向后执行和前向执行不同,例如,由于某个全局变量,checkpoint版本将会不同,并且无法被检测到。

参数:

  • function - 描述在模型的正向传递或模型的一部分中运行的内容。它也应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户通过 ,应正确使用第一个输入作为第二个输入(activation, hidden)functionactivationhidden
  • args – 包含输入的元组function

阅读全文/改进本文

转载于:https://www.cnblogs.com/wizardforcel/p/10492605.html

/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants. return fn(*args, **kwargs) {'loss': 1.955, 'grad_norm': 0.41192638874053955, 'learning_rate': 0.0002, 'num_tokens': 12374.0, 'mean_token_accuracy': 0.5659834313392639, 'epoch': 0.0} {'loss': 1.6544, 'grad_norm': 0.9751525521278381, 'learning_rate': 0.0002, 'num_tokens': 16766.0, 'mean_token_accuracy': 0.6307526516914368, 'epoch': 0.0} {'loss': 1.8638, 'grad_norm': 0.3490130603313446, 'learning_rate': 0.0002, 'num_tokens': 27174.0, 'mean_token_accuracy': 0.5735858237743378, 'epoch': 0.01} {'loss': 1.7149, 'grad_norm': 0.6998162269592285, 'learning_rate': 0.0002, 'num_tokens': 31047.0, 'mean_token_accuracy': 0.6218746590614319, 'epoch': 0.01} 1%|█▉ | 100/13001 [01:13<2:25:55, 1.47it/s]/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants. return fn(*args, **kwargs) {'loss': 1.7628, 'grad_norm': 0.29718583822250366, 'learning_rate': 0.0002, 'num_tokens': 42117.0, 'mean_token_accuracy': 0.5874369502067566, 'epoch': 0.01} {'loss': 1.6219, 'grad_norm': 0.5728892087936401, 'learning_rate': 0.0002, 'num_tokens': 46196.0, 'mean_token_accuracy': 0.6381562113761902, 'epoch': 0.01} {'loss': 1.8255, 'grad_norm': 0.31880176067352295, 'learning_rate': 0.0002, 'num_tokens': 58459.0, 'mean_token_accuracy': 0.5713030004501343, 'epoch': 0.01} {'loss': 1.5681, 'grad_norm': 0.8215921521186829, 'learning_rate': 0.0002, 'num_tokens': 62713.0, 'mean_token_accuracy': 0.6404920220375061, 'epoch': 0.02} 2%|███▊ | 200/13001 [02:25<2:22:47, 1.49it/s]/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants. return fn(*args, **kwargs) {'loss': 1.8265, 'grad_norm': 0.31884294748306274, 'learning_rate': 0.0002, 'num_tokens': 74551.0, 'mean_token_accuracy': 0.5766114640235901, 'epoch': 0.02} {'loss': 1.6236, 'grad_norm': 0.6962191462516785, 'learning_rate': 0.0002, 'num_tokens': 78901.0, 'mean_token_accuracy': 0.6332866501808166, 'epoch': 0.02} {'loss': 1.7901, 'grad_norm': 0.31407737731933594, 'learning_rate': 0.0002, 'num_tokens': 91128.0, 'mean_token_accuracy': 0.5836749339103698, 'epoch': 0.02} {'loss': 1.5906, 'grad_norm': 0.5636782646179199, 'learning_rate': 0.0002, 'num_tokens': 95364.0, 'mean_token_accuracy': 0.637250554561615, 'epoch': 0.02} 2%|█████▋ | 300/13001 [03:45<2:17:26, 1.54it/s]/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants. return fn(*args, **kwargs) 这个日志输出是什么意思???
最新发布
07-24
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值