一、常见损失函数的用法

本文详细介绍了多分类交叉熵、均方差、二分类交叉熵等损失函数在深度学习中的应用,展示了如何在PyTorch中实现这些损失函数,并通过实例演示了如何在MNIST数据集上训练网络,强调了损失函数在模型训练中的关键作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

定义损失函数的常用方法,其中包括多分类交叉熵、均方差、二分类交叉熵的用法。其作用包括:1.衡量模型输出值和标签值的差异;2.评估模型的预测值与真实值不一致程度;3.神经网络中优化的目标函数,损失函数越小,预测值越接近真实值,模型健壮性也越好。


一、L1-loss(MAE)、L2- loss(MSE)、smooth L1- loss、交叉熵损失函数是什么?

在这里插入图片描述

在这里插入图片描述

二、使用步骤

1.损失函数方法

代码如下(示例):

#定义损失函数,更新梯度----
loss_fn = torch.nn.CrossEntropyLoss()#多分类交叉熵不需要用激活幂函数输出
# loss_fn=torch.nn.MSELoss()#均方差
# loss_fn=torch.nn.BCELoss()#二分类交叉熵
# loss_fn=torch.nn.BCEWithLogitsLoss()#自动引入激活函数

2.代码操作

代码如下(示例):

import torch
from torchvision import datasets,transforms
from torch.utils.data import DataLoader#类时加载数据的核心,返回可迭代的数据
import os
import matplotlib.pyplot as plt


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()#继承
        self.fc1 = torch.nn.Sequential(
            torch.nn.Linear(784,256),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU())
        #nn.Sequential 将网络层和激活函数结合起来,输出激活后的网络节点。
        #nn.Linear(in_features,out_features,bias = True )对传入数据应用线性变换
        #784,每个输入样本的大小----即为28*28,图像的像素值w*h
        # 256 每个输出样本的大小-----784通过Linear函数
        #BatchNorm1d(256),#自适应标准化-正态分布--输入值落在非线性函数敏感的区域,避免梯度消失问题产生
        #nn.ReLU() 激活函数 relu
        self.fc2 = torch.nn.Sequential(
            torch.nn.Linear(256,128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU())
        self.fc3 = torch.nn.Linear(128,10)
    def forward(self,x):#forward函数里面实现在前向传播运算
        # print(x.shape)
        #N,C,H,W(batchsize,channels,x,y)-->N,V
        #x.size(0)==batchsize,转换后有几行
        #最后通过x.view(x.size(0), -1)将tensor的结构转换为了(batchsize, channels*x*y),
        # 即将(channels,x,y)拉直,然后就可以和fc层连接了
        #-1指在不告诉函数有多少列的情况下,根据原tensor数据和batchsize自动分配列数。
        x = torch.reshape(x,[x.size(0),-1])#变换形状,换成2维,reshape=view
        # print(x.shape)
        y=self.fc1(x)#N,256
        y=self.fc2(y)#N,128     #y=w*sqrt(x2+bias)
        # y=self.fc3(y)#N,10
        self.y=self.fc3(y)
        y=torch.softmax(self.y,1)
        return y


if __name__ == '__main__':
    save_params = r"./save_params/parmas.pth"#保存参数
    save_net = r"./save_params/net.pth"#保存网络
    transf = transforms.Compose([transforms.ToTensor(),
             transforms.Normalize(mean=[0.5,],std=[0.5,])])
    #transforms.Compose 将transforms列表里面的transform操作进行遍历。
    #transforms.ToTensor() 灰度范围从0-255变换到0-1之间
    #transforms.Normalize把0-1变换到(-1,1),(image-mean)/std
    train_data = datasets.MNIST("./data",train=True,transform=transf,download=True)#读取训练数据
    test_data = datasets.MNIST("./data",train=False,transform=transf,download=False)#读取测试数据

    # 100涨图片,True 是否打乱,随机,给出不同的特征才能学习
    trin_loader = DataLoader(train_data,100,True)#加载数据
    test_loader = DataLoader(test_data,100,True)
    # DataLoader()
    # 利用多进程来加速batchdata的处理
    # 直观的网络输入数据结构,便于使用和扩展
    print(train_data.data.shape)
    print(train_data.targets.shape)
    print(test_data.data.shape)
    print(test_data.targets.shape)
    print(test_data.classes)

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    net=Net().to(device)#网络加载器开始读取数据时的tensor变量copy一份到device所指定的cuda上去
    if os.path.exists(save_params):
        net.load_state_dict(torch.load(save_params))#只加载参数
        print("参数加载成功")
    else:
        print("No params!")
    # net = torch.load(save_net).to(device)#加载参数和网络
    # loss_fn = torch.nn.CrossEntropyLoss()#多分类交叉熵,定义损失函数,更新梯度----
    # loss_fn=torch.nn.MSELoss()#均方差
    # loss_fn=torch.nn.BCELoss()#二分类交叉熵
    loss_fn=torch.nn.BCEWithLogitsLoss()#二分类交叉熵,对输入值自动做sigmoid
    # optim = torch.optim.SGD(net.parameters(),lr=1e-3)#比较稳定
    optim = torch.optim.Adam(net.parameters(),lr=1e-3)#创建优化器
    #torch.optim.Adam  优化器
    # (net.parameters(), 待优化参数的iterable或者是定义了参数组的dict
    # lr=1e-3) 学习率或步长因子

    #测试,实时画图分析
    a = []
    b = []
    # plt.ion()
    net.train()
    for epoch in range(1):
        for i ,(x,y) in enumerate(trin_loader):
            x = x.to(device)
            y = y.to(device)
            y_ = torch.zeros(len(y), max(y) + 1).to(device)
            y_[torch.arange(len(y)), y] = 1
            out = net(x)#前向输出
            # loss = loss_fn(out,y)#求损失
            loss = loss_fn(net.y,y_)
            optim.zero_grad()#清空当前梯度
            loss.backward()#计算当前梯度
            optim.step()#沿着当前梯度更新一步
            # a.append(i)
            # b.append(loss.item())
            # plt.clf()
            # plt.plot(a,b)
            # plt.pause(0.1)
            if i%50==0:
                print("loss",loss.item())
    plt.ioff()
    plt.show()


    # 测试
    eval_loss=0
    eval_acc=0
    net.eval()
    for i,(x,y) in enumerate(test_loader):
        x = x.to(device)#x传给网络
        y = y.to(device)
        y_ = torch.zeros(len(y), max(y) + 1).to(device)
        y_[torch.arange(len(y)), y] = 1
        out = net(x)
        # loss = loss_fn(out, y)
        loss = loss_fn(out,y_)
        eval_loss += loss.item() * y.size(0)
        eval_acc += (y == torch.argmax(out, 1)).cpu().sum().item()
    avg_loss = eval_loss / len(test_data)
    avg_acc = eval_acc / len(test_data)
    print(avg_loss)
    print(avg_acc)

    if not os.path.exists("./save_params"):
        os.mkdir("./save_params")
    torch.save(net.state_dict(),"./save_params/parmas.pth")#只保存参数
    torch.save(net,"./save_params/net.pth")

总结

提示:这里对文章进行总结:

loss_fn = torch.nn.CrossEntropyLoss()#多分类交叉熵,输出不需要加激活函数。
loss_fn=torch.nn.MSELoss()#均方差、输出需要加激活函数。
loss_fn=torch.nn.BCELoss()#二分类交叉熵、输出需要加激活函数。
loss_fn=torch.nn.BCEWithLogitsLoss()#二分类交叉熵,对输入值自动做sigmoid
资源下载链接为: https://pan.quark.cn/s/5c50e6120579 在Android移动应用开发中,定位功能扮演着极为关键的角色,尤其是在提供导航、本地搜索等服务时,它能够帮助应用获取用户的位置信息。以“baiduGPS.rar”为例,这是一个基于百度地图API实现定位功能的示例项目,旨在展示如何在Android应用中集成百度地图的GPS定位服务。以下是对该技术的详细阐述。 百度地图API简介 百度地图API是由百度提供的一系列开放接口,开发者可以利用这些接口将百度地图的功能集成到自己的应用中,涵盖地图展示、定位、路径规划等多个方面。借助它,开发者能够开发出满足不同业务需求的定制化地图应用。 Android定位方式 Android系统支持多种定位方式,包括GPS(全球定位系统)和网络定位(通过Wi-Fi及移动网络)。开发者可以根据应用的具体需求选择合适的定位方法。在本示例中,主要采用GPS实现高精度定位。 权限声明 在Android应用中使用定位功能前,必须在Manifest.xml文件中声明相关权限。例如,添加<uses-permission android:name="android.permission.ACCESS_FINE_LOCATION" />,以获取用户的精确位置信息。 百度地图SDK初始化 集成百度地图API时,需要在应用启动时初始化地图SDK。通常在Application类或Activity的onCreate()方法中调用BMapManager.init(),并设置回调监听器以处理初始化结果。 MapView的创建 在布局文件中添加MapView组件,它是地图显示的基础。通过设置其属性(如mapType、zoomLevel等),可以控制地图的显示效果。 定位服务的管理 使用百度地图API的LocationClient类来管理定位服务
资源下载链接为: https://pan.quark.cn/s/dab15056c6a5 Oracle Instant Client是一款轻量级的Oracle数据连接工具,能够在不安装完整Oracle客户端软件的情况下,为用户提供访问Oracle数据的能力。以“instantclient-basic-nt-12.1.0.1.0.zip”为例,它是针对Windows(NT)平台的Instant Client基本版本,版本号为12.1.0.1.0,包含连接Oracle数据所需的基本组件。 Oracle Instant Client主要面向开发人员和系统管理员,适用于数据查询、应用程序调试、数据迁移等工作。它支持运行SQL*Plus、PL/SQL Developer等管理工具,还能作为ODBC和JDBC驱动的基础,让非Oracle应用连接到Oracle数据。 安装并解压“instantclient_12_1”后,为了使PL/SQL Developer等应用程序能够使用该客户端,需要进行环境变量配置。设置ORACLE_HOME指向Instant Client的安装目录,如“C:\instantclient_12_1”。添加TNS_ADMIN环境变量,用于存放网络配置文件(如tnsnames.ora)。将Instant Client的bin目录添加到PATH环境变量中,以便系统能够找到oci.dll等关键动态链接。 oci.dll是OCI(Oracle Call Interface)的重要组成部分。OCI是Oracle提供的C语言接口,允许开发者直接与数据交互,执行SQL语句、处理结果集和管理事务等功能。确保系统能够找到oci.dll是连接数据的关键。 tnsnames.ora是Oracle的网络配置文件,用于定义数据服务名与网络连接参数的映射关系,包括服务器地址
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值