在计算机视觉领域,图像分类是一项基础且重要的任务。本文将带领大家完成一个完整的水果图像分类项目,从数据加载与预处理开始,到使用 VGG16 预训练模型构建分类器,再到模型评估与可视化,最后通过 Gradio 构建交互式演示界面。整个项目流程清晰,代码完整,适合深度学习初学者学习实践。
项目背景与数据集介绍
水果图像分类在农业自动化、智能零售、食品加工等领域有着广泛的应用前景。本项目使用的数据集为公开的水果图像数据集,包含多种常见水果的图像,可用于训练一个能够准确识别不同水果类别的深度学习模型。
数据集结构如下:
train_zip/train/:训练集图像文件夹,图像文件名格式为 "类别_编号.jpg"test_zip/test/:测试集图像文件夹,结构与训练集类似
该数据集的特点是:
- 包含多种水果类别
- 图像分辨率和质量各不相同
- 部分水果外观相似,增加了分类难度
环境配置与依赖安装
在开始项目之前,我们需要先配置好运行环境并安装必要的依赖库。项目主要依赖以下 Python 库:
import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, applications
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import gradio as gr
如果在运行时遇到ModuleNotFoundError错误,说明缺少相应的库,可以使用pip命令进行安装:
pip install opencv-python numpy pandas matplotlib tensorflow scikit-learn seaborn gradio
特别地,Gradio 库用于构建交互式网页界面,如果尚未安装,需要单独安装:
pip install gradio
数据加载与预处理
数据加载与标注解析
首先,我们需要加载图像数据并解析其标注信息。由于数据集中的图像文件名包含了类别信息(如 "apple_1.jpg" 表示苹果类别),我们可以通过解析文件名来获取图像的类别标签。
下面是数据加载函数的实现:
# 数据加载与预处理
data_dir = '/kaggle/input/fruit-images-for-object-detection' # Kaggle环境路径
# 如果在本地运行,请替换为实际路径,例如:
# data_dir = "D:/datasets/fruit_images"
train_image_dir = os.path.join(data_dir, 'train_zip', 'train')
test_image_dir = os.path.join(data_dir, 'test_zip', 'test')
def get_image_data(image_dir):
"""从图像文件夹中加载图像文件名并解析类别标签"""
image_files = []
classes = []
for filename in os.listdir(image_dir):
if filename.endswith(('.jpg', '.jpeg', '.png')):
class_name = filename.split('_')[0] # 从文件名中解析类别
image_files.append(filename)
classes.append(class_name)
return pd.DataFrame({'filename': image_files, 'class': classes})
# 加载训练集和测试集标注
train_annotations = get_image_data(train_image_dir)
test_annotations = get_image_data(test_image_dir)
样本可视化
在开始模型训练之前,先可视化一些样本图像,了解数据集的基本情况:
# 可视化样本
def visualize_samples(n=9):
"""可视化n个随机样本图像及其类别"""
plt.figure(figsize=(15, 12))
for i, (_, row) in enumerate(train_annotations.sample(n).iterrows(), 1):
img_path = os.path.join(train_image_dir, row['filename'])
img = cv2.imread(img_path)
if img is not None:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # OpenCV默认读取为BGR,转换为RGB
plt.subplot(3, 3, i)
plt.imshow(img)
plt.title(f"类别: {row['class']}")
plt.axis('off')
plt.tight_layout()
plt.show()
# 执行可视化
if __name__ == "__main__":
print(f"训练集样本数量: {len(train_annotations)}")
print(f"测试集样本数量: {len(test_annotations)}")
print(f"检测到的类别: {train_annotations['class'].unique()}")
# 调用可视化函数
visualize_samples()
运行上述代码,我们可以看到训练集中不同类别的水果图像示例,这有助于我们直观了解数据集的分布情况。
数据增强与生成器配置
为了提高模型的泛化能力,减少过拟合,我们需要对训练数据进行数据增强。数据增强通过对原始图像进行各种变换(如旋转、缩放、翻转等)来生成新的训练样本,从而扩充数据集的规模和多样性。
下面是数据生成器的配置代码:
# 数据生成器
def prepare_data():
"""配置数据生成器并生成训练、验证和测试数据迭代器"""
# 训练集数据增强配置
train_datagen = ImageDataGenerator(
rescale=1./255, # 像素值归一化
rotation_range=20, # 随机旋转角度范围
width_shift_range=0.2, # 水平平移范围
height_shift_range=0.2, # 垂直平移范围
shear_range=0.2, # 剪切变换范围
zoom_range=0.2, # 随机缩放范围
horizontal_flip=True, # 水平翻转
validation_split=0.2 # 划分20%的训练数据作为验证集
)
# 测试集仅需归一化
test_datagen = ImageDataGenerator(rescale=1./255)
# 训练数据生成器
train_generator = train_datagen.flow_from_dataframe(
dataframe=train_annotations,
directory=train_image_dir,
x_col="filename",
y_col="class",
target_size=(224, 224), # 调整图像大小为224x224
batch_size=32, # 批次大小
class_mode='categorical', # 多分类模式
subset='training' # 使用训练子集
)
# 验证数据生成器
validation_generator = train_datagen.flow_from_dataframe(
dataframe=train_annotations,
directory=train_image_dir,
x_col="filename",
y_col="class",
target_size=(224, 224),
batch_size=32,
class_mode='categorical',
subset='validation' # 使用验证子集
)
# 测试数据生成器
test_generator = test_datagen.flow_from_dataframe(
dataframe=test_annotations,
directory=test_image_dir,
x_col="filename",
y_col="class",
target_size=(224, 224),
batch_size=32,
class_mode='categorical'
)
return train_generator, validation_generator, test_generator
# 准备数据
train_generator, validation_generator, test_generator = prepare_data()
class_indices = train_generator.class_indices # 获取类别索引映射
num_classes = len(class_indices)
print(f"类别数量: {num_classes}")
print(f"类别映射: {class_indices}")
通过数据增强,我们可以在不增加实际数据量的情况下,为模型提供更多样化的训练样本,这对于提高模型的泛化能力非常重要。
基于 VGG16 的迁移学习模型构建
迁移学习简介
迁移学习是深度学习中一种常用的技术,它利用在大规模数据集上预训练好的模型作为基础,然后针对特定任务进行微调。这种方法可以大大减少训练所需的数据量和时间,同时提高模型的性能。
本项目使用 VGG16 作为基础模型,VGG16 是一种经典的卷积神经网络架构,在 ImageNet 大规模图像分类数据集上进行了预训练,具有很强的特征提取能力。
模型构建代码
下面是基于 VGG16 构建水果分类模型的代码:
# 构建模型
def build_vgg_model(num_classes):
"""基于VGG16构建水果分类模型"""
# 加载预训练的VGG16模型,不包含顶层分类层
base_model = applications.VGG16(
weights='imagenet', # 使用ImageNet预训练权重
include_top=False, # 不包含顶层分类层
input_shape=(224, 224, 3) # 输入图像尺寸
)
# 冻结基础模型的所有层,避免在微调时破坏预训练的特征
for layer in base_model.layers:
layer.trainable = False
# 构建完整的模型
model = models.Sequential([
base_model, # 基础特征提取层
layers.GlobalAveragePooling2D(), # 全局平均池化层
layers.Dense(256, activation='relu'), # 全连接层
layers.Dropout(0.5), # dropout层,防止过拟合
layers.Dense(num_classes, activation='softmax') # 分类输出层
])
# 编译模型
model.compile(
optimizer='adam', # 优化器
loss='categorical_crossentropy', # 多分类损失函数
metrics=['accuracy'] # 评估指标
)
return model
# 初始化模型
vgg_model = build_vgg_model(num_classes)
vgg_model.summary() # 打印模型结构
在上述代码中,我们首先加载了预训练的 VGG16 模型,然后在其基础上添加了自定义的分类层。通过冻结 VGG16 的所有层,我们可以利用其已经学习到的通用图像特征,只需训练新添加的分类层即可完成水果分类任务。
模型训练与优化
训练配置
为了获得更好的训练效果,我们需要配置合适的训练参数和回调函数:
# 训练配置
def lr_scheduler(epoch):
"""学习率调度函数,根据训练轮次调整学习率"""
if epoch < 10:
return 1e-3 # 前10轮使用较大的学习率
elif epoch < 15:
return 5e-4 # 中间5轮降低学习率
else:
return 1e-4 # 最后阶段使用更小的学习率
# 模型检查点,保存最佳模型
checkpoint = ModelCheckpoint(
'vgg_fruit_20epochs.h5',
monitor='val_accuracy', # 监控验证集准确率
save_best_only=True, # 只保存最佳模型
mode='max', # 最大化验证准确率
verbose=1 # 输出信息
)
# 早停策略,防止过拟合
early_stopping = EarlyStopping(
monitor='val_loss', # 监控验证集损失
patience=5, # 连续5轮没有改善则停止
restore_best_weights=True, # 恢复最佳权重
verbose=1
)
# 学习率调度器
lr_sched = LearningRateScheduler(lr_scheduler)
模型训练
完成配置后,我们可以开始训练模型了:
# 训练模型
epochs = 20 # 训练轮次
history = vgg_model.fit(
train_generator, # 训练数据生成器
steps_per_epoch=train_generator.samples // train_generator.batch_size,
validation_data=validation_generator, # 验证数据生成器
validation_steps=validation_generator.samples // validation_generator.batch_size,
epochs=epochs, # 训练轮次
callbacks=[checkpoint, early_stopping, lr_sched] # 回调函数
)
在训练过程中,模型会在每轮结束后评估验证集上的性能,并根据配置的回调函数进行模型保存、学习率调整和早停判断。
模型评估与可视化
模型评估
训练完成后,我们需要在测试集上评估模型的性能:
# 加载最佳模型
vgg_model = tf.keras.models.load_model('vgg_fruit_20epochs.h5')
# 模型评估
def evaluate_model(model, test_generator, class_indices):
"""评估模型在测试集上的性能"""
# 评估模型
test_loss, test_acc = model.evaluate(test_generator)
print(f'测试集准确率: {test_acc:.4f}')
# 获取真实标签和预测结果
y_true = test_generator.classes
y_pred = model.predict(test_generator)
y_pred_classes = np.argmax(y_pred, axis=1)
# 获取类别名称
class_names = list(class_indices.keys())
# 生成分类报告
report = classification_report(y_true, y_pred_classes, target_names=class_names)
print('分类报告:\n', report)
# 生成混淆矩阵
cm = confusion_matrix(y_true, y_pred_classes)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
plt.show()
# 执行评估
evaluate_model(vgg_model, test_generator, class_indices)
分类报告和混淆矩阵可以帮助我们全面了解模型在各个类别上的表现,识别模型的优势和不足。
训练过程可视化
可视化训练过程中的准确率和损失变化,可以帮助我们理解模型的训练情况:
# 可视化训练历史
def plot_training_history(history):
"""绘制训练过程中的准确率和损失曲线"""
plt.figure(figsize=(16, 5))
# 绘制准确率曲线
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('模型准确率')
plt.ylabel('准确率')
plt.xlabel('轮次')
plt.legend(loc='lower right')
# 绘制损失曲线
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('模型损失')
plt.ylabel('损失')
plt.xlabel('轮次')
plt.legend(loc='upper right')
plt.tight_layout()
plt.show()
# 显示训练历史
plot_training_history(history)
通过观察训练曲线,我们可以判断模型是否过拟合,训练是否收敛,以及学习率设置是否合理等。
构建交互式演示界面
为了方便用户体验我们的水果分类模型,我们可以使用 Gradio 库构建一个简单的交互式网页界面:
# 构建Gradio演示界面
def predict_fruit(image):
"""对输入图像进行水果分类预测"""
if image is None:
return "请上传一张水果图像"
# 图像预处理
img = cv2.resize(image, (224, 224))
img = img / 255.0
img = np.expand_dims(img, axis=0)
# 预测
predictions = vgg_model.predict(img)
pred_class = np.argmax(predictions[0])
class_name = list(class_indices.keys())[pred_class]
confidence = predictions[0][pred_class]
return f"预测结果: {class_name}, 置信度: {confidence:.4f}"
# 创建Gradio界面
iface = gr.Interface(
fn=predict_fruit,
inputs=gr.inputs.Image(shape=(224, 224), type="numpy"),
outputs="text",
title="水果图像分类器",
description="上传一张水果图像,模型将预测其类别",
examples=[["apple.jpg"], ["banana.jpg"], ["orange.jpg"]]
)
# 启动界面
iface.launch()
运行上述代码后,会生成一个可访问的链接,打开链接即可在浏览器中使用我们的水果分类模型。用户可以上传水果图像,模型会返回预测的类别和置信度。
项目优化与扩展方向
-
模型优化:
- 可以尝试解冻 VGG16 的顶层卷积层进行微调,进一步提高模型性能
- 尝试使用其他预训练模型如 ResNet、EfficientNet 等进行比较
- 调整模型的超参数,如学习率、批次大小、正则化强度等
-
数据增强优化:
- 根据水果图像的特点,设计更有针对性的数据增强策略
- 考虑使用 Mixup、CutMix 等高级数据增强技术
-
模型轻量化:
- 对于实际应用场景,可以考虑使用模型压缩技术减小模型体积
- 尝试使用 TensorFlow Lite 将模型转换为适合移动设备部署的格式
-
功能扩展:
- 在分类基础上增加目标检测功能,实现水果的定位与分类
- 构建完整的水果识别系统,包括图像采集
599

被折叠的 条评论
为什么被折叠?



