pytorch+Apex,面对多optimizer和多model如何使用

本人最近使用Apex,来降低显存消耗。但目标模型是GAN结构,model和optimizer都不止一个。这种情况下该如何使用。

  1. 首先在github下载源码https://github.com/NVIDIA/apex 并安装。

  2. 大部分网上给的教程是这样的:

from apex import amp 
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是零一
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
  1. 但如果目标模型是GAN结构,model和optimizer都不止一个。这种情况下该如何使用呢?
  • 查看源码注释:
    Args:
    models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast.
    optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers): Optimizers to modify/cast. REQUIRED for training, optional for inference.

  • 因此,将参数变为list形式即可。即

[self.genA2B,self.genB2A,self.disLA,self.disLB,self.disGA,self.disGB], [self.G_optim, self.D_optim] =  
amp.initialize(models =[self.genA2B,self.genB2A,self.disLA,self.disLB,self.disGA,self.disGB], optimizers=
[self.G_optim, self.D_optim], opt_level="O1")

with amp.scale_loss(Discriminator_loss, self.D_optim) as D_loss:
                 D_loss.backward()

with amp.scale_loss(Generator_loss, self.G_optim) as G_loss:
                G_loss.backward()
  1. 优化参数含义优化参数含义
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值