Pytorch模型训练函数如何编写

📚博客主页:knighthood2001
公众号:认知up吧 (目前正在带领大家一起提升认知,感兴趣可以来围观一下)
🎃知识星球:【认知up吧|成长|副业】介绍
❤️如遇文章付费,可先看看我公众号中是否发布免费文章❤️
🙏笔者水平有限,欢迎各位大佬指点,相互学习进步!

今天带大家一步步编写模型的训练函数。

定义训练模型函数

def train_model_process(model, train_dataloader, val_dataloader, num_epochs):

首先,我们定义这样的一个函数,用来表示训练模型。其中,

参数

  • model (nn.Module): 待训练的模型
  • train_dataloader (DataLoader): 训练集的数据加载器
  • val_dataloader (DataLoader): 验证集的数据加载器
  • num_epochs (int): 训练轮数

返回

  • pd.DataFrame: 包含训练过程中各指标的DataFrame,包括epoch、训练集损失、验证集损失、训练集准确率、验证集准确率,这个可以用来具体的绘图。

定义设备,优化器,损失函数

	# 设定训练所用到的设备,有GPU用GPU没有GPU用CPU
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	# 使用Adam优化器,学习率为0.001
	optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
	# 损失函数为交叉熵函数
	criterion = nn.CrossEntropyLoss()

设备、优化器、损失函数,每次都会用到,当然你也可以选择不同的优化器和损失函数。

模型放入设备中

    # 将模型放入到训练设备中
    model = model.to(device)
    # 复制当前模型的参数
    best_model_wts = copy.deepcopy(model.state_dict())

这一步,除了将模型放到训练设备中,我们还复制了当前模型中的参数给最好模型,方便后续进行更新。

初始化参数

    # 初始化参数
    # 最高准确度
    best_accuracy = 0.0
    # 训练集损失列表
    train_loss_all = []
    # 验证集损失列表
    val_loss_all = []
    # 训练集准确度列表
    train_acc_all = []
    # 验证集准确度列表
    val_acc_all = []

这些参数后续可以返回,用来绘图。

    # 当前时间
    since = time.time()

当然,你也可以在这时候记录一下当前时间。

定义每一个epoch

    # 大循环,用来训练一次全部数据
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs-1))
        print("-"*10)

        # 初始化参数
        # 训练集损失函数
        train_loss = 0.0
        # 训练集准确度
        train_corrects = 0
        # 验证集损失函数
        val_loss = 0.0
        # 验证集准确度
        val_corrects = 0
        # 训练集样本数量
        train_num = 0
        # 验证集样本数量
        val_num = 0

这时候,就进入到了每一个epoch,每一轮,你也需要初始化一些参数。

训练集对每一个batch进行训练

        # 对每一个mini-batch训练和计算
        for step, (b_x, b_y) in enumerate(train_dataloader):
            # 将特征放入到训练设备中
            b_x = b_x.to(device)
            # 将标签放入到训练设备中
            b_y = b_y.to(device)
            # 设置模型为训练模式
            model.train()

具体地说,train_dataloader 返回的通常是一个迭代器(iterator)或者一个可以迭代的对象,每次迭代返回一个批量的数据。它负责在训练过程中批量地、高效地加载数据。使用 enumerate(train_dataloader) 可以遍历这个数据加载器,并在每次迭代中获取一对数据 (b_x, b_y),其中 b_x 是一个批量的输入数据b_y 是相应的标签或输出数据。

如下图所示,由于我在DataLoader中设置的batch_size=16,也就是一批量是16,因此这里b_x的批数也是16。同样的,我这里的b_y也是16个标签,比如b_y中0代表猫,1代表狗。

(16,) 并不表示 16 行或者 16 列,

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

knighthood2001

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值