import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
from time import time
from matplotlib import rcParams
rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
rcParams['axes.unicode_minus'] = False # 避免负号显示问题
# 数据集加载函数
def data_load(data_dir, test_data_dir, img_height, img_width, batch_size):
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
label_mode='categorical',
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
test_data_dir,
label_mode='categorical',
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
return train_ds, val_ds, class_names
# 构建MobileNet模型
def model_load(IMG_SHAPE=(224, 224, 3), class_num=15):
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
base_model.trainable = False
model = tf.keras.models.Sequential([
tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(class_num, activation='softmax')
])
model.summary()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
return model
# 绘制训练过程中的损失和准确率
def show_loss_acc(history):
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()), 1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.savefig('results/results_mobilenet.png', dpi=100)
# 评估模型函数
def evaluate_model(model, val_ds, class_names):
y_true = []
y_pred = []
for images, labels in val_ds:
predictions = model.predict(images)
y_pred.extend(np.argmax(predictions, axis=1))
y_true.extend(np.argmax(labels, axis=1))
cm = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:")
print(cm)
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
# 自定义回调函数,用于动态打印评价指标
class PrintEvalCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print(f"Epoch {epoch + 1}: "
f"loss={logs['loss']:.4f}, accuracy={logs['accuracy']:.4f}, "
f"val_loss={logs['val_loss']:.4f}, val_accuracy={logs['val_accuracy']:.4f}")
# 主训练函数
def train(epochs):
begin_time = time()
train_ds, val_ds, class_names = data_load(
"D:/1/WeChat Files/wxid_c6n47k03okeu22/FileStorage/File/2024-05/vegetables_tf2.3-master/data/train",
"D:/1/WeChat Files/wxid_c6n47k03okeu22/FileStorage/File/2024-05/vegetables_tf2.3-master/data/valid",
224, 224, 16
)
print("Class Names:", class_names)
model = model_load(class_num=len(class_names))
# 使用自定义回调函数
eval_callback = PrintEvalCallback()
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs, callbacks=[eval_callback])
model.save("models/mobilenet_fv.h5")
end_time = time()
print('Training time:', end_time - begin_time, "seconds")
show_loss_acc(history)
evaluate_model(model, val_ds, class_names)
if __name__ == '__main__':
train(epochs=10)import tensorflow as tf
from PyQt5.QtGui import QIcon, QFont, QPixmap
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QTabWidget, QFileDialog, QMessageBox
import sys
import cv2
from PIL import Image
import numpy as np
import shutil
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowIcon(QIcon('images/logo.png')) # 确保路径正确
self.setWindowTitle('蔬菜识别系统')
self.model = tf.keras.models.load_model("models/mobilenet_fv.h5")
self.to_predict_name = "images/tim9.jpeg"
self.class_names = ['豆', '苦瓜', '葫芦瓜', '茄子', '西兰花', '胡萝卜', '包菜', '辣椒', '西红柿', '花菜', '土豆', '黄瓜','番木瓜', '南瓜', '白萝卜', ]
self.resize(900, 700)
self.initUI()
def initUI(self):
main_widget = QWidget()
main_layout = QHBoxLayout()
font = QFont('楷体', 15)
left_widget = QWidget()
left_layout = QVBoxLayout()
img_title = QLabel("样本")
img_title.setFont(font)
img_title.setAlignment(Qt.AlignCenter)
self.img_label = QLabel()
# 读取并检查图像
img_init = cv2.imread(self.to_predict_name)
if img_init is None:
print(f"Error: 无法加载图像文件 '{self.to_predict_name}'。请检查文件路径和文件名。")
sys.exit(1)
h, w, c = img_init.shape
scale = 400 / h
img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)
cv2.imwrite("images/show.png", img_show)
img_init = cv2.resize(img_init, (224, 224))
cv2.imwrite('images/target.png', img_init)
self.img_label.setPixmap(QPixmap("images/show.png"))
left_layout.addWidget(img_title)
left_layout.addWidget(self.img_label, 1, Qt.AlignCenter)
left_widget.setLayout(left_layout)
right_widget = QWidget()
right_layout = QVBoxLayout()
btn_change = QPushButton("上传图片")
btn_change.clicked.connect(self.change_img)
btn_change.setFont(font)
btn_predict = QPushButton("开始识别")
btn_predict.setFont(font)
btn_predict.clicked.connect(self.predict_img)
label_result = QLabel('蔬菜名称')
self.result = QLabel("等待识别")
label_result.setFont(QFont('楷体', 16))
self.result.setFont(QFont('楷体', 24))
right_layout.addStretch()
right_layout.addWidget(label_result, 0, Qt.AlignCenter)
right_layout.addStretch()
right_layout.addWidget(self.result, 0, Qt.AlignCenter)
right_layout.addStretch()
right_layout.addStretch()
right_layout.addWidget(btn_change)
right_layout.addWidget(btn_predict)
right_layout.addStretch()
right_widget.setLayout(right_layout)
main_layout.addWidget(left_widget)
main_layout.addWidget(right_widget)
main_widget.setLayout(main_layout)
about_widget = QWidget()
about_layout = QVBoxLayout()
about_title = QLabel('欢迎使用蔬菜识别系统')
about_title.setFont(QFont('楷体', 18))
about_title.setAlignment(Qt.AlignCenter)
about_img = QLabel()
about_img.setPixmap(QPixmap('images/bj.jpg'))
about_img.setAlignment(Qt.AlignCenter)
label_super = QLabel("作者:杨益")
label_super.setFont(QFont('楷体', 12))
label_super.setAlignment(Qt.AlignRight)
about_layout.addWidget(about_title)
about_layout.addStretch()
about_layout.addWidget(about_img)
about_layout.addStretch()
about_layout.addWidget(label_super)
about_widget.setLayout(about_layout)
self.tab_widget = QTabWidget()
self.tab_widget.addTab(main_widget, '主页')
self.tab_widget.addTab(about_widget, '关于')
self.tab_widget.setTabIcon(0, QIcon('images/主页面.png'))
self.tab_widget.setTabIcon(1, QIcon('images/关于.png'))
self.setCentralWidget(self.tab_widget)
def change_img(self):
openfile_name, _ = QFileDialog.getOpenFileName(self, '选择文件', '', 'Image files(*.jpg *.png *jpeg)')
if openfile_name:
target_image_name = "images/tmp_up." + openfile_name.split(".")[-1]
shutil.copy(openfile_name, target_image_name)
self.to_predict_name = target_image_name
img_init = cv2.imread(self.to_predict_name)
if img_init is None:
print(f"Error: 无法加载图像文件 '{self.to_predict_name}'。请检查文件路径和文件名。")
return
h, w, c = img_init.shape
scale = 400 / h
img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)
cv2.imwrite("images/show.png", img_show)
img_init = cv2.resize(img_init, (224, 224))
cv2.imwrite('images/target.png', img_init)
self.img_label.setPixmap(QPixmap("images/show.png"))
self.result.setText("等待识别")
def predict_img(self):
img = Image.open('images/target.png')
img = np.asarray(img)
outputs = self.model.predict(img.reshape(1, 224, 224, 3))
result_index = int(np.argmax(outputs))
result = self.class_names[result_index]
self.result.setText(result)
def closeEvent(self, event):
reply = QMessageBox.question(self, '退出', "是否要退出程序?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
if reply == QMessageBox.Yes:
event.accept()
else:
event.ignore()
if __name__ == "__main__":
app = QApplication(sys.argv)
x = MainWindow()
x.show()
sys.exit(app.exec_())这个代码是否正确
最新发布