// widget.cpp
#include "widget.h"
#include "ui_widget.h"
#include <QPainter>
#include <QMouseEvent>
#include <QFileDialog>
#include <QMessageBox>
#include<QGraphicsItem>
#include <QPixmap>
#include <QImage>
#undef slots
#include <torch/script.h>
#include <ATen/ATen.h>
#include <torch/torch.h>
#define slots Q_SLOTS
#include <ATen/ATen.h>
#include "customgraphicsview.h"
Widget::Widget(torch::jit::Module model, QWidget *parent)
: QWidget(parent)
, ui(new Ui::Widget)
, model(std::move(model))
{
ui->setupUi(this);
// 初始化场景
scene = new QGraphicsScene(this);
scene->setBackgroundBrush(Qt::black); // 黑色背景
ui->writingArea->setScene(scene);
initScene();
// 设置 writingArea 的样式
ui->writingArea->setStyleSheet(
"QGraphicsView {"
" border: 2px solid black;" // 黑色边框
" border-radius: 5px;"
"}"
);
//禁用滚动条
ui->writingArea->setHorizontalScrollBarPolicy(Qt::ScrollBarAlwaysOff);
ui->writingArea->setVerticalScrollBarPolicy(Qt::ScrollBarAlwaysOff);
// 设置视口背景透明(避免覆盖场景背景)
ui->writingArea->setViewport(new QWidget); // 重置视口
ui->writingArea->setRenderHint(QPainter::Antialiasing, true);
// 启用鼠标跟踪
setMouseTracking(true);
ui->writingArea->setMouseTracking(true);
// 连接信号与槽(使用新式语法)
connect(ui->writingArea, &CustomGraphicsView::mousePressed, this, &Widget::handleMousePress);
connect(ui->writingArea, &CustomGraphicsView::mouseMoved, this, &Widget::handleMouseMove);
connect(ui->writingArea, &CustomGraphicsView::mouseReleased, this, &Widget::handleMouseRelease);
}
Widget::~Widget()
{
delete ui;
}
// 初始化场景
void Widget::initScene()
{
scene->clear();
scene->setSceneRect(0, 0, 280, 280);
ui->writingArea->setFixedSize(280, 280);
// 添加空白矩形项,确保 sceneRect 生效
scene->addRect(scene->sceneRect(), QPen(Qt::NoPen), QBrush(Qt::black));
}
QString Widget::recognizeImage(const QImage &inputImage)
{
//缩放为 28x28 并转灰度图
QImage resizedImage = inputImage.scaled(28, 28, Qt::IgnoreAspectRatio, Qt::SmoothTransformation)
.convertToFormat(QImage::Format_Grayscale8);
//居中裁剪
QRect boundingBox;
bool hasContent = false;
for (int y = 0; y < resizedImage.height(); ++y) {
const uchar* line = resizedImage.scanLine(y);
for (int x = 0; x < resizedImage.width(); ++x) {
if (line[x] > 0x10) { // 判断是否是黑色像素(非白色)
boundingBox = boundingBox.united(QRect(x, y, 1, 1));
hasContent = true;
}
}
}
if (!hasContent) return "NULL";
// 创建新图像并将内容居中
QImage centeredImage(28, 28, QImage::Format_Grayscale8);
centeredImage.fill(Qt::black); // 填充为黑色背景
QPainter painter(¢eredImage);
int offsetX = 14 - boundingBox.width() / 2;
int offsetY = 14 - boundingBox.height() / 2;
painter.drawImage(offsetX, offsetY, resizedImage.copy(boundingBox));
// 提取像素数据并归一化
std::vector<float> pixelData;
const uchar* data = centeredImage.bits();
int width = centeredImage.width();
int height = centeredImage.height();
for (int y = 0; y < height; ++y) {
for (int x = 0; x < width; ++x) {
float value = data[y * width + x] / 255.0f; // 归一化到 [0, 1]
pixelData.push_back(value);
}
}
//构造张量 [1, 1, 28, 28]
auto tensor = torch::from_blob(pixelData.data(), {1, 28, 28}, torch::kFloat32); // [1, 28, 28]
tensor = tensor.unsqueeze(0); // [1, 1, 28, 28]
//标准化
tensor = (tensor - 0.1307) / 0.3081;
//推理
model.to(torch::kCPU);
tensor = tensor.to(torch::kCPU);
std::vector<torch::jit::IValue> inputs;
inputs.push_back(tensor);
at::Tensor output = model.forward(inputs).toTensor();
int64_t predicted_label = output.argmax(1).item<int64_t>();
return QString::number(predicted_label);
}
//识别按钮点击槽函数
void Widget::handleRecognitionAction()
{
try {
// 创建空白图像并渲染场景内容
QImage image(ui->writingArea->viewport()->size(), QImage::Format_ARGB32);
image.fill(Qt::white);
QPainter painter(&image);
scene->render(&painter);
// 检查是否为空画布
if (image.allGray() && image.pixel(10, 10) == qRgb(255, 255, 255)) {
ui->resultLabel->setText("NULL");
return;
}
// 调用识别函数
QString result = recognizeImage(image);
ui->resultLabel->setText(result);
} catch (const std::exception& e) {
QMessageBox::critical(this, "错误", QString("识别失败: %1").arg(e.what()));
ui->resultLabel->setText("ERROR");
}
}
// 清空画布按钮
void Widget::handleDeletionAction()
{
scene->clear();
ui->resultLabel->setText("NULL");
}
// 加载图片按钮
void Widget::handlePushAction()
{
QString fileName = QFileDialog::getOpenFileName(
this, "选择图像文件", "", "图像文件 (*.png *.jpg *.jpeg *.bmp)"
);
if (!fileName.isEmpty()) {
QPixmap pixmap(fileName);
scene->clear();
scene->addPixmap(pixmap.scaled(280, 280));
}
}
// 处理鼠标按下事件(槽函数)
void Widget::handleMousePress(QMouseEvent* event) {
if (event->button() == Qt::LeftButton && ui->writingArea->underMouse()) {
isDrawing = true;
QPointF scenePos = ui->writingArea->mapToScene(event->pos());
// 限制绘制区域在场景范围内
if (scene->sceneRect().contains(scenePos)) {
currentPath.moveTo(scenePos);
currentPathItem = scene->addPath(currentPath, QPen(Qt::white, 8));
}
}
}
void Widget::handleMouseMove(QMouseEvent* event) {
if (isDrawing && (event->buttons() & Qt::LeftButton)) {
QPointF scenePos = ui->writingArea->mapToScene(event->pos());
// 限制绘制区域在场景范围内
if (scene->sceneRect().contains(scenePos)) {
currentPath.lineTo(scenePos);
if (currentPathItem) {
currentPathItem->setPath(currentPath);
}
}
}
}
// 修正3:补全函数定义并修正路径类型
void Widget::handleMouseRelease(QMouseEvent* event)
{
if (event->button() == Qt::LeftButton) {
isDrawing = false;
currentPath = QPainterPath(); // 修正拼写错误
currentPathItem = nullptr;
}
}
// 删除旧的重复定义(确保只保留一个handleMousePress)
void Widget::handleFilechooseAction()
{
// 打开文件对话框选择图片
QString fileName = QFileDialog::getOpenFileName(
this,
"选择图像文件",
"",
"图像文件 (*.png *.jpg *.jpeg *.bmp)"
);
if (!fileName.isEmpty()) {
// 加载图片并显示到 writingArea
QPixmap pixmap(fileName);
if (pixmap.isNull()) {
QMessageBox::warning(this, "错误", "无法加载图像文件!");
return;
}
// 清空场景并添加缩放后的图片
scene->clear();
QGraphicsPixmapItem *item = scene->addPixmap(
pixmap.scaled(ui->writingArea->viewport()->size(),
Qt::IgnoreAspectRatio, // 忽略宽高比
Qt::SmoothTransformation)
);
// 强制图像填充整个视图
ui->writingArea->fitInView(item, Qt::IgnoreAspectRatio);
ui->writingArea->ensureVisible(item); // 确保内容可见
// Step 1: 直接从 QPixmap 转换为 QImage(必要步骤)
QImage image = pixmap.toImage().convertToFormat(QImage::Format_Grayscale8);
// Step 2: 缩放为 28x28(使用双线性插值)
QImage resizedImage = image.scaled(
28, 28,
Qt::IgnoreAspectRatio,
Qt::SmoothTransformation // 对应 Python 的 BILINEAR 插值 [[3]]
);
// Step 3: 提取像素数据并展平为 [1, 784]
std::vector<float> pixelData;
const uchar* data = resizedImage.bits(); // 直接访问内存缓冲区
int width = resizedImage.width();
int height = resizedImage.height();
for (int y = 0; y < height; ++y) {
for (int x = 0; x < width; ++x) {
float value = data[y * width + x] / 255.0f; // 行优先展平
pixelData.push_back(value);
}
}
// Step 4: 创建张量并 reshape 成 [1, 1, 28, 28]
auto tensor = torch::from_blob(pixelData.data(), {1, 28, 28}, torch::kFloat32); // [1, 28, 28]
tensor = tensor.unsqueeze(0); // [1, 1, 28, 28]
// Step 5: 标准化(与训练一致)
tensor = (tensor - 0.1307) / 0.3081;
// Step 6: 推理并显示结果
std::vector<torch::jit::IValue> inputs;
inputs.push_back(tensor);
at::Tensor output = model.forward(inputs).toTensor();
int64_t predicted_label = output.argmax(1).item<int64_t>();
ui->resultLabel->setText(QString::number(predicted_label));
}
}解释这串代码