pytorch backward 求梯度 累计 样式

PyTorch的backward函数在计算梯度时采用累计方式,因此需要optimizer.zero_grad()重置梯度。当批量大小过大导致显存不足时,使用gradientaccumulation将大批次拆分为小批次,平均损失并更新权重。在每个小批次的累积迭代次数达到预设值或最后一个批次后,执行optimizer.step()更新模型参数。
部署运行你感兴趣的模型镜像

pytorch backwad 函数计算梯度是 累计式的

关于 pytorch 的 backward()函数反向传播计算梯度是 累计式的,见下图(主要是图中用黑框框出来的部分内容)。
因为这样,所以才需要 optimizer.zero_grad()
在这里插入图片描述

利用 gradient accumulation 的框架

一般的优化框架是:

# loop through batches
for (inputs, labels) in data_loader:

    # extract inputs and labels
    inputs = inputs.to(device)
    labels = labels.to(device)

    # passes and weights update
    with torch.set_grad_enabled(True):
        
        # forward pass 
        preds = model(inputs)
        loss  = criterion(preds, labels)

        # backward pass
        loss.backward() 

        # weights update
        optimizer.step()
        optimizer.zero_grad()

利用 gradient accumulation 的框架是这样的。
为什么需要 gradient accumulation 呢?
因为可能会出现 训练集 batch size 比较大,电脑 显存吃不下的情况,这样就需要将一个 batch 分为几个小 batch训练,但是同时又都利用它们的gradient 信息。

# batch accumulation parameter
accum_iter = 4  

# loop through enumaretad batches
for batch_idx, (inputs, labels) in enumerate(data_loader):

    # extract inputs and labels
    inputs = inputs.to(device)
    labels = labels.to(device)

    # passes and weights update
    with torch.set_grad_enabled(True):
        
        # forward pass 
        preds = model(inputs)
        loss  = criterion(preds, labels)

        # normalize loss to account for batch accumulation
        loss = loss / accum_iter 

        # backward pass
        loss.backward()

        # weights update
        if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):
            optimizer.step()
            optimizer.zero_grad()

您可能感兴趣的与本文相关的镜像

ACE-Step

ACE-Step

音乐合成
ACE-Step

ACE-Step是由中国团队阶跃星辰(StepFun)与ACE Studio联手打造的开源音乐生成模型。 它拥有3.5B参数量,支持快速高质量生成、强可控性和易于拓展的特点。 最厉害的是,它可以生成多种语言的歌曲,包括但不限于中文、英文、日文等19种语言

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

培之

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值