错误提示:Project ... is missing required source folder: 'gen'

本文介绍了解决Eclipse3.5.2+AndroidADT2.1创建的Android项目中出现的“Project is missing required source folder: 'gen'”错误的方法。通过删除src包中的R.java文件并重启Eclipse即可解决此问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Eclipse3.5.2 + Android ADT2.1
创建新的Android项目后总是提示错误:Project ... is missing required source folder: 'gen'。
解决办法:将Eclipse自动生成的R.java删掉,刷新项目,R.java便会重新生成

我是在项目中把src包中的R. java删除就可以了。而不是简单的删掉GEN中的R.JAVA

最后再把ECLIPSE重启就可以解决了

import sys import subprocess import zipfile import pkg_resources import requests # 检查并安装缺失的依赖 required = { 'torch', 'torchvision', 'numpy', 'matplotlib', 'tqdm', 'requests', 'pillow', 'scikit-learn', 'pyqt5', 'torchsummary' # 添加torchsummary } installed = {pkg.key for pkg in pkg_resources.working_set} missing = required - installed if missing: print(f"安装缺失的依赖: {', '.join(missing)}") python = sys.executable subprocess.check_call([python, '-m', 'pip', 'install', *missing]) # 现在导入其他模块 import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, random_split from torchvision import datasets, transforms, models import numpy as np import matplotlib.pyplot as plt import os import shutil from PIL import Image from tqdm import tqdm import matplotlib from matplotlib import font_manager import json from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay # PyQt5相关导入 from PyQt5.QtWidgets import (QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QScrollArea, QFileDialog, QMessageBox, QTextEdit) from PyQt5.QtGui import QPixmap from PyQt5.QtCore import Qt, QObject, pyqtSignal import threading import time # 导入torchsummary from torchsummary import summary # 设置中文字体支持 try: plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False except: try: font_url = "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansSC-Regular.otf" font_path = "NotoSansSC-Regular.otf" if not os.path.exists(font_path): response = requests.get(font_url) with open(font_path, 'wb') as f: f.write(response.content) font_prop = font_manager.FontProperties(fname=font_path) plt.rcParams['font.family'] = font_prop.get_name() except: print("警告: 无法设置中文字体") matplotlib.use('Agg') # 第二部分:下载并设置数据集 def download_and_extract_dataset(): base_dir = "data" data_path = os.path.join(base_dir, "dogs-vs-cats") train_folder = os.path.join(data_path, 'train') test_folder = os.path.join(data_path, 'test') os.makedirs(train_folder, exist_ok=True) os.makedirs(test_folder, exist_ok=True) # 检查数据集是否完整 cat_files = [f for f in os.listdir(train_folder) if f.startswith('cat')] dog_files = [f for f in os.listdir(train_folder) if f.startswith('dog')] if len(cat_files) > 1000 and len(dog_files) > 1000: print("数据集已存在,跳过下载") return print("正在下载数据集...") dataset_url = "https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip" try: zip_path = os.path.join(base_dir, "catsdogs.zip") # 下载文件 if not os.path.exists(zip_path): response = requests.get(dataset_url, stream=True) total_size = int(response.headers.get('content-length', 0)) with open(zip_path, 'wb') as f, tqdm( desc="下载进度", total=total_size, unit='B', unit_scale=True, unit_divisor=1024, ) as bar: for data in response.iter_content(chunk_size=1024): size = f.write(data) bar.update(size) print("下载完成,正在解压...") # 解压文件 with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(base_dir) print("数据集解压完成!") # 移动文件 extracted_dir = os.path.join(base_dir, "PetImages") # 移动猫图片 cat_source = os.path.join(extracted_dir, "Cat") for file in os.listdir(cat_source): src = os.path.join(cat_source, file) dst = os.path.join(train_folder, f"cat.{file}") if os.path.exists(src) and not os.path.exists(dst): shutil.move(src, dst) # 移动狗图片 dog_source = os.path.join(extracted_dir, "Dog") for file in os.listdir(dog_source): src = os.path.join(dog_source, file) dst = os.path.join(train_folder, f"dog.{file}") if os.path.exists(src) and not os.path.exists(dst): shutil.move(src, dst) # 创建测试集(从训练集中抽取20%) train_files = os.listdir(train_folder) np.random.seed(42) test_files = np.random.choice(train_files, size=int(len(train_files) * 0.2), replace=False) for file in test_files: src = os.path.join(train_folder, file) dst = os.path.join(test_folder, file) if os.path.exists(src) and not os.path.exists(dst): shutil.move(src, dst) # 清理临时文件 if os.path.exists(extracted_dir): shutil.rmtree(extracted_dir) if os.path.exists(zip_path): os.remove(zip_path) print( f"数据集设置完成!训练集: {len(os.listdir(train_folder))} 张图片, 测试集: {len(os.listdir(test_folder))} 张图片") except Exception as e: print(f"下载或设置数据集时出错: {str(e)}") print("请手动下载数据集并解压到 data/dogs-vs-cats 目录") print("下载地址: https://www.microsoft.com/en-us/download/details.aspx?id=54765") # 下载并解压数据集 download_and_extract_dataset() # 第三部分:自定义数据集 class DogsVSCats(Dataset): def __init__(self, data_dir, transform=None): self.image_paths = [] self.labels = [] for file in os.listdir(data_dir): if file.lower().endswith(('.png', '.jpg', '.jpeg')): img_path = os.path.join(data_dir, file) try: # 验证图片完整性 with Image.open(img_path) as img: img.verify() self.image_paths.append(img_path) # 根据文件名设置标签 if file.startswith('cat'): self.labels.append(0) elif file.startswith('dog'): self.labels.append(1) else: # 对于无法识别的文件,默认设为猫 self.labels.append(0) except (IOError, SyntaxError) as e: print(f"跳过损坏图片: {img_path} - {str(e)}") if not self.image_paths: print(f"错误: 在 {data_dir} 中没有找到有效图片!") for i in range(10): img_path = os.path.join(data_dir, f"example_{i}.jpg") img = Image.new('RGB', (224, 224), color=(i * 25, i * 25, i * 25)) img.save(img_path) self.image_paths.append(img_path) self.labels.append(0 if i % 2 == 0 else 1) print(f"已创建 {len(self.image_paths)} 个示例图片") self.transform = transform or transforms.Compose([ transforms.Resize((150, 150)), # 修改为150x150以匹配CNN输入 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): try: image = Image.open(self.image_paths[idx]).convert('RGB') except Exception as e: print(f"无法加载图片: {self.image_paths[idx]}, 使用占位符 - {str(e)}") image = Image.new('RGB', (150, 150), color=(100, 100, 100)) image = self.transform(image) label = torch.tensor(self.labels[idx], dtype=torch.long) return image, label # 第六部分:定义自定义CNN模型(添加额外的Dropout层) class CatDogCNN(nn.Module): def __init__(self): super(CatDogCNN, self).__init__() # 卷积层1: 输入3通道(RGB), 输出32通道, 卷积核3x3 self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) # 卷积层2: 输入32通道, 输出64通道, 卷积核3x3 self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 卷积层3: 输入64通道, 输出128通道, 卷积核3x3 self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # 卷积层4: 输入128通道, 输出256通道, 卷积核3x3 self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # 最大池化层 self.pool = nn.MaxPool2d(2, 2) # 全连接层 self.fc1 = nn.Linear(256 * 9 * 9, 512) # 输入尺寸计算: 150 -> 75 -> 37 -> 18 -> 9 self.fc2 = nn.Linear(512, 2) # 输出2个类别 (猫和狗) # Dropout防止过拟合(添加额外的Dropout层) self.dropout1 = nn.Dropout(0.5) # 第一个Dropout层 self.dropout2 = nn.Dropout(0.5) # 新添加的第二个Dropout层 def forward(self, x): # 卷积层1 + ReLU + 池化 x = self.pool(F.relu(self.conv1(x))) # 卷积层2 + ReLU + 池化 x = self.pool(F.relu(self.conv2(x))) # 卷积层3 + ReLU + 池化 x = self.pool(F.relu(self.conv3(x))) # 卷积层4 + ReLU + 池化 x = self.pool(F.relu(self.conv4(x))) # 展平特征图 x = x.view(-1, 256 * 9 * 9) # 全连接层 + Dropout x = self.dropout1(F.relu(self.fc1(x))) # 添加第二个Dropout层 x = self.dropout2(x) # 输出层 x = self.fc2(x) return x # 第七部分:模型训练和可视化 class Trainer: def __init__(self, model, train_loader, val_loader): self.train_loader = train_loader self.val_loader = val_loader self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {self.device}") self.model = model.to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) self.criterion = nn.CrossEntropyLoss() # 使用兼容性更好的调度器设置(移除了 verbose 参数) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='max', factor=0.1, patience=2) # 记录指标 self.train_losses = [] self.train_accuracies = [] self.val_losses = [] self.val_accuracies = [] def train(self, num_epochs): best_accuracy = 0.0 for epoch in range(num_epochs): # 训练阶段 self.model.train() running_loss = 0.0 correct = 0 total = 0 train_bar = tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [训练]") for images, labels in train_bar: images, labels = images.to(self.device), labels.to(self.device) self.optimizer.zero_grad() outputs = self.model(images) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() running_loss += loss.item() * images.size(0) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() train_loss = running_loss / total train_acc = correct / total train_bar.set_postfix(loss=train_loss, acc=train_acc) # 计算训练指标 epoch_train_loss = running_loss / total epoch_train_acc = correct / total self.train_losses.append(epoch_train_loss) self.train_accuracies.append(epoch_train_acc) # 验证阶段 val_loss, val_acc = self.validate() self.val_losses.append(val_loss) self.val_accuracies.append(val_acc) # 更新学习率 self.scheduler.step(val_acc) # 保存最佳模型 if val_acc > best_accuracy: best_accuracy = val_acc torch.save(self.model.state_dict(), 'best_cnn_model.pth') print(f"保存最佳模型,验证准确率: {best_accuracy:.4f}") # 打印epoch结果 print(f"Epoch {epoch + 1}/{num_epochs} | " f"训练损失: {epoch_train_loss:.4f} | 训练准确率: {epoch_train_acc:.4f} | " f"验证损失: {val_loss:.4f} | 验证准确率: {val_acc:.4f}") # 训练完成后可视化结果 self.visualize_training_results() def validate(self): self.model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): val_bar = tqdm(self.val_loader, desc="[验证]") for images, labels in val_bar: images, labels = images.to(self.device), labels.to(self.device) outputs = self.model(images) loss = self.criterion(outputs, labels) running_loss += loss.item() * images.size(0) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() val_loss = running_loss / total val_acc = correct / total val_bar.set_postfix(loss=val_loss, acc=val_acc) return running_loss / total, correct / total def visualize_training_results(self): """可视化训练和验证的准确率与损失""" epochs = range(1, len(self.train_accuracies) + 1) # 创建准确率图表 plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.plot(epochs, self.train_accuracies, 'bo-', label='训练准确率') plt.plot(epochs, self.val_accuracies, 'ro-', label='验证准确率') plt.title('训练和验证准确率') plt.xlabel('Epoch') plt.ylabel('准确率') plt.legend() plt.grid(True) # 创建损失图表 plt.subplot(1, 2, 2) plt.plot(epochs, self.train_losses, 'bo-', label='训练损失') plt.plot(epochs, self.val_losses, 'ro-', label='验证损失') plt.title('训练和验证损失') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig('training_visualization.png') print("训练结果可视化图表已保存为 training_visualization.png") # 单独保存准确率图表 plt.figure(figsize=(8, 6)) plt.plot(epochs, self.train_accuracies, 'bo-', label='训练准确率') plt.plot(epochs, self.val_accuracies, 'ro-', label='验证准确率') plt.title('训练和验证准确率') plt.xlabel('Epoch') plt.ylabel('准确率') plt.legend() plt.grid(True) plt.savefig('accuracy_curve.png') print("准确率曲线已保存为 accuracy_curve.png") # 单独保存损失图表 plt.figure(figsize=(8, 6)) plt.plot(epochs, self.train_losses, 'bo-', label='训练损失') plt.plot(epochs, self.val_losses, 'ro-', label='验证损失') plt.title('训练和验证损失') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.grid(True) plt.savefig('loss_curve.png') print("损失曲线已保存为 loss_curve.png") # 保存训练结果 results = { 'epochs': list(epochs), 'train_losses': self.train_losses, 'train_accuracies': self.train_accuracies, 'val_losses': self.val_losses, 'val_accuracies': self.val_accuracies } with open('training_results.json', 'w') as f: json.dump(results, f) print("训练结果已保存为 training_results.json") # 图像处理类 class ImageProcessor(QObject): result_signal = pyqtSignal(str, str) # 信号:filename, result def __init__(self, model, device, filename): super().__init__() self.model = model self.device = device self.filename = filename self.transform = transforms.Compose([ transforms.Resize((150, 150)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def process_image(self): try: # 加载图像 image = Image.open(self.filename).convert('RGB') image_tensor = self.transform(image).unsqueeze(0).to(self.device) # 模型预测 self.model.eval() with torch.no_grad(): output = self.model(image_tensor) probabilities = F.softmax(output, dim=1) _, predicted = torch.max(output, 1) # 获取猫和狗的置信度 cat_prob = probabilities[0][0].item() dog_prob = probabilities[0][1].item() # 确定结果和置信度 result = "猫" if predicted.item() == 0 else "狗" confidence = cat_prob if result == "猫" else dog_prob # 格式化输出结果 formatted_result = f"{result} ({confidence * 100:.1f}%置信度)" self.result_signal.emit(self.filename, formatted_result) except Exception as e: self.result_signal.emit(self.filename, f"处理错误: {str(e)}") # 主应用窗口 class CatDogClassifierApp(QWidget): def __init__(self, model, device): super().__init__() self.setWindowTitle("猫狗识别系统") self.setGeometry(100, 100, 1000, 700) self.model = model self.device = device self.initUI() self.image_processors = [] def initUI(self): # 主布局 main_layout = QVBoxLayout() # 标题 title = QLabel("猫狗识别系统") title.setAlignment(Qt.AlignCenter) title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 10px;") main_layout.addWidget(title) # 按钮区域 button_layout = QHBoxLayout() self.upload_button = QPushButton("上传图像") self.upload_button.setStyleSheet("font-size: 16px; padding: 10px;") self.upload_button.clicked.connect(self.uploadImage) button_layout.addWidget(self.upload_button) self.batch_process_button = QPushButton("批量处理") self.batch_process_button.setStyleSheet("font-size: 16px; padding: 10px;") self.batch_process_button.clicked.connect(self.batchProcess) button_layout.addWidget(self.batch_process_button) self.clear_button = QPushButton("清除所有") self.clear_button.setStyleSheet("font-size: 16px; padding: 10px;") self.clear_button.clicked.connect(self.clearAll) button_layout.addWidget(self.clear_button) self.results_button = QPushButton("查看训练结果") self.results_button.setStyleSheet("font-size: 16px; padding: 10px;") self.results_button.clicked.connect(self.showTrainingResults) button_layout.addWidget(self.results_button) # 添加查看模型结构按钮 self.model_summary_button = QPushButton("查看模型结构") self.model_summary_button.setStyleSheet("font-size: 16px; padding: 10px;") self.model_summary_button.clicked.connect(self.showModelSummary) button_layout.addWidget(self.model_summary_button) main_layout.addLayout(button_layout) # 状态标签 self.status_label = QLabel("就绪") self.status_label.setStyleSheet("font-size: 14px; color: #666; margin: 5px;") main_layout.addWidget(self.status_label) # 图像预览区域 self.preview_area = QScrollArea() self.preview_area.setWidgetResizable(True) self.preview_area.setStyleSheet("background-color: #f0f0f0;") self.preview_widget = QWidget() self.preview_layout = QHBoxLayout() self.preview_layout.setAlignment(Qt.AlignTop | Qt.AlignLeft) self.preview_widget.setLayout(self.preview_layout) self.preview_area.setWidget(self.preview_widget) main_layout.addWidget(self.preview_area) # 底部信息 info_label = QLabel("基于卷积神经网络(CNN)的猫狗识别系统 | 支持上传单张或多张图片") info_label.setAlignment(Qt.AlignCenter) info_label.setStyleSheet("font-size: 12px; color: #888; margin: 10px;") main_layout.addWidget(info_label) self.setLayout(main_layout) def uploadImage(self): self.status_label.setText("正在选择图像...") filename, _ = QFileDialog.getOpenFileName( self, "选择图像", "", "图像文件 (*.png *.jpg *.jpeg)" ) if filename: self.status_label.setText(f"正在处理: {os.path.basename(filename)}") self.displayImage(filename) def batchProcess(self): self.status_label.setText("正在选择多张图像...") filenames, _ = QFileDialog.getOpenFileNames( self, "选择多张图像", "", "图像文件 (*.png *.jpg *.jpeg)" ) if filenames: self.status_label.setText(f"正在批量处理 {len(filenames)} 张图像...") for filename in filenames: self.displayImage(filename) def displayImage(self, filename): if not os.path.isfile(filename): QMessageBox.warning(self, "警告", "文件路径不安全或文件不存在") self.status_label.setText("错误: 文件不存在") return # 检查是否已存在相同文件 for i in reversed(range(self.preview_layout.count())): item = self.preview_layout.itemAt(i) if item.widget() and item.widget().objectName().startswith(f"container_{filename}"): widget_to_remove = item.widget() self.preview_layout.removeWidget(widget_to_remove) widget_to_remove.deleteLater() # 创建图像容器 container = QWidget() container.setObjectName(f"container_{filename}") container.setStyleSheet(""" background-color: white; border: 1px solid #ddd; border-radius: 5px; padding: 10px; margin: 5px; """) container.setFixedSize(300, 350) container_layout = QVBoxLayout(container) container_layout.setContentsMargins(5, 5, 5, 5) container_layout.setSpacing(5) # 显示文件名 filename_label = QLabel(os.path.basename(filename)) filename_label.setStyleSheet("font-size: 12px; color: #555;") filename_label.setAlignment(Qt.AlignCenter) container_layout.addWidget(filename_label) # 图像预览 pixmap = QPixmap(filename) if pixmap.width() > 280 or pixmap.height() > 200: pixmap = pixmap.scaled(280, 200, Qt.KeepAspectRatio, Qt.SmoothTransformation) preview_label = QLabel(container) preview_label.setPixmap(pixmap) preview_label.setAlignment(Qt.AlignCenter) preview_label.setFixedSize(280, 200) preview_label.setStyleSheet("border: 1px solid #eee;") container_layout.addWidget(preview_label) # 结果标签 result_label = QLabel("识别中...", container) result_label.setObjectName(f"result_{filename}") result_label.setAlignment(Qt.AlignCenter) result_label.setStyleSheet("font-size: 16px; font-weight: bold; padding: 5px;") container_layout.addWidget(result_label) # 删除按钮 delete_button = QPushButton("删除", container) delete_button.setObjectName(f"button_{filename}") delete_button.setStyleSheet(""" QPushButton { background-color: #ff6b6b; color: white; border: none; border-radius: 3px; padding: 5px; } QPushButton:hover { background-color: #ff5252; } """) delete_button.clicked.connect(lambda _, fn=filename: self.deleteImage(fn)) container_layout.addWidget(delete_button) # 添加到预览区域 self.preview_layout.addWidget(container) # 创建并启动图像处理线程 processor = ImageProcessor(self.model, self.device, filename) processor.result_signal.connect(self.updateUIWithResult) threading.Thread(target=processor.process_image).start() self.image_processors.append(processor) # 限制最大处理数量 if self.preview_layout.count() > 20: QMessageBox.warning(self, "警告", "最多只能同时处理20张图像") self.image_processors.clear() def deleteImage(self, filename): container_name = f"container_{filename}" container = self.findChild(QWidget, container_name) if container: self.preview_layout.removeWidget(container) container.deleteLater() self.status_label.setText(f"已删除: {os.path.basename(filename)}") def updateUIWithResult(self, filename, result): container = self.findChild(QWidget, f"container_{filename}") if container: result_label = container.findChild(QLabel, f"result_{filename}") if result_label: # 根据结果设置颜色 if "猫" in result: result_label.setStyleSheet("color: #1a73e8; font-size: 16px; font-weight: bold;") elif "狗" in result: result_label.setStyleSheet("color: #e91e63; font-size: 16px; font-weight: bold;") else: result_label.setStyleSheet("color: #f57c00; font-size: 16px; font-weight: bold;") result_label.setText(result) self.status_label.setText(f"完成识别: {os.path.basename(filename)} -> {result}") def clearAll(self): # 删除所有图像容器 while self.preview_layout.count(): item = self.preview_layout.takeAt(0) widget = item.widget() if widget is not None: widget.deleteLater() self.image_processors = [] self.status_label.setText("已清除所有图像") def showTrainingResults(self): """显示训练结果可视化图表""" if not os.path.exists('training_visualization.png'): QMessageBox.information(self, "提示", "训练结果可视化图表尚未生成") return try: # 创建结果展示窗口 results_window = QWidget() results_window.setWindowTitle("训练结果可视化") results_window.setGeometry(200, 200, 1200, 800) layout = QVBoxLayout() # 标题 title = QLabel("模型训练结果可视化") title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") title.setAlignment(Qt.AlignCenter) layout.addWidget(title) # 综合图表 layout.addWidget(QLabel("训练和验证准确率/损失:")) pixmap1 = QPixmap('training_visualization.png') label1 = QLabel() label1.setPixmap(pixmap1.scaled(1000, 500, Qt.KeepAspectRatio, Qt.SmoothTransformation)) layout.addWidget(label1) # 水平布局用于两个图表 h_layout = QHBoxLayout() # 准确率图表 vbox1 = QVBoxLayout() vbox1.addWidget(QLabel("准确率曲线:")) pixmap2 = QPixmap('accuracy_curve.png') label2 = QLabel() label2.setPixmap(pixmap2.scaled(450, 350, Qt.KeepAspectRatio, Qt.SmoothTransformation)) vbox1.addWidget(label2) h_layout.addLayout(vbox1) # 损失图表 vbox2 = QVBoxLayout() vbox2.addWidget(QLabel("损失曲线:")) pixmap3 = QPixmap('loss_curve.png') label3 = QLabel() label3.setPixmap(pixmap3.scaled(450, 350, Qt.KeepAspectRatio, Qt.SmoothTransformation)) vbox2.addWidget(label3) h_layout.addLayout(vbox2) layout.addLayout(h_layout) # 关闭按钮 close_button = QPushButton("关闭") close_button.setStyleSheet("font-size: 16px; padding: 8px;") close_button.clicked.connect(results_window.close) layout.addWidget(close_button, alignment=Qt.AlignCenter) results_window.setLayout(layout) results_window.show() except Exception as e: QMessageBox.critical(self, "错误", f"加载训练结果时出错: {str(e)}") def showModelSummary(self): """显示模型结构摘要""" # 创建摘要展示窗口 summary_window = QWidget() summary_window.setWindowTitle("模型结构摘要") summary_window.setGeometry(200, 200, 800, 600) layout = QVBoxLayout() # 标题 title = QLabel("模型各层参数状况") title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") title.setAlignment(Qt.AlignCenter) layout.addWidget(title) # 创建文本编辑框显示摘要 summary_text = QTextEdit() summary_text.setReadOnly(True) summary_text.setStyleSheet("font-family: monospace; font-size: 12px;") # 获取模型摘要 try: # 使用StringIO捕获summary的输出 from io import StringIO import sys # 重定向标准输出 original_stdout = sys.stdout sys.stdout = StringIO() # 生成模型摘要 summary(self.model, input_size=(3, 150, 150), device=self.device.type) # 获取捕获的输出 summary_output = sys.stdout.getvalue() # 恢复标准输出 sys.stdout = original_stdout # 显示摘要 summary_text.setPlainText(summary_output) except Exception as e: summary_text.setPlainText(f"生成模型摘要时出错: {str(e)}") layout.addWidget(summary_text) # 关闭按钮 close_button = QPushButton("关闭") close_button.setStyleSheet("font-size: 16px; padding: 8px;") close_button.clicked.connect(summary_window.close) layout.addWidget(close_button, alignment=Qt.AlignCenter) summary_window.setLayout(layout) summary_window.show() # 程序入口点 if __name__ == "__main__": # 设置数据集路径 data_path = os.path.join("data", "dogs-vs-cats") train_folder = os.path.join(data_path, 'train') test_folder = os.path.join(data_path, 'test') # 检查是否已有训练好的模型 model_path = "catdog_cnn_model_with_extra_dropout.pth" # 修改模型名称以反映更改 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 创建模型实例(使用添加了额外Dropout层的新模型) model = CatDogCNN() if os.path.exists(model_path): print("加载已训练的模型...") model.load_state_dict(torch.load(model_path, map_location=device)) model = model.to(device) model.eval() print("模型加载完成") else: print("未找到训练好的模型,开始训练新模型...") # 创建完整训练集和测试集(使用数据增强) # 训练集使用增强后的transform train_transform = transforms.Compose([ transforms.RandomRotation(15), # 随机旋转15度 transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.Resize((150, 150)), transforms.ColorJitter(brightness=0.2, contrast=0.2), # 随机调整亮度和对比度 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 验证集和测试集使用基础transform(不需要增强) base_transform = transforms.Compose([ transforms.Resize((150, 150)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) full_train_dataset = DogsVSCats(train_folder, transform=train_transform) test_dataset = DogsVSCats(test_folder, transform=base_transform) # 划分训练集和验证集 (80% 训练, 20% 验证) train_size = int(0.8 * len(full_train_dataset)) val_size = len(full_train_dataset) - train_size gen = torch.Generator().manual_seed(42) train_dataset, val_dataset = random_split( full_train_dataset, [train_size, val_size], generator=gen ) # 创建数据加载器 batch_size = 32 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0) # 训练模型 trainer = Trainer(model, train_loader, val_loader) num_epochs = 15 print(f"开始训练(带额外Dropout层和数据增强),共 {num_epochs} 个epoch...") trainer.train(num_epochs) # 保存最终模型 torch.save(model.state_dict(), model_path) print(f"模型已保存为 {model_path}") # 输出模型各层的参数状况 print("\n模型各层参数状况:") summary(model, input_size=(3, 150, 150), device=device.type) # 启动应用程序 app = QApplication(sys.argv) window = CatDogClassifierApp(model, device) window.show() sys.exit(app.exec_())对此代码进行优化
最新发布
06-26
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值