import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet
# 图片处理大小核标准化
transform = transforms.Compose(
[transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 实例化 LeNet
net = LeNet()
# 载入权重文件
net.load_state_dict(torch.load('Lenet.pth'))
im = Image.open('1.jpg') # 保存到目录下一张图片用来预测
im = transforms(im) # 预处理为 [C, H, W]
im = torch.unsqueeze(im, dim=0) # 增加新维度 tensor格式,用来网络中传播[N, C, H, W]
with torch.no_grad():
outputs = net(im)
predict = torch.max(outputs, dim=1)[1].data.numpy()
# 或者用softmax函数
# predict = torch.softmax(outputs, dim=1)
print(classes[int(predict)])
预测模型
最新推荐文章于 2024-08-03 15:44:27 发布
本文介绍了一种基于PyTorch实现的LeNet模型进行图像分类的方法。文章首先定义了图像预处理步骤,包括调整图像大小、转换为张量及标准化处理等。随后介绍了如何加载训练好的模型权重,并演示了如何利用该模型对单张图片进行预测。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch
Cuda
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理
14万+

被折叠的 条评论
为什么被折叠?



