自制AI的图片分类工具- 智能图片分类器

亲爱的科技爱好者们,今天我要为大家介绍一款自制的小工具 ——智能图片分类器!这款软件能够通过部署你的模型,就像是一位AI魔法师,能够瞬间识别各种图片,让你的图片管理工作变得轻松又有趣。

工具功能


单图识别:只需轻点"选择图片"按钮,选中你想识别的图片,然后点击"开始识别"。瞧!我们的AI魔法师会立刻开始施法,几秒钟后就能告诉你这是什么图片,而且还会给出信心指数哦!

批量预测:如果你有一整个文件夹的图片需要分类,别担心!点击"批量预测"按钮,选择你的图片文件夹,然后就可以坐等魔法发生了。AI魔法师会自动处理所有图片,最后还会给你一份详细的成绩单,告诉你它猜对了多少张。

炫酷界面:我们的AI魔法师不仅法力高强,还很注重外表呢!软件界面采用了优雅的渐变背景,看起来就像魔法水晶球一样神秘又迷人。

实时反馈:当AI魔法师在施法时,你会看到一个可爱的加载动画,让你知道魔法正在进行中。识别结果会以不同的颜色显示,绿色代表非常自信,橙色表示有点犹豫,红色则意味着AI魔法师也不太确定呢!

全能选手:我们的AI魔法师学识渊博,能识别多种类别的图片。无论是动物、植物、风景还是日常物品,它都能应对自如。


下面看看我怎么来实现:其中需要注意的是模型实例化和模型权重的加载需要改成自己的

import tkinter as tk
from tkinter import filedialog, ttk
from PIL import Image, ImageTk, ImageDraw
import torch
from torchvision import transforms
import os
from models.faster_vit import StudentFasterViT
import threading
import time

