tf.app.run(main=main, argv=[sys.argv[0]])

TensorFlow图像分类测试错误
本文解决了一个在使用TensorFlow进行图像分类测试时遇到的问题,该问题源于TensorFlow版本过低导致的TypeError。具体表现为在调用tf.app.run时传递了不被支持的argv参数。文章提供了针对该问题的解决方案,即升级TensorFlow版本到0.12或以上。

用tensorflow测试图像分类的时候,出现下面的问题,

观察上面的问题我们可以发现。问题是由于:

主要线索是错误提示:

TypeError: run() got an unexpected keyword argument 'argv'

目测应该是你的tensorflow版本问题。
r0.11及以前的版本里, tf.app.run

 def run(main=None):
  f = flags.FLAGS
  flags_passthrough = f._parse_flags()
  main = main or sys.modules['__main__'].main
  sys.exit(main(sys.argv[:1] + flags_passthrough))

没有argv参数,argv参数是r0.12后加入。
所以,推测你的tensorflow版本在r0.11及以下,你需要更新你的tensorflow版本到r0.12或以上。

所以解决方案便是将tensorflow的版本升级到r0,12及以上,升级的方法如下:

sudo pip  --no-cache-dir install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.0rc0-cp27-none-linux_x86_64.whl
OK,问题到此结束。


