今天pytorch官网更新了pytorch2.0稳定版,迫不及待的我直接更新了,确实像官方所说,只需加入model=torch.compile(model)一行代码即可加速,加入的位置如下。
cpu训练:
model=UNet(deep_supervision=True)
model=torch.compile(model)
单卡训练:
model=UNet(deep_supervision=True)
model.to(Device)
model=torch.compile(model)
多卡训练:
model=UNet(deep_supervision=True)
model.to(Device)
model=nn.parallel.DistributedDataParallel(
model,
device_ids=[lo

PyTorch官网推出了2.0稳定版,新特性torch.compile允许用户只需添加model=torch.compile(model)即可实现模型训练加速。无论是CPU训练、单卡训练还是多卡训练(DistributedDataParallel),只需关注这行代码的位置,无需大量修改原有代码,即可享受性能提升。
最低0.47元/天 解锁文章
1万+