class ModelPredictorApp:
    def __init__(self, master):
        self.master = master
        master.title("智能图片分类器 v1.0")
        master.geometry("900x700")

        # 创建渐变背景
        self.create_gradient_background()

        self.class_labels = self.get_class_labels('dataset/train')

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.load_model('模型权重.pth') #请改成你的模型权重

        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        self.create_widgets()
        self.is_predicting = False

        # 定义控制面板和按钮样式
        control_frame = tk.Frame(self.master, bg='#E6F3FF')
        control_frame.pack(pady=10)

        button_style = {'font': ('Segoe UI', 12), 'bg': '#4CAF50', 'fg': 'white', 'padx': 10, 'pady': 5, 'bd': 0}

        # 添加批量预测按钮
        self.btn_batch_predict = tk.Button(control_frame, text="批量预测", command=self.batch_predict, **button_style)
        self.btn_batch_predict.pack(side=tk.LEFT, padx=5)

    def create_gradient_background(self):
        gradient_image = Image.new('RGB', (900, 700), color='#FFFFFF')
        draw = ImageDraw.Draw(gradient_image)
        for y in range(700):
            r = int(255 * (1 - y / 700))
            g = int(200 * (1 - y / 700))
            b = int(255 * (1 - y / 700))
            draw.line([(0, y), (900, y)], fill=(r, g, b))
        self.bg_image = ImageTk.PhotoImage(gradient_image)
        self.bg_label = tk.Label(self.master, image=self.bg_image)
        self.bg_label.place(x=0, y=0, relwidth=1, relheight=1)

    def create_widgets(self):
        # 标题
        title_label = tk.Label(self.master, text="智能图片分类器", font=("Segoe UI", 30, "bold"), bg='#E6F3FF',
                               fg='#4A4A4A')
        title_label.pack(pady=20)

        # 控制面板
        self.control_frame = tk.Frame(self.master, bg='#E6F3FF')
        self.control_frame.pack(pady=10)

        button_style = {'font': ('Segoe UI', 12), 'bg': '#4CAF50', 'fg': 'white', 'padx': 10, 'pady': 5, 'bd': 0}

        self.btn_load = tk.Button(self.control_frame, text="选择图片", command=self.load_image, **button_style)
        self.btn_load.pack(side=tk.LEFT, padx=5)

        self.btn_predict = tk.Button(self.control_frame, text="开始识别", command=self.start_prediction, **button_style)
        self.btn_predict.pack(side=tk.LEFT, padx=5)

        # 添加批量预测按钮
        self.btn_batch_predict = tk.Button(self.control_frame, text="批量预测", command=self.batch_predict,
                                           **button_style)
        self.btn_batch_predict.pack(side=tk.LEFT, padx=5)

        # 图片显示区域
        self.image_frame = tk.Frame(self.master, bg='#E6F3FF', bd=2, relief=tk.GROOVE)
        self.image_frame.pack(pady=10)
        self.image_label = tk.Label(self.image_frame, bg='#FFFFFF')
        self.image_label.pack(padx=10, pady=10)

        # 加载动画
        self.loading_label = tk.Label(self.master, text="", font=('Segoe UI', 14), bg='#E6F3FF', fg='#4A4A4A')
        self.loading_label.pack()

        # 结果显示区域
        self.result_frame = tk.Frame(self.master, bg='#E6F3FF', bd=2, relief=tk.GROOVE)
        self.result_frame.pack(pady=10)

        self.result_text = tk.Text(self.result_frame, height=4, width=50, font=('Segoe UI', 12), bg='#FFFFFF',
                                   fg='#4A4A4A')
        self.result_text.pack(padx=10, pady=10)
        self.result_text.insert(tk.END, "识别结果将显示在这里...")
        self.result_text.config(state=tk.DISABLED)

        # 状态栏
        self.status_bar = tk.Label(self.master, relief=tk.SUNKEN, anchor=tk.W, bg='#E6F3FF', font=('Segoe UI', 10))
        self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)

        # 状态标签
        self.status_label = tk.Label(self.master, text="就绪", font=('Segoe UI', 12), bg='#E6F3FF', fg='#4A4A4A')
        self.status_label.pack()

    def load_model(self, model_path):
        """加载训练好的模型"""
        model = StudentFasterViT(dim=32, in_dim=16, depths=[2, 2, 4, 2], num_heads=[1, 2, 4, 8],
                                 window_size=[7, 7, 7, 7], ct_size=1, mlp_ratio=2).to(self.device)  #模型实例化可以改成你们的模型
        model.load_state_dict(torch.load(model_path, map_location=self.device))
        model.eval()
        return model

    def get_class_labels(self, train_dir):
        """从训练目录获取类别标签"""
        return sorted([d.name for d in os.scandir(train_dir) if d.is_dir()])

    def load_image(self):
        """加载图片并显示"""
        file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg *.jpeg *.png")])
        if file_path:
            self.current_image = Image.open(file_path)
            self.display_image(self.current_image)
            self.update_status(f"已加载图片: {os.path.basename(file_path)}")

    def display_image(self, image):
        """在GUI中显示图片"""
        image = image.resize((300, 300), Image.Resampling.LANCZOS)
        photo = ImageTk.PhotoImage(image)
        self.image_label.config(image=photo)
        self.image_label.image = photo  # 保持引用

    def start_prediction(self):
        if not hasattr(self, 'current_image'):
            self.show_error("请先选择图片!")
            return

        if self.is_predicting:
            return

        self.is_predicting = True
        self.btn_predict.config(state=tk.DISABLED)
        self.loading_animation()
        threading.Thread(target=self.predict, daemon=True).start()

    def update_status(self, message):
        self.status_label.config(text=message)

    def loading_animation(self):
        def animate():
            if not self.is_predicting:
                self.loading_label.config(text="")
                return
            for i in range(3):
                if not self.is_predicting:
                    self.loading_label.config(text="")
                    return
                self.loading_label.config(text="识别中" + "." * (i + 1))
                self.master.update()
                time.sleep(0.5)
            self.master.after(100, animate)

        animate()

    def predict(self):
        input_tensor = self.transform(self.current_image)
        input_batch = input_tensor.unsqueeze(0).to(self.device)

        with torch.no_grad():
            class_output, _ = self.model(input_batch)
            probabilities = torch.softmax(class_output, dim=1)

        pred_index = torch.argmax(probabilities).item()
        confidence = probabilities[0][pred_index].item()
        pred_label = self.class_labels[pred_index]

        self.master.after(0, self.update_result, pred_label, confidence)

    def update_result(self, pred_label, confidence):
        self.is_predicting = False
        self.result_text.config(state=tk.NORMAL)
        self.result_text.delete(1.0, tk.END)

        if confidence > 0.7:
            color = "#4CAF50"  # Green
        elif confidence > 0.5:
            color = "#FFA500"  # Orange
        else:
            color = "#FF0000"  # Red

        self.result_text.insert(tk.END, f"识别结果:{pred_label}\n", "result")
        self.result_text.insert(tk.END, f"置信度:{confidence:.2%}\n", f"confidence")
        self.result_text.insert(tk.END, f"支持类别:{', '.join(self.class_labels)}")

        self.result_text.tag_config("result", font=("Segoe UI", 14, "bold"))
        self.result_text.tag_config("confidence", foreground=color)

        self.result_text.config(state=tk.DISABLED)
        self.update_status("预测完成!")
        self.btn_predict.config(state=tk.NORMAL)
        self.loading_label.config(text="")

    def show_error(self, message):
        tk.messagebox.showerror("错误", message)

    def batch_predict(self):
        folder_path = filedialog.askdirectory(title="选择包含图片的文件夹")
        if not folder_path:
            return

        self.is_predicting = True
        self.btn_batch_predict.config(state=tk.DISABLED)
        self.loading_animation()
        threading.Thread(target=self.run_batch_prediction, args=(folder_path,), daemon=True).start()

    def run_batch_prediction(self, folder_path):
        total_images = 0
        correct_predictions = 0

        for root, dirs, files in os.walk(folder_path):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    total_images += 1
                    file_path = os.path.join(root, file)
                    true_label = os.path.basename(root)

                    image = Image.open(file_path)
                    input_tensor = self.transform(image)
                    input_batch = input_tensor.unsqueeze(0).to(self.device)

                    with torch.no_grad():
                        class_output, _ = self.model(input_batch)
                        probabilities = torch.softmax(class_output, dim=1)

                    pred_index = torch.argmax(probabilities).item()
                    pred_label = self.class_labels[pred_index]

                    if pred_label == true_label:
                        correct_predictions += 1

                    self.master.after(0, self.update_status, f"正在处理: {file}")

        accuracy = correct_predictions / total_images if total_images > 0 else 0
        self.master.after(0, self.show_batch_result, total_images, correct_predictions, accuracy)

    def show_batch_result(self, total_images, correct_predictions, accuracy):
        self.is_predicting = False
        self.btn_batch_predict.config(state=tk.NORMAL)
        self.loading_label.config(text="")

        result_message = f"批量预测完成!\n总图片数: {total_images}\n正确预测数: {correct_predictions}\n准确率: {accuracy:.2%}"

        self.result_text.config(state=tk.NORMAL)
        self.result_text.delete(1.0, tk.END)
        self.result_text.insert(tk.END, result_message)
        self.result_text.config(state=tk.DISABLED)

        self.update_status("批量预测完成")

if __name__ == "__main__":
    root = tk.Tk()
    app = ModelPredictorApp(root)
    root.mainloop()

该程序可以打包成.exe,其界面长成这个样子:

点击选择图片后开始识别可进行单个图片的预测

 

点击批量预测后可以选择测试集来检验准确率

 

这样,一个识别工具的最基础的功能就实现了,当然,更多的内容请等待后续更新! 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值