import torch
def get_optimizer(network, args):
optimizer = torch.optim.SGD(network.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
scheduler = None
if args.lr_milestone is not None:
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=args.lr_milestone, gamma=args.lr_gamma)
return optimizer, scheduler
optimizer.zero_grad()
...
loss.backward()
...
optimizer.step()
# adjust learning rate
if scheduler is not None:
scheduler.step()
这段代码展示了如何在PyTorch中创建SGD优化器,并根据预设的里程碑调整学习率。通过MultiStepLR调度器,可以在特定训练阶段降低学习率以优化网络训练效果。
8958

被折叠的 条评论
为什么被折叠?



