Question about center & scale

本文探讨了人体姿态估计任务中图像裁剪的方法,包括如何设置裁剪区域的中心和比例,确保人物居中并占据约70-80%的高度。还讨论了针对截断或部分可见人物的解决方案。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

https://github.com/umich-vl/pose-hg-demo/issues/8

我的建议是将“center”设置为边界框的中心(足够简单),然后对于“scale”,你需要稍微玩一下。 使用框的最大尺寸(高度或宽度)并除以常数因子。 我不能说我的头脑是多么好的价值,但要玩代码。 您可以调用裁剪并显示生成的图像。 理想情况下,站立的人应居中并占据图像高度的约70-80%。

方框的主要问题是很难判断数字是否被截断,并且您可能会看到只有身体上半部分可见的数字表现不佳。 在这种情况下,可以尝试根据盒子的纵横比调整中心和比例。 最终,特别依赖于这些中心和规模术语的训练极限之一是很难直接推广到边界框,但我认为你仍然应该能够产生合理的预测。

import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import classification_report, confusion_matrix import seaborn as sns from time import time from matplotlib import rcParams rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体 rcParams['axes.unicode_minus'] = False # 避免负号显示问题 # 数据集加载函数 def data_load(data_dir, test_data_dir, img_height, img_width, batch_size): train_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, label_mode='categorical', seed=123, image_size=(img_height, img_width), batch_size=batch_size) val_ds = tf.keras.preprocessing.image_dataset_from_directory( test_data_dir, label_mode='categorical', seed=123, image_size=(img_height, img_width), batch_size=batch_size) class_names = train_ds.class_names return train_ds, val_ds, class_names # 构建MobileNet模型 def model_load(IMG_SHAPE=(224, 224, 3), class_num=15): base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet') base_model.trainable = False model = tf.keras.models.Sequential([ tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE), base_model, tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dense(class_num, activation='softmax') ]) model.summary() model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) return model # 绘制训练过程中的损失和准确率 def show_loss_acc(history): acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] plt.figure(figsize=(8, 8)) plt.subplot(2, 1, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.legend(loc='lower right') plt.ylabel('Accuracy') plt.ylim([min(plt.ylim()), 1]) plt.title('Training and Validation Accuracy') plt.subplot(2, 1, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.legend(loc='upper right') plt.ylabel('Cross Entropy') plt.title('Training and Validation Loss') plt.xlabel('epoch') plt.savefig('results/results_mobilenet.png', dpi=100) # 评估模型函数 def evaluate_model(model, val_ds, class_names): y_true = [] y_pred = [] for images, labels in val_ds: predictions = model.predict(images) y_pred.extend(np.argmax(predictions, axis=1)) y_true.extend(np.argmax(labels, axis=1)) cm = confusion_matrix(y_true, y_pred) print("Confusion Matrix:") print(cm) print("Classification Report:") print(classification_report(y_true, y_pred, target_names=class_names, digits=4)) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.xlabel('Predicted') plt.ylabel('True') plt.title('Confusion Matrix') plt.show() # 自定义回调函数,用于动态打印评价指标 class PrintEvalCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): print(f"Epoch {epoch + 1}: " f"loss={logs['loss']:.4f}, accuracy={logs['accuracy']:.4f}, " f"val_loss={logs['val_loss']:.4f}, val_accuracy={logs['val_accuracy']:.4f}") # 主训练函数 def train(epochs): begin_time = time() train_ds, val_ds, class_names = data_load( "D:/1/WeChat Files/wxid_c6n47k03okeu22/FileStorage/File/2024-05/vegetables_tf2.3-master/data/train", "D:/1/WeChat Files/wxid_c6n47k03okeu22/FileStorage/File/2024-05/vegetables_tf2.3-master/data/valid", 224, 224, 16 ) print("Class Names:", class_names) model = model_load(class_num=len(class_names)) # 使用自定义回调函数 eval_callback = PrintEvalCallback() history = model.fit(train_ds, validation_data=val_ds, epochs=epochs, callbacks=[eval_callback]) model.save("models/mobilenet_fv.h5") end_time = time() print('Training time:', end_time - begin_time, "seconds") show_loss_acc(history) evaluate_model(model, val_ds, class_names) if __name__ == '__main__': train(epochs=10)import tensorflow as tf from PyQt5.QtGui import QIcon, QFont, QPixmap from PyQt5.QtCore import Qt from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QTabWidget, QFileDialog, QMessageBox import sys import cv2 from PIL import Image import numpy as np import shutil class MainWindow(QMainWindow): def __init__(self): super().__init__() self.setWindowIcon(QIcon('images/logo.png')) # 确保路径正确 self.setWindowTitle('蔬菜识别系统') self.model = tf.keras.models.load_model("models/mobilenet_fv.h5") self.to_predict_name = "images/tim9.jpeg" self.class_names = ['豆', '苦瓜', '葫芦瓜', '茄子', '西兰花', '胡萝卜', '包菜', '辣椒', '西红柿', '花菜', '土豆', '黄瓜','番木瓜', '南瓜', '白萝卜', ] self.resize(900, 700) self.initUI() def initUI(self): main_widget = QWidget() main_layout = QHBoxLayout() font = QFont('楷体', 15) left_widget = QWidget() left_layout = QVBoxLayout() img_title = QLabel("样本") img_title.setFont(font) img_title.setAlignment(Qt.AlignCenter) self.img_label = QLabel() # 读取并检查图像 img_init = cv2.imread(self.to_predict_name) if img_init is None: print(f"Error: 无法加载图像文件 '{self.to_predict_name}'。请检查文件路径和文件名。") sys.exit(1) h, w, c = img_init.shape scale = 400 / h img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale) cv2.imwrite("images/show.png", img_show) img_init = cv2.resize(img_init, (224, 224)) cv2.imwrite('images/target.png', img_init) self.img_label.setPixmap(QPixmap("images/show.png")) left_layout.addWidget(img_title) left_layout.addWidget(self.img_label, 1, Qt.AlignCenter) left_widget.setLayout(left_layout) right_widget = QWidget() right_layout = QVBoxLayout() btn_change = QPushButton("上传图片") btn_change.clicked.connect(self.change_img) btn_change.setFont(font) btn_predict = QPushButton("开始识别") btn_predict.setFont(font) btn_predict.clicked.connect(self.predict_img) label_result = QLabel('蔬菜名称') self.result = QLabel("等待识别") label_result.setFont(QFont('楷体', 16)) self.result.setFont(QFont('楷体', 24)) right_layout.addStretch() right_layout.addWidget(label_result, 0, Qt.AlignCenter) right_layout.addStretch() right_layout.addWidget(self.result, 0, Qt.AlignCenter) right_layout.addStretch() right_layout.addStretch() right_layout.addWidget(btn_change) right_layout.addWidget(btn_predict) right_layout.addStretch() right_widget.setLayout(right_layout) main_layout.addWidget(left_widget) main_layout.addWidget(right_widget) main_widget.setLayout(main_layout) about_widget = QWidget() about_layout = QVBoxLayout() about_title = QLabel('欢迎使用蔬菜识别系统') about_title.setFont(QFont('楷体', 18)) about_title.setAlignment(Qt.AlignCenter) about_img = QLabel() about_img.setPixmap(QPixmap('images/bj.jpg')) about_img.setAlignment(Qt.AlignCenter) label_super = QLabel("作者:杨益") label_super.setFont(QFont('楷体', 12)) label_super.setAlignment(Qt.AlignRight) about_layout.addWidget(about_title) about_layout.addStretch() about_layout.addWidget(about_img) about_layout.addStretch() about_layout.addWidget(label_super) about_widget.setLayout(about_layout) self.tab_widget = QTabWidget() self.tab_widget.addTab(main_widget, '主页') self.tab_widget.addTab(about_widget, '关于') self.tab_widget.setTabIcon(0, QIcon('images/主页面.png')) self.tab_widget.setTabIcon(1, QIcon('images/关于.png')) self.setCentralWidget(self.tab_widget) def change_img(self): openfile_name, _ = QFileDialog.getOpenFileName(self, '选择文件', '', 'Image files(*.jpg *.png *jpeg)') if openfile_name: target_image_name = "images/tmp_up." + openfile_name.split(".")[-1] shutil.copy(openfile_name, target_image_name) self.to_predict_name = target_image_name img_init = cv2.imread(self.to_predict_name) if img_init is None: print(f"Error: 无法加载图像文件 '{self.to_predict_name}'。请检查文件路径和文件名。") return h, w, c = img_init.shape scale = 400 / h img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale) cv2.imwrite("images/show.png", img_show) img_init = cv2.resize(img_init, (224, 224)) cv2.imwrite('images/target.png', img_init) self.img_label.setPixmap(QPixmap("images/show.png")) self.result.setText("等待识别") def predict_img(self): img = Image.open('images/target.png') img = np.asarray(img) outputs = self.model.predict(img.reshape(1, 224, 224, 3)) result_index = int(np.argmax(outputs)) result = self.class_names[result_index] self.result.setText(result) def closeEvent(self, event): reply = QMessageBox.question(self, '退出', "是否要退出程序?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No) if reply == QMessageBox.Yes: event.accept() else: event.ignore() if __name__ == "__main__": app = QApplication(sys.argv) x = MainWindow() x.show() sys.exit(app.exec_())这个代码是否正确
最新发布
06-02
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值