import torch
from torch import nn
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(3, 32, 3, 1),
nn.Conv2d(32, 64, 3, 1)
)
def forward(self, x):
x = self.layer(x)
return x
def save(net, input, save_path):
net.eval()
traced_script_module = torch.jit.trace(net, input)
traced_script_module.save(save_path)
def load(model_path):
return torch.jit.load(model_path)
if __name__ == '__main__':
input = torch.Tensor(1, 3, 100, 100)
model_path = 'model.pt'
# # net.load_state_dict(torch.load(model_path))
# # save(net, input, './model.pt')
# net = model()
# out = net(input)
# print(out.size())
# save(net, input, 'model.pt')
# # net.load_state_dict(torch.load(model_path))
net = load(model_path)
out = net(input)
print(out.size())
【Pytorch】torch.jit
最新推荐文章于 2025-09-19 03:38:13 发布
本文介绍了一个基于PyTorch的卷积神经网络模型的实现与序列化过程。通过定义一个包含两层卷积层的自定义模型,并演示了如何使用PyTorch的jit模块将模型转换为可序列化的形式,以便于模型的保存和加载。此外,还展示了如何加载已序列化的模型并进行预测。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch
Cuda
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理
1118

被折叠的 条评论
为什么被折叠?



