pytorch 交叉熵类定义

这段代码定义了一个`Loss_Calculator`类,用于计算并记录CrossEntropyLoss。`calc_loss`方法接收输出和目标,计算损失并将其添加到历史损失序列中。`get_loss_log`方法返回最近平均损失值。此部分代码适用于训练神经网络模型时跟踪损失变化。
import torch

class Loss_Calculator(object):
    def __init__(self):
        self.criterion = torch.nn.CrossEntropyLoss()        
        self.loss_seq = []
    
    def calc_loss(self, output, target):
        loss = self.criterion(output, target)        
        self.loss_seq.append(loss.item())
        return loss

    def get_loss_log(self, length=100):
        # get recent average loss values
        if len(self.loss_seq) < length:
            length = len(self.loss_seq)
        return sum(self.loss_seq[-length:])/length
loss_calculator = Loss_Calculator()
...
loss_calculator.calc_loss(outputs, targets)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值