输入图像 一只猫预测结果 类别3
类别对应索引
完整代码
import torch
import torchvision.transforms
from torch import nn
from PIL import Image
from torch.nn import Conv2d,MaxPool2d,Flatten,Linear,Sequential
img_path = "./cat.png"
image = Image.open(img_path)
# 将一个4通道转化为rgb三通道
img = image.convert("RGB")
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
torchvision.transforms.ToTensor()],)
img = transform(image)
print(img.shape)
# img = img.reshape(1,3,32,32)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.model1 = Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding= 2),
MaxPool2d(2),
Flatten(),
Linear(1024,64),
Linear(64, 10)
)
def forward(self, x):
x = self.model1(x)
return x
model = torch.load("./models/mymodule_9.pth")
# model 要求输入是4维
img = torch.reshape(img,(1,3,32,32))
model.eval()
with torch.no_grad():
output = model(img)
print(output)
print(output.argmax(1))