Pytorch的各种损失函数

参考官网torch.nn中的loss fuction

1. BCELoss (Binary Cross Entropy Loss)

  • 用途:用于二分类问题。
  • 案例:逻辑回归,二分类的神经网络。
  • 代码示例
    loss = nn.BCELoss()
    input = torch.randn(3, requires_grad=True)
    target = torch.randn(3).random_(2)
    output = loss(input, target)
    

2. NLLLoss (Negative Log Likelihood Loss)

  • 用途:用于多分类问题,通常与LogSoftmax层一起使用。
  • 案例:多分类的神经网络。
  • 代码示例
    loss = nn.NLLLoss()
    input = torch.randn(3, 5, requires_grad=True)
    target = torch.tensor([1, 0, 4])
    output = loss(F.log_softmax(input, dim=1), target)
    

3. BCEWithLogitsLoss

  • 用途:结合了Sigmoid层和BCELoss,用于二分类问题。
  • 案例:二分类的神经网络。
  • 代码示例
    loss = nn.BCEWithLogitsLoss()
    input = torch.randn(3, requires_grad=True)
    target = torch.randn(3).random_(2)
    output = loss(input, target)
    

4. CrossEntropyLoss

  • 用途:用于多分类问题,结合了LogSoftmaxNLLLoss
  • 案例:图像分类,文本分类。
  • 代码示例
    loss = nn.CrossEntropyLoss()
    input = torch.randn(3, 5, requires_grad=True)
    target = torch.tensor([1, 0, 4])
    output = loss(input, target)
    

5. L1Loss (MAELoss)

  • 用途:用于回归问题,计算输出和目标之间的平均绝对误差。
  • 案例:时间序列预测,回归分析。
  • 代码示例
    loss = nn.L1Loss()
    input = torch.randn(3, requires_grad=True)
    target = torch.randn(3)
    output = loss(input, target)
    

6. MSELoss (Mean Square Error Loss)

  • 用途:用于回归问题,计算输出和目标之间的均方误差。
  • 案例:线性回归,神经网络回归。
  • 代码示例
    loss = nn.MSELoss()
    input = torch.randn(3, requires_grad=True)
    target = torch.randn(3)
    output = loss(input, target)
    

7. HingeEmbeddingLoss

  • 用途:用于度量两个输入是否相似。
  • 案例:度量学习,推荐系统。
  • 代码示例
    loss = nn.HingeEmbeddingLoss()
    input = torch.randn(3, requires_grad=True)
    target = torch.tensor([1, -1, 1])
    output = loss(input, target)
    

8. CosineEmbeddingLoss

  • 用途:用于度量两个输入的余弦相似度。
  • 案例:度量学习,推荐系统。
  • 代码示例
    loss = nn.CosineEmbeddingLoss()
    input1 = torch.randn(2, 3, requires_grad=True)
    input2 = torch.randn(2, 3, requires_grad=True)
    target = torch.tensor([1, -1])
    output = loss(input1, input2, target)
    

当然,PyTorch还提供了其他多种损失函数,用于不同的机器学习任务和场景。以下是一些额外的损失函数及其用途:

9. MarginRankingLoss

  • 用途:用于比较两个样本的相对距离,常用于排序任务。
  • 案例:学习排序,信息检索。
  • 代码示例
    loss = nn.MarginRankingLoss()
    input1 = torch.randn(3, requires_grad=True)
    input2 = torch.randn(3, requires_grad=True)
    target = torch.randn(3).sign()
    output = loss(input1, input2, target)
    

10. MultiLabelMarginLoss

  • 用途:用于多标签分类问题,其中每个样本可以属于多个类别。
  • 案例:图像标注,多标签文本分类。
  • 代码示例
    loss = nn.MultiLabelMarginLoss()
    input = torch.randn(3, 5, requires_grad=True)
    target = torch.tensor([[0, 1, 2], [0, 2, 1], [1, 0, 2]])
    output = loss(input, target)
    

11. SmoothL1Loss

  • 用途:用于回归问题,是L1Loss的一个平滑版本,对异常值不敏感。
  • 案例:目标检测,姿态估计。
  • 代码示例
    loss = nn.SmoothL1Loss()
    input = torch.randn(3, requires_grad=True)
    target = torch.randn(3)
    output = loss(input, target)
    

12. SoftMarginLoss

  • 用途:用于二分类问题,类似于逻辑回归的损失函数。
  • 案例:二分类的神经网络。
  • 代码示例
    loss = nn.SoftMarginLoss()
    input = torch.randn(3, requires_grad=True)
    target = torch.tensor([1, -1, 1])
    output = loss(input, target)
    

13. TripletMarginLoss

  • 用途:用于度量学习,特别是三元组损失,用于学习嵌入空间中的相对距离。
  • 案例:人脸识别,行人重识别。
  • 代码示例
    loss = nn.TripletMarginLoss(margin=1.0)
    anchor = torch.randn(100, 128, requires_grad=True)
    positive = torch.randn(100, 128, requires_grad=True)
    negative = torch.randn(100, 128, requires_grad=True)
    output = loss(anchor, positive, negative)
    

14. PoissonNLLLoss

  • 用途:用于计数数据,如时间序列或频率数据,当数据遵循泊松分布时使用。
  • 案例:自然语言处理中的词频模型,交通流量预测。
  • 代码示例
    loss = nn.PoissonNLLLoss()
    log_input = torch.randn(5, 2, requires_grad=True)
    target = torch.randn(5, 2)
    output = loss(log_input, target)
    

15. KLDivLoss

  • 用途:用于度量两个概率分布之间的差异,特别是KL散度。
  • 案例:变分自编码器,概率图模型。
  • 代码示例
    loss = nn.KLDivLoss(reduction='batchmean')
    log_input = torch.randn(5, 2, requires_grad=True)
    target = torch.randn(5, 2)
    output = loss(log_input, target)
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值