下面是《深度学习框架PyTorch:入门与实践》中的代码,是一个标准的模型训练器,主要包括4个步骤:
- 准备数据
- 准备网络模型
- 准备损失函数和优化器
- 训练
def train(**kwargs):
# 1、准备数据
train_data = DogCat(opt.train_data_root, train=True)
val_data = DogCat(opt.train_data_root, train + false)
train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True,
num_workers=opt.num_workers)
val_dataloader = DataLoader(val_data, opt.batch_size, shuffle=False,
num_workers=opt.num_workers)
# 2、准备网络模型
model = getattr(models, opt.model)()
if opt.load_model_path:
model.load(opt.load_model_path)
if opt.use_gpu:
model.cuda()
# 3、准备损失函数和优化器
criterion = t.nn.CrossEntropyLoss()
lr = opt.lr
optimizer = t.optim.Adam(model.parameters(), lr=lr,
weight_decay=opt.weight_decay)
# 4、训练
loss_meter = meter.AverageValueMeter()
confusion_matrix = meter.ConfusionMeter(2)
previous_loss = 1e100
for epoch in range(opt.max_epoch):
for i, (data, label) in enumerate(train_dataloader):
# 训练模型参数
input = Variable(data)
target = Variable(label)
if opt.use_gpu:
input = input.cuda()
target = target.cuda()
optimizer.zero_grad()
score = model(input)
loss = criterion(score, target)
loss.backward()
optimizer.step()
# 更新统计指标及可视化
loss_meter.add(loss.data[0])
confusion_matrix.add(score.data, target.data)
if i % opt.print_freg == opt.print_freg - 1:
vis.plot('loss', loss_meter.value()[0])
# 保存模型
model.save()
# 计算验证集上的指标及可视化
val_cm, val_accuracy = val(model, val_dataloader)
vis.plot('val_accuracy', val_accuracy)
vis.log(
"epoch: {epoch}, lr: {lr}, loss: {loss}, train_cm: {train_cm},val_cm : { val_cm}"
.format(
epoch=epoch,
loss=loss_meter.value()[0],
val_cm=str(val_cm.value()),
train_cm=str(confusion_matrix.value()),
lr=lr))
# 如采损失不再下降,则降低学习率
if loss_meter.value()[0] > previous_loss:
lr = lr * opt.lr_decay
for param_group in optimizer.param_groups:
param_group['lr'] = lr
previous_loss = loss_meter.value()[0]