def load_network(self, load_path, network, strict=True):
if isinstance(network, nn.DataParallel):
network = network.module
model_dict = torch.load(load_path)
filtered = {k: v for k, v in model_dict.items() if 'num_batches_tracked' not in k}
network.load_state_dict(filtered, strict=strict)
# network.load_state_dict(torch.load(load_path), strict=strict)
pytorch 0.4版本加载0.4.1 1.0更高版本的model
最新推荐文章于 2021-09-23 21:00:59 发布
本文介绍了一种从指定路径加载神经网络模型的方法,特别关注于如何处理DataParallel模型和过滤不必要的参数跟踪信息。通过使用PyTorch库,文章详细解释了如何自定义加载过程,确保模型与保存状态的一致性。
172

被折叠的 条评论
为什么被折叠?