import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import classification_report, confusion_matrix import seaborn as sns from time import time from matplotlib import rcParams rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体 rcParams['axes.unicode_minus'] = False # 避免负号显示问题 # 数据集加载函数 def data_load(data_dir, test_data_dir, img_height, img_width, batch_size): train_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, label_mode='categorical', seed=123, image_size=(img_height, img_width), batch_size=batch_size) val_ds = tf.keras.preprocessing.image_dataset_from_directory( test_data_dir, label_mode='categorical', seed=123, image_size=(img_height, img_width), batch_size=batch_size) class_names = train_ds.class_names return train_ds, val_ds, class_names # 构建MobileNet模型 def model_load(IMG_SHAPE=(224, 224, 3), class_num=15): base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet') base_model.trainable = False model = tf.keras.models.Sequential([ tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE), base_model, tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dense(class_num, activation='softmax') ]) model.summary() model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) return model # 绘制训练过程中的损失和准确率 def show_loss_acc(history): acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] plt.figure(figsize=(8, 8)) plt.subplot(2, 1, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.legend(loc='lower right') plt.ylabel('Accuracy') plt.ylim([min(plt.ylim()), 1]) plt.title('Training and Validation Accuracy') plt.subplot(2, 1, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.legend(loc='upper right') plt.ylabel('Cross Entropy') plt.title('Training and Validation Loss') plt.xlabel('epoch') plt.savefig('results/results_mobilenet.png', dpi=100) # 评估模型函数 def evaluate_model(model, val_ds, class_names): y_true = [] y_pred = [] for images, labels in val_ds: predictions = model.predict(images) y_pred.extend(np.argmax(predictions, axis=1)) y_true.extend(np.argmax(labels, axis=1)) cm = confusion_matrix(y_true, y_pred) print("Confusion Matrix:") print(cm) print("Classification Report:") print(classification_report(y_true, y_pred, target_names=class_names, digits=4)) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.xlabel('Predicted') plt.ylabel('True') plt.title('Confusion Matrix') plt.show() # 自定义回调函数,用于动态打印评价指标 class PrintEvalCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): print(f"Epoch {epoch + 1}: " f"loss={logs['loss']:.4f}, accuracy={logs['accuracy']:.4f}, " f"val_loss={logs['val_loss']:.4f}, val_accuracy={logs['val_accuracy']:.4f}") # 主训练函数 def train(epochs): begin_time = time() train_ds, val_ds, class_names = data_load( "D:/1/WeChat Files/wxid_c6n47k03okeu22/FileStorage/File/2024-05/vegetables_tf2.3-master/data/train", "D:/1/WeChat Files/wxid_c6n47k03okeu22/FileStorage/File/2024-05/vegetables_tf2.3-master/data/valid", 224, 224, 16 ) print("Class Names:", class_names) model = model_load(class_num=len(class_names)) # 使用自定义回调函数 eval_callback = PrintEvalCallback() history = model.fit(train_ds, validation_data=val_ds, epochs=epochs, callbacks=[eval_callback]) model.save("models/mobilenet_fv.h5") end_time = time() print('Training time:', end_time - begin_time, "seconds") show_loss_acc(history) evaluate_model(model, val_ds, class_names) if __name__ == '__main__': train(epochs=10)import tensorflow as tf from PyQt5.QtGui import QIcon, QFont, QPixmap from PyQt5.QtCore import Qt from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QTabWidget, QFileDialog, QMessageBox import sys import cv2 from PIL import Image import numpy as np import shutil class MainWindow(QMainWindow): def __init__(self): super().__init__() self.setWindowIcon(QIcon('images/logo.png')) # 确保路径正确 self.setWindowTitle('蔬菜识别系统') self.model = tf.keras.models.load_model("models/mobilenet_fv.h5") self.to_predict_name = "images/tim9.jpeg" self.class_names = ['豆', '苦瓜', '葫芦瓜', '茄子', '西兰花', '胡萝卜', '包菜', '辣椒', '西红柿', '花菜', '土豆', '黄瓜','番木瓜', '南瓜', '白萝卜', ] self.resize(900, 700) self.initUI() def initUI(self): main_widget = QWidget() main_layout = QHBoxLayout() font = QFont('楷体', 15) left_widget = QWidget() left_layout = QVBoxLayout() img_title = QLabel("样本") img_title.setFont(font) img_title.setAlignment(Qt.AlignCenter) self.img_label = QLabel() # 读取并检查图像 img_init = cv2.imread(self.to_predict_name) if img_init is None: print(f"Error: 无法加载图像文件 '{self.to_predict_name}'。请检查文件路径和文件名。") sys.exit(1) h, w, c = img_init.shape scale = 400 / h img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale) cv2.imwrite("images/show.png", img_show) img_init = cv2.resize(img_init, (224, 224)) cv2.imwrite('images/target.png', img_init) self.img_label.setPixmap(QPixmap("images/show.png")) left_layout.addWidget(img_title) left_layout.addWidget(self.img_label, 1, Qt.AlignCenter) left_widget.setLayout(left_layout) right_widget = QWidget() right_layout = QVBoxLayout() btn_change = QPushButton("上传图片") btn_change.clicked.connect(self.change_img) btn_change.setFont(font) btn_predict = QPushButton("开始识别") btn_predict.setFont(font) btn_predict.clicked.connect(self.predict_img) label_result = QLabel('蔬菜名称') self.result = QLabel("等待识别") label_result.setFont(QFont('楷体', 16)) self.result.setFont(QFont('楷体', 24)) right_layout.addStretch() right_layout.addWidget(label_result, 0, Qt.AlignCenter) right_layout.addStretch() right_layout.addWidget(self.result, 0, Qt.AlignCenter) right_layout.addStretch() right_layout.addStretch() right_layout.addWidget(btn_change) right_layout.addWidget(btn_predict) right_layout.addStretch() right_widget.setLayout(right_layout) main_layout.addWidget(left_widget) main_layout.addWidget(right_widget) main_widget.setLayout(main_layout) about_widget = QWidget() about_layout = QVBoxLayout() about_title = QLabel('欢迎使用蔬菜识别系统') about_title.setFont(QFont('楷体', 18)) about_title.setAlignment(Qt.AlignCenter) about_img = QLabel() about_img.setPixmap(QPixmap('images/bj.jpg')) about_img.setAlignment(Qt.AlignCenter) label_super = QLabel("作者:杨益") label_super.setFont(QFont('楷体', 12)) label_super.setAlignment(Qt.AlignRight) about_layout.addWidget(about_title) about_layout.addStretch() about_layout.addWidget(about_img) about_layout.addStretch() about_layout.addWidget(label_super) about_widget.setLayout(about_layout) self.tab_widget = QTabWidget() self.tab_widget.addTab(main_widget, '主页') self.tab_widget.addTab(about_widget, '关于') self.tab_widget.setTabIcon(0, QIcon('images/主页面.png')) self.tab_widget.setTabIcon(1, QIcon('images/关于.png')) self.setCentralWidget(self.tab_widget) def change_img(self): openfile_name, _ = QFileDialog.getOpenFileName(self, '选择文件', '', 'Image files(*.jpg *.png *jpeg)') if openfile_name: target_image_name = "images/tmp_up." + openfile_name.split(".")[-1] shutil.copy(openfile_name, target_image_name) self.to_predict_name = target_image_name img_init = cv2.imread(self.to_predict_name) if img_init is None: print(f"Error: 无法加载图像文件 '{self.to_predict_name}'。请检查文件路径和文件名。") return h, w, c = img_init.shape scale = 400 / h img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale) cv2.imwrite("images/show.png", img_show) img_init = cv2.resize(img_init, (224, 224)) cv2.imwrite('images/target.png', img_init) self.img_label.setPixmap(QPixmap("images/show.png")) self.result.setText("等待识别") def predict_img(self): img = Image.open('images/target.png') img = np.asarray(img) outputs = self.model.predict(img.reshape(1, 224, 224, 3)) result_index = int(np.argmax(outputs)) result = self.class_names[result_index] self.result.setText(result) def closeEvent(self, event): reply = QMessageBox.question(self, '退出', "是否要退出程序?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No) if reply == QMessageBox.Yes: event.accept() else: event.ignore() if __name__ == "__main__": app = QApplication(sys.argv) x = MainWindow() x.show() sys.exit(app.exec_())这个代码是否正确
最新发布
06-02
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值