import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import efficientnet_b0 as create_model
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
img_size = {"B0": 224,
"B1": 240,
"B2": 260,
"B3": 300,
"B4": 380,
"B5": 456,
"B6": 528,
"B7": 600}
num_model = "B0"
data_transform = transforms.Compose(
[transforms.Resize(img_size[num_model]),
transforms.CenterCrop(img_size[num_model]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# 建立循环,将目标目录下的图片路径取出
# 需要改动的有两个部分,输入文件夹的路径、预测结果与文件夹对应结果的对比
path = r'C:\Users\sun\Desktop\b_classification\classification-pytorch-main\color_datasets\train\No-Anomaly'
# 获取该路径下所有图片
filelist = os.listdir(path)
a = 1
for files in filelist:
# load image
img_path = os.path.join(path, files)
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
json_file = open(json_path, "r")
class_indict = json.load(json_file)
# create model
model = create_model(num_classes=5).to(device)
# load model weights
model_weight_path = "./weights/model-29.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
if predict_cla == "No-Anomaly":
a = a + 1
else:
a = a
# 这里的result就是该路径下的图片得到的准确率
result = a / 2000
print(result)
# print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
# predict[predict_cla].numpy())
# 不需要展示运行结果所以该部分注释掉
# plt.title(print_res)
# for i in range(len(predict)):
# print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
# predict[i].numpy()))
# plt.show()
if __name__ == '__main__':
main()
修改predict实现acc计算
最新推荐文章于 2024-11-24 18:00:00 发布
该博客介绍了一个使用EfficientNet_B0模型进行图像分类的Python代码实现。它加载预训练权重,对指定目录下的图像进行预测,并计算No-Anomaly类别的准确率。代码包括数据预处理、模型创建、权重加载和预测过程。
595

被折叠的 条评论
为什么被折叠?



