『ignite』PyTorch好用的工具包

本文介绍PyTorch Ignite库,它简化了深度学习模型的训练流程,通过创建训练器和评估器,管理训练周期,处理事件和指标计算,使模型训练更为高效。

尽管 PyTorch 已经为我们实现神经网络提供了不少便利,但是人的惰性是无极限的,这里介绍一个进一步抽象的工具包——ignite,它将 PyTorch 训练过程更加简化了。

1. 安装

pip install pytorch-ignite

2. 基础示例

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss

model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
criterion = nn.NLLLoss()

trainer = create_supervised_trainer(model, optimizer, criterion)

val_metrics = {
   
   
    "accuracy": Accuracy(),
    "nll": Loss(criterion)
}
evaluator = create_supervised_evaluator(model, metrics=val_metrics)

@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(trainer):
    print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    print("Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))

trainer.run(train_loader, max_epochs=100)

显然,这里先创建网络模型,Dataloader,优化器以及目标函数,然后用 ignite 的方法 create_supervised_trainer 和 create_supervised_evaluator 简化以往繁琐的循环写法,另外,ignite 还提供了面向切面的处理方法,可以在epoch、iteration等开始前、结束后位置执行你希望的操作

3. Engine

这是 ignite 的核心类,它是一种抽象,它在提供的数据上循环给定的次数,执行处理函数并返回结果

while epoch < max_epochs:
    # run an epoch on data
    data_iter = iter(data)
    while True:
        try:
            batch = next(data_iter)
            output = process_function(batch)
            iter_counter += 1
        except StopIteration:
            data_iter = iter(data)

        if iter_counter == epoch_length:
            break

因此,模型训练器只是一个引擎,它在训练数据集上循环多次并更新模型参数。例如:

def train_step(trainer, batch):
    model.train()
    optimizer.zero_grad()
    x, y = prepare_batch(batch)
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

【例 1】创建一个基本的训练器

def update_model(engine, batch):
    inputs, targets = batch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(update_model)

@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training(engine):
    batch_loss = engine.state.output
    lr = optimizer.param_groups[0]['lr']
    e = engine.state.epoch
    n = engine.state.max_epochs
    i = engine.state.iteration
    print("Epoch {}/{} : {} - batch loss: {}, lr: {}".format(e, n, i, batch_los
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值