Table of Contents
模型保存
# 保存整个网络
torch.save(net, PATH)
# 保存网络中的参数, 速度快,占空间少
torch.save(net.state_dict(),PATH)
常规模型加载
model_dict=torch.load(PATH)
model_dict=model.load_state_dict(torch.load(PATH))
模型详细信息保存
torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')
eg:
for epoch in range(init_epoch, basic_configs['num_epochs']):
logger.info("Begin training epoch {}".format(epoch + 1))
train_function(epoch)
net_checkpoint_name = args.exp + "_net_epoch" + str(epoch + 1)
net_checkpoint_path = os.path.join(exp_ckpt_dir, net_checkpoint_name)
net_state = {"epoch": epoch + 1,
"network": net.module.state_dict()}
torch.save(net_state, net_checkpoint_path)
模型参数冻结
for p in self.parameters():
p.requires_grad = False
class RESNET_YYT(nn.Module):
def init(self, model, pretrained):
super(RESNET_YYT, self).__init__()
self.resnet = model(pretrained)
for p in self.parameters():
p.requires_grad = False
self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))
打印网络层数
model_dict = torch.load('/001_net_epoch5loss0.08947')['network']
dict_name = list(model_dict)
for i, p in enumerate(dict_name):
print(i, p)
4万+

被折叠的 条评论
为什么被折叠?



