亲爱的科技爱好者们,今天我要为大家介绍一款自制的小工具 ——智能图片分类器!这款软件能够通过部署你的模型,就像是一位AI魔法师,能够瞬间识别各种图片,让你的图片管理工作变得轻松又有趣。
工具功能
单图识别:只需轻点"选择图片"按钮,选中你想识别的图片,然后点击"开始识别"。瞧!我们的AI魔法师会立刻开始施法,几秒钟后就能告诉你这是什么图片,而且还会给出信心指数哦!
批量预测:如果你有一整个文件夹的图片需要分类,别担心!点击"批量预测"按钮,选择你的图片文件夹,然后就可以坐等魔法发生了。AI魔法师会自动处理所有图片,最后还会给你一份详细的成绩单,告诉你它猜对了多少张。
炫酷界面:我们的AI魔法师不仅法力高强,还很注重外表呢!软件界面采用了优雅的渐变背景,看起来就像魔法水晶球一样神秘又迷人。
实时反馈:当AI魔法师在施法时,你会看到一个可爱的加载动画,让你知道魔法正在进行中。识别结果会以不同的颜色显示,绿色代表非常自信,橙色表示有点犹豫,红色则意味着AI魔法师也不太确定呢!
全能选手:我们的AI魔法师学识渊博,能识别多种类别的图片。无论是动物、植物、风景还是日常物品,它都能应对自如。
下面看看我怎么来实现:其中需要注意的是模型实例化和模型权重的加载需要改成自己的
import tkinter as tk
from tkinter import filedialog, ttk
from PIL import Image, ImageTk, ImageDraw
import torch
from torchvision import transforms
import os
from models.faster_vit import StudentFasterViT
import threading
import time
class ModelPredictorApp:
def __init__(self, master):
self.master = master
master.title("智能图片分类器 v1.0")
master.geometry("900x700")
# 创建渐变背景
self.create_gradient_background()
self.class_labels = self.get_class_labels('dataset/train')
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.load_model('模型权重.pth') #请改成你的模型权重
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.create_widgets()
self.is_predicting = False
# 定义控制面板和按钮样式
control_frame = tk.Frame(self.master, bg='#E6F3FF')
control_frame.pack(pady=10)
button_style = {'font': ('Segoe UI', 12), 'bg': '#4CAF50', 'fg': 'white', 'padx': 10, 'pady': 5, 'bd': 0}
# 添加批量预测按钮
self.btn_batch_predict = tk.Button(control_frame, text="批量预测", command=self.batch_predict, **button_style)
self.btn_batch_predict.pack(side=tk.LEFT, padx=5)
def create_gradient_background(self):
gradient_image = Image.new('RGB', (900, 700), color='#FFFFFF')
draw = ImageDraw.Draw(gradient_image)
for y in range(700):
r = int(255 * (1 - y / 700))
g = int(200 * (1 - y / 700))
b = int(255 * (1 - y / 700))
draw.line([(0, y), (900, y)], fill=(r, g, b))
self.bg_image = ImageTk.PhotoImage(gradient_image)
self.bg_label = tk.Label(self.master, image=self.bg_image)
self.bg_label.place(x=0, y=0, relwidth=1, relheight=1)
def create_widgets(self):
# 标题
title_label = tk.Label(self.master, text="智能图片分类器", font=("Segoe UI", 30, "bold"), bg='#E6F3FF',
fg='#4A4A4A')
title_label.pack(pady=20)
# 控制面板
self.control_frame = tk.Frame(self.master, bg='#E6F3FF')
self.control_frame.pack(pady=10)
button_style = {'font': ('Segoe UI', 12), 'bg': '#4CAF50', 'fg': 'white', 'padx': 10, 'pady': 5, 'bd': 0}
self.btn_load = tk.Button(self.control_frame, text="选择图片", command=self.load_image, **button_style)
self.btn_load.pack(side=tk.LEFT, padx=5)
self.btn_predict = tk.Button(self.control_frame, text="开始识别", command=self.start_prediction, **button_style)
self.btn_predict.pack(side=tk.LEFT, padx=5)
# 添加批量预测按钮
self.btn_batch_predict = tk.Button(self.control_frame, text="批量预测", command=self.batch_predict,
**button_style)
self.btn_batch_predict.pack(side=tk.LEFT, padx=5)
# 图片显示区域
self.image_frame = tk.Frame(self.master, bg='#E6F3FF', bd=2, relief=tk.GROOVE)
self.image_frame.pack(pady=10)
self.image_label = tk.Label(self.image_frame, bg='#FFFFFF')
self.image_label.pack(padx=10, pady=10)
# 加载动画
self.loading_label = tk.Label(self.master, text="", font=('Segoe UI', 14), bg='#E6F3FF', fg='#4A4A4A')
self.loading_label.pack()
# 结果显示区域
self.result_frame = tk.Frame(self.master, bg='#E6F3FF', bd=2, relief=tk.GROOVE)
self.result_frame.pack(pady=10)
self.result_text = tk.Text(self.result_frame, height=4, width=50, font=('Segoe UI', 12), bg='#FFFFFF',
fg='#4A4A4A')
self.result_text.pack(padx=10, pady=10)
self.result_text.insert(tk.END, "识别结果将显示在这里...")
self.result_text.config(state=tk.DISABLED)
# 状态栏
self.status_bar = tk.Label(self.master, relief=tk.SUNKEN, anchor=tk.W, bg='#E6F3FF', font=('Segoe UI', 10))
self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)
# 状态标签
self.status_label = tk.Label(self.master, text="就绪", font=('Segoe UI', 12), bg='#E6F3FF', fg='#4A4A4A')
self.status_label.pack()
def load_model(self, model_path):
"""加载训练好的模型"""
model = StudentFasterViT(dim=32, in_dim=16, depths=[2, 2, 4, 2], num_heads=[1, 2, 4, 8],
window_size=[7, 7, 7, 7], ct_size=1, mlp_ratio=2).to(self.device) #模型实例化可以改成你们的模型
model.load_state_dict(torch.load(model_path, map_location=self.device))
model.eval()
return model
def get_class_labels(self, train_dir):
"""从训练目录获取类别标签"""
return sorted([d.name for d in os.scandir(train_dir) if d.is_dir()])
def load_image(self):
"""加载图片并显示"""
file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg *.jpeg *.png")])
if file_path:
self.current_image = Image.open(file_path)
self.display_image(self.current_image)
self.update_status(f"已加载图片: {os.path.basename(file_path)}")
def display_image(self, image):
"""在GUI中显示图片"""
image = image.resize((300, 300), Image.Resampling.LANCZOS)
photo = ImageTk.PhotoImage(image)
self.image_label.config(image=photo)
self.image_label.image = photo # 保持引用
def start_prediction(self):
if not hasattr(self, 'current_image'):
self.show_error("请先选择图片!")
return
if self.is_predicting:
return
self.is_predicting = True
self.btn_predict.config(state=tk.DISABLED)
self.loading_animation()
threading.Thread(target=self.predict, daemon=True).start()
def update_status(self, message):
self.status_label.config(text=message)
def loading_animation(self):
def animate():
if not self.is_predicting:
self.loading_label.config(text="")
return
for i in range(3):
if not self.is_predicting:
self.loading_label.config(text="")
return
self.loading_label.config(text="识别中" + "." * (i + 1))
self.master.update()
time.sleep(0.5)
self.master.after(100, animate)
animate()
def predict(self):
input_tensor = self.transform(self.current_image)
input_batch = input_tensor.unsqueeze(0).to(self.device)
with torch.no_grad():
class_output, _ = self.model(input_batch)
probabilities = torch.softmax(class_output, dim=1)
pred_index = torch.argmax(probabilities).item()
confidence = probabilities[0][pred_index].item()
pred_label = self.class_labels[pred_index]
self.master.after(0, self.update_result, pred_label, confidence)
def update_result(self, pred_label, confidence):
self.is_predicting = False
self.result_text.config(state=tk.NORMAL)
self.result_text.delete(1.0, tk.END)
if confidence > 0.7:
color = "#4CAF50" # Green
elif confidence > 0.5:
color = "#FFA500" # Orange
else:
color = "#FF0000" # Red
self.result_text.insert(tk.END, f"识别结果:{pred_label}\n", "result")
self.result_text.insert(tk.END, f"置信度:{confidence:.2%}\n", f"confidence")
self.result_text.insert(tk.END, f"支持类别:{', '.join(self.class_labels)}")
self.result_text.tag_config("result", font=("Segoe UI", 14, "bold"))
self.result_text.tag_config("confidence", foreground=color)
self.result_text.config(state=tk.DISABLED)
self.update_status("预测完成!")
self.btn_predict.config(state=tk.NORMAL)
self.loading_label.config(text="")
def show_error(self, message):
tk.messagebox.showerror("错误", message)
def batch_predict(self):
folder_path = filedialog.askdirectory(title="选择包含图片的文件夹")
if not folder_path:
return
self.is_predicting = True
self.btn_batch_predict.config(state=tk.DISABLED)
self.loading_animation()
threading.Thread(target=self.run_batch_prediction, args=(folder_path,), daemon=True).start()
def run_batch_prediction(self, folder_path):
total_images = 0
correct_predictions = 0
for root, dirs, files in os.walk(folder_path):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg')):
total_images += 1
file_path = os.path.join(root, file)
true_label = os.path.basename(root)
image = Image.open(file_path)
input_tensor = self.transform(image)
input_batch = input_tensor.unsqueeze(0).to(self.device)
with torch.no_grad():
class_output, _ = self.model(input_batch)
probabilities = torch.softmax(class_output, dim=1)
pred_index = torch.argmax(probabilities).item()
pred_label = self.class_labels[pred_index]
if pred_label == true_label:
correct_predictions += 1
self.master.after(0, self.update_status, f"正在处理: {file}")
accuracy = correct_predictions / total_images if total_images > 0 else 0
self.master.after(0, self.show_batch_result, total_images, correct_predictions, accuracy)
def show_batch_result(self, total_images, correct_predictions, accuracy):
self.is_predicting = False
self.btn_batch_predict.config(state=tk.NORMAL)
self.loading_label.config(text="")
result_message = f"批量预测完成!\n总图片数: {total_images}\n正确预测数: {correct_predictions}\n准确率: {accuracy:.2%}"
self.result_text.config(state=tk.NORMAL)
self.result_text.delete(1.0, tk.END)
self.result_text.insert(tk.END, result_message)
self.result_text.config(state=tk.DISABLED)
self.update_status("批量预测完成")
if __name__ == "__main__":
root = tk.Tk()
app = ModelPredictorApp(root)
root.mainloop()
该程序可以打包成.exe,其界面长成这个样子:
点击选择图片后开始识别可进行单个图片的预测
点击批量预测后可以选择测试集来检验准确率
这样,一个识别工具的最基础的功能就实现了,当然,更多的内容请等待后续更新!