图像分类批量图片预测(pth)

import os
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont

def predict_image_class(img_path, checkpoints_path):
    # 加载模型
    model = torch.load(checkpoints_path)

    # 预处理
    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    img_pil = Image.open(img_path)
    input_img = test_transform(img_pil)  # 预处理
    input_img = input_img.unsqueeze(0)  # 添加 batch 维度

    # 有 GPU 就用 GPU,没有就用 CPU
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = model.eval().to(device)
    input_img = input_img.to(device)

    # 执行前向预测,得到所有类别的 logit 预测分数
    with torch.no_grad():
        pred_logits = model(input_img)
        pred_softmax = F.softmax(pred_logits, dim=1)  # 对 logit 分数做 softmax 运算

    return pred_softmax

def add_classification_results_to_image(img_pil, pred_softmax, idx_to_labels, n=2):
    top_n = torch.topk(pred_softmax, n)  # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze()  # 解析出类别
    confs = top_n[0].cpu().detach().numpy().squeeze()  # 解析出置信度

    font = ImageFont.load_default()  # 使用PIL默认字体
    draw = ImageDraw.Draw(img_pil)

    for i in range(n):
        class_name = idx_to_labels[pred_ids[i]]  # 获取类名
        confidence = confs[i] * 100  # 获取置信度
        text = '{:<15} {:>.4f}'.format(class_name, confidence)  # 格式化文本包括类名和置信度

        text_x = 10
        text_y = 10 + (30 + 5) * i
        draw.text((text_x, text_y), text, font=font, fill=(255, 0, 0, 1))

    return img_pil

def batch_predict_images(folder_path, checkpoints_path, save_csv_path, save_images_path):
    idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()

    images = os.listdir(folder_path)
    results = []

    for img_file in images:
        img_path = os.path.join(folder_path, img_file)

        result = predict_image_class(img_path, checkpoints_path)
        results.append(result)

        img_pil = Image.open(img_path)
        img_with_results = add_classification_results_to_image(img_pil, result, idx_to_labels, n=2)
        save_img_path = os.path.join(save_images_path, img_file)
        img_with_results.save(save_img_path)

    df = pd.DataFrame([res[0].cpu().detach().numpy() for res in results], columns=list(idx_to_labels.values()))
    df.to_csv(save_csv_path, index=False)

if __name__ == '__main__':
    checkpoints_path = "checkponits/resnet_best.pth"
    folder_path = 'E:\Python_Project\deep-learning-for-image-processing-master\data_set\\flower_data\\val\\roses'
    save_csv_path = 'Apredicted_results.csv'
    save_images_path = 'A'

    batch_predict_images(folder_path, checkpoints_path, save_csv_path, save_images_path)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Make_magic

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值