本人最近使用Apex,来降低显存消耗。但目标模型是GAN结构,model和optimizer都不止一个。这种情况下该如何使用。
-
首先在github下载源码
https://github.com/NVIDIA/apex
并安装。 -
大部分网上给的教程是这样的:
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是零一
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
- 但如果目标模型是GAN结构,model和optimizer都不止一个。这种情况下该如何使用呢?
-
查看源码注释:
Args:
models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast.
optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers): Optimizers to modify/cast. REQUIRED for training, optional for inference.
-
因此,将参数变为list形式即可。即
[self.genA2B,self.genB2A,self.disLA,self.disLB,self.disGA,self.disGB], [self.G_optim, self.D_optim] =
amp.initialize(models =[self.genA2B,self.genB2A,self.disLA,self.disLB,self.disGA,self.disGB], optimizers=
[self.G_optim, self.D_optim], opt_level="O1")
with amp.scale_loss(Discriminator_loss, self.D_optim) as D_loss:
D_loss.backward()
with amp.scale_loss(Generator_loss, self.G_optim) as G_loss:
G_loss.backward()
- 优化参数含义