使用pytorch进行训练的步骤

下面是《深度学习框架PyTorch:入门与实践》中的代码,是一个标准的模型训练器,主要包括4个步骤:

  • 准备数据
  • 准备网络模型
  • 准备损失函数和优化器
  • 训练
def train(**kwargs):
    # 1、准备数据
    train_data = DogCat(opt.train_data_root, train=True)
    val_data = DogCat(opt.train_data_root, train + false)
    train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True,
                                  num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data, opt.batch_size, shuffle=False,
                                num_workers=opt.num_workers)

    # 2、准备网络模型
    model = getattr(models, opt.model)()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu:
        model.cuda()

    # 3、准备损失函数和优化器
    criterion = t.nn.CrossEntropyLoss()
    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(), lr=lr,
                             weight_decay=opt.weight_decay)

    # 4、训练
    loss_meter = meter.AverageValueMeter()
    confusion_matrix = meter.ConfusionMeter(2)
    previous_loss = 1e100
    for epoch in range(opt.max_epoch):
        for i, (data, label) in enumerate(train_dataloader):

            # 训练模型参数
            input = Variable(data)
            target = Variable(label)
            if opt.use_gpu:
                input = input.cuda()
                target = target.cuda()
            optimizer.zero_grad()
            score = model(input)
            loss = criterion(score, target)
            loss.backward()
            optimizer.step()

            # 更新统计指标及可视化
            loss_meter.add(loss.data[0])
            confusion_matrix.add(score.data, target.data)

            if i % opt.print_freg == opt.print_freg - 1:
                vis.plot('loss', loss_meter.value()[0])
                
        # 保存模型
        model.save()

        # 计算验证集上的指标及可视化
        val_cm, val_accuracy = val(model, val_dataloader)
        vis.plot('val_accuracy', val_accuracy)

        vis.log(
            "epoch: {epoch}, lr: {lr}, loss: {loss}, train_cm: {train_cm},val_cm : { val_cm}"
                .format(
                epoch=epoch,
                loss=loss_meter.value()[0],
                val_cm=str(val_cm.value()),
                train_cm=str(confusion_matrix.value()),
                lr=lr))

        # 如采损失不再下降,则降低学习率
        if loss_meter.value()[0] > previous_loss:
            lr = lr * opt.lr_decay
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        previous_loss = loss_meter.value()[0]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值