import os
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
def predict_image_class(img_path, checkpoints_path):
# 加载模型
model = torch.load(checkpoints_path)
# 预处理
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_pil = Image.open(img_path)
input_img = test_transform(img_pil) # 预处理
input_img = input_img.unsqueeze(0) # 添加 batch 维度
# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.eval().to(device)
input_img = input_img.to(device)
# 执行前向预测,得到所有类别的 logit 预测分数
with torch.no_grad():
pred_logits = model(input_img)
pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
return pred_softmax
def add_classification_results_to_image(img_pil, pred_softmax, idx_to_labels, n=2):
top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度
font = ImageFont.load_default() # 使用PIL默认字体
draw = ImageDraw.Draw(img_pil)
for i in range(n):
class_name = idx_to_labels[pred_ids[i]] # 获取类名
confidence = confs[i] * 100 # 获取置信度
text = '{:<15} {:>.4f}'.format(class_name, confidence) # 格式化文本包括类名和置信度
text_x = 10
text_y = 10 + (30 + 5) * i
draw.text((text_x, text_y), text, font=font, fill=(255, 0, 0, 1))
return img_pil
def batch_predict_images(folder_path, checkpoints_path, save_csv_path, save_images_path):
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()
images = os.listdir(folder_path)
results = []
for img_file in images:
img_path = os.path.join(folder_path, img_file)
result = predict_image_class(img_path, checkpoints_path)
results.append(result)
img_pil = Image.open(img_path)
img_with_results = add_classification_results_to_image(img_pil, result, idx_to_labels, n=2)
save_img_path = os.path.join(save_images_path, img_file)
img_with_results.save(save_img_path)
df = pd.DataFrame([res[0].cpu().detach().numpy() for res in results], columns=list(idx_to_labels.values()))
df.to_csv(save_csv_path, index=False)
if __name__ == '__main__':
checkpoints_path = "checkponits/resnet_best.pth"
folder_path = 'E:\Python_Project\deep-learning-for-image-processing-master\data_set\\flower_data\\val\\roses'
save_csv_path = 'Apredicted_results.csv'
save_images_path = 'A'
batch_predict_images(folder_path, checkpoints_path, save_csv_path, save_images_path)
图像分类批量图片预测(pth)
于 2023-11-28 16:01:10 首次发布