sys标准输入输出 ,input() raw_input(),StringIO

sys.stdout sys.stdin

print 等价于
sys.stdout.write(‘HelloWorld!’)
raw_input 等价于
sys.stdin.readline()[:-1]

其实可以理解为,标准输入输出重定向
# coding:utf-8
from StringIO import StringIO
import sys
# 生成一个StringIO对象
buff =StringIO()
# 保存标准I/O流
temp =sys.stdout
# 将标准I/O流重定向到buff对象                            
sys.stdout =buff       
# 这个打印不会打印在标准输出,所以不会显示 。打印到   buff   中                     
print 42, 'hello', 0.001
#恢复标准I/O流
sys.stdout=temp              
print "2***"
print buff.getvalue()
import sys
f1 = open('22.txt', 'w')
temp =sys.stdout
sys.stdout = f1
print 'abcd'
sys.stdout = temp

f2 = open('33.txt', 'r')
temp =sys.stdin
sys.stdin = f2
a = sys.stdin.readline()[:-1]   # a = raw_input() 也可以
print a

input() raw_input()区别

当输入为纯数字时

input返回的是数值类型,如int,float
raw_input返回的是字符串类型,string类型(raw_input不论输入什么返回的都是string类型)

输入字符串为表达式

input会计算在字符串中的数字表达式,而raw_input不会。

如输入 “57 + 3”:

input会得到整数60
raw_input会得到字符串”57 + 3”

也就是说input()输入严格按照Python的语法,是字符就自觉的加 ’ ’ ,数字就是数字(其实,input使用raw_input实现的)

StringIO

from StringIO import StringIO  
  
# 生成一个StringIO对象,当前缓冲区内容为ABCDEF  
s = StringIO('ABCDEF' )  
​
# 从开头写入,将会覆盖ABC  
s.write('abc')  
​
# 每次使用read()读取前,必须seek()  
# 定位到开头  
s.seek(0)  
​
# 将输出abcDEF  
print s.read()  
​
# 定位到第二个字符c  
s.seek(2)  
​
# 从当前位置一直读取到结束,将输出cDEF  
print s.read()  
​
s.seek(3)  ​
# 从第三个位置读取两个字符,将输出DE  
print s.read(2)  
​
s.seek(6)  
# 从指定位置写入
s.write('GH')  

s.seek(0)  
# 将输出abcDEFGH
 print s.read()  

# 如果读取所有内容,可以直接使用getvalue()
# 将输出abcDEFGH  

​print s.getvalue()  
import arcpy import pandas as pd import os import glob import datetime import chardet import codecs # 获取工具参数 input_folder = arcpy.GetParameterAsText(0) # 输入文件夹路径 keep_header = arcpy.GetParameter(1) # 是否保留表头(布尔值,默认True) try: # 验证输入文件夹 if not os.path.exists(input_folder): raise ValueError(f"输入文件夹不存在: {input_folder}") # 创建输出文件夹 output_folder = os.path.join(input_folder, "合并结果_UTF8") os.makedirs(output_folder, exist_ok=True) arcpy.AddMessage(f"创建输出文件夹: {output_folder}") # 获取所有支持的文件 all_files = [] for ext in ["*.xlsx", "*.xls", "*.csv", "*.txt", "*.dat"]: all_files.extend(glob.glob(os.path.join(input_folder, ext), recursive=False)) if not all_files: raise ValueError("未找到任何可合并的文件") arcpy.AddMessage(f"找到 {len(all_files)} 个可合并文件") # 按文件类型分组 file_groups = {} for file_path in all_files: ext = os.path.splitext(file_path)[1].lower() if ext not in file_groups: file_groups[ext] = [] file_groups[ext].append(file_path) arcpy.AddMessage(f"发现 {len(file_groups)} 种文件类型需要合并") # 处理每种文件类型 for ext, file_list in file_groups.items(): arcpy.AddMessage(f"\n开始处理 {ext} 文件 ({len(file_list)} 个文件)") # 生成输出文件名 timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") output_file = os.path.join(output_folder, f"合并结果_{ext[1:]}_{timestamp}{ext}") # 存储合并数据 merged_df = None header = None # 处理每个文件 for i, file_path in enumerate(file_list): file_name = os.path.basename(file_path) arcpy.AddMessage(f"处理文件 {i+1}/{len(file_list)}: {file_name}") # 读取文件 if ext in ['.xlsx', '.xls']: df = pd.read_excel(file_path, header=0 if keep_header else None) else: # 自动检测编码 with open(file_path, 'rb') as f: raw_data = f.read(10000) result = chardet.detect(raw_data) detected_encoding = result['encoding'] confidence = result['confidence'] arcpy.AddMessage(f" 检测到原始编码: {detected_encoding} (置信度: {confidence*100:.1f}%)") # 尝试读取并转换为UTF-8 try: # 第一步:用检测到的编码读取 with codecs.open(file_path, 'r', encoding=detected_encoding) as f: content = f.read() # 第二步:转换为UTF-8格式字符串 utf8_content = content.encode('utf-8', 'ignore').decode('utf-8') # 第三步:自动检测分隔符 first_line = utf8_content.split('\n')[0] if '\n' in utf8_content else utf8_content sep = ',' if ',' in first_line else '\t' if '\t' in first_line else ' ' # 第四步:从字符串创建DataFrame from io import StringIO df = pd.read_csv( StringIO(utf8_content), sep=sep, header=0 if keep_header else None, engine='python' ) arcpy.AddMessage(f" 成功转换为UTF-8格式") except Exception as e: arcpy.AddWarning(f" 转换失败: {str(e)}") # 尝试常见编码列表 encodings = ['gbk', 'gb18030', 'big5', 'latin1', 'iso-8859-1', 'cp1252', 'utf-8'] for enc in encodings: try: with codecs.open(file_path, 'r', encoding=enc) as f: content = f.read() utf8_content = content.encode('utf-8', 'ignore').decode('utf-8') first_line = utf8_content.split('\n')[0] if '\n' in utf8_content else utf8_content sep = ',' if ',' in first_line else '\t' if '\t' in first_line else ' ' from io import StringIO df = pd.read_csv( StringIO(utf8_content), sep=sep, header=0 if keep_header else None, engine='python' ) arcpy.AddMessage(f" 使用备用编码 {enc} 成功转换为UTF-8") break except Exception as e2: arcpy.AddWarning(f" 编码 {enc} 转换失败: {str(e2)}") else: raise ValueError(f"无法转换文件 {file_name} 为UTF-8格式") # 处理表头 if keep_header: if i == 0: header = df.columns.tolist() else: if len(df.columns) == len(header): df.columns = header else: arcpy.AddWarning(f"文件 {file_name} 列数不匹配(期望 {len(header)} 列,实际 {len(df.columns)} 列)") # 合并数据 if merged_df is None: merged_df = df else: merged_df = pd.concat([merged_df, df], ignore_index=True) # 保存合并结果(全部使用UTF-8编码) if ext in ['.xlsx', '.xls']: merged_df.to_excel(output_file, index=False, header=bool(keep_header)) arcpy.AddMessage(f"↳ 合并完成! 保存到: {output_file}") else: # 输出为UTF-8 with BOM格式 merged_df.to_csv(output_file, sep=',', index=False, header=bool(keep_header), encoding='utf-8-sig') arcpy.AddMessage(f"↳ 合并完成! 保存到: {output_file} (UTF-8 with BOM编码)") arcpy.AddMessage(f" 总行数: {len(merged_df)}, 总列数: {len(merged_df.columns)}") arcpy.AddMessage("\n所有文件类型合并完成!") arcpy.AddMessage(f"合并结果保存在: {output_folder} (全部转换为UTF-8格式)") except Exception as e: arcpy.AddError(f"处理失败: {str(e)}") import traceback arcpy.AddError(traceback.format_exc())
07-10
#!/usr/bin/env python3 import os import pandas as pd from glob import glob import sys import argparse from openpyxl.styles import Alignment from openpyxl import load_workbook from pathlib import Path import shutil import tempfile # 常见编码列表 COMMON_ENCODINGS = ['utf-8', 'latin1', 'gbk', 'iso-8859-1', 'latin1', 'cp1252','gb2312'] #COMMON_ENCODINGS = [‘utf-8’, ‘latin1’, ‘iso-8859-1’, ‘cp1252’, ‘gbk’, ‘gb2312’, ‘big5’] def detect_encoding(file_path): """静默检测文件编码""" for encoding in COMMON_ENCODINGS: try: with open(file_path, 'r', encoding=encoding) as f: f.read(1024) # 尝试读取前1KB内容 return encoding except: continue return 'utf-8' # 默认使用utf-8 def apply_left_alignment(output_file): """应用左对齐样式到所有单元格""" try: wb = load_workbook(output_file) ws = wb.active # 创建左对齐样式 left_align = Alignment(horizontal='left') # 应用到所有单元格 for row in ws.iter_rows(): for cell in row: cell.alignment = left_align wb.save(output_file) return True except Exception as e: print(f"应用左对齐样式时出错: {str(e)}", file=sys.stderr) return False def clean_file_content(file_path): """清理文件内容:执行字节替换操作""" try: with open(file_path, 'rb') as file: content = file.read() # 执行字节替换操作 cleaned_content = content.replace(b'\x01', b'\r').replace(b'\x00', b' ').replace(b'\x0A', b' ') with open(file_path, 'wb') as file: file.write(cleaned_content) return True except Exception as e: print(f"清理文件 {file_path} 时出错: {str(e)}", file=sys.stderr) return False def read_txt_file(file_path, sep): """读取单个文本文件,不处理标题行""" encoding = detect_encoding(file_path) try: return pd.read_csv( file_path, sep=sep, encoding=encoding, engine='python', header=None, # 不将第一行作为标题 dtype=str, # 所有数据作为字符串处理 keep_default_na=False, # 不将空值转换为NaN on_bad_lines='skip' # 跳过格式错误的行 ) except Exception as e: print(f"读取文件 {file_path} 时出错: {str(e)}", file=sys.stderr) return None def merge_txt_files(input_dir, output_file, sep='\t', recursive=False, clean_files=False): """ 合并目录下所有文本文件到单个Excel文件 :param input_dir: 输入目录路径 :param output_file: 输出Excel文件路径 :param sep: 文本文件分隔符,默认为制表符 :param recursive: 是否递归搜索子目录 :param clean_files: 是否在合并前执行字节替换清理 """ # 获取所有文本文件 pattern = os.path.join(input_dir, '**', '*.txt') if recursive \ else os.path.join(input_dir, '*.txt') txt_files = glob(pattern, recursive=recursive) if not txt_files: print(f"在 {input_dir} 中未找到任何.txt文件", file=sys.stderr) return False # 如果需要清理文件,创建临时目录处理 if clean_files: temp_dir = tempfile.mkdtemp() print(f"创建临时目录: {temp_dir}") # 复制文件到临时目录并清理 for file_path in txt_files: temp_path = os.path.join(temp_dir, os.path.basename(file_path)) shutil.copy2(file_path, temp_path) if not clean_file_content(temp_path): print(f"清理失败: {os.path.basename(file_path)}", file=sys.stderr) continue print(f"已清理: {os.path.basename(file_path)}") # 使用临时目录中的文件 input_dir = temp_dir pattern = os.path.join(temp_dir, '*.txt') txt_files = glob(pattern, recursive=False) all_data = [] for file_path in txt_files: df = read_txt_file(file_path, sep) if df is not None and not df.empty: all_data.append(df) print(f"已处理: {os.path.basename(file_path)}") if not all_data: print("所有文件均为空或无法读取", file=sys.stderr) return False try: # 合并所有数据 combined_df = pd.concat(all_data, ignore_index=True) # 写入Excel文件 combined_df.to_excel(output_file, sheet_name='合并数据', index=False, header=False) print(f"已创建Excel文件: {output_file}") # 应用左对齐样式 if apply_left_alignment(output_file): return True return False except Exception as e: print(f"合并或写入文件时出错: {str(e)}", file=sys.stderr) return False finally: # 清理临时目录 if clean_files and 'temp_dir' in locals(): shutil.rmtree(temp_dir) print(f"已删除临时目录: {temp_dir}") if __name__ == "__main__": parser = argparse.ArgumentParser( description='合并多个文本文件到单个Excel文件,支持字节替换清理', formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('-i', '--input', required=True, help='包含.txt文件的输入目录路径') parser.add_argument('-o', '--output', default='合并数据.xlsx', help='输出Excel文件路径') parser.add_argument('-s', '--sep', default='\t', help='文本文件中的分隔符,如",", ";", "\\t"等') parser.add_argument('-r', '--recursive', action='store_true', help='递归搜索子目录中的文件') parser.add_argument('-c', '--clean', action='store_true', help='在合并前执行字节替换清理(替换 \\x01, \\x00, \\x0A)') args = parser.parse_args() # 确保输出目录存在 output_dir = os.path.dirname(args.output) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) success = merge_txt_files( args.input, args.output, args.sep, args.recursive, args.clean ) sys.exit(0 if success else 1) 智能纠错:修改读取文件格式错误的问题
最新发布
07-17
import sys import subprocess import zipfile import pkg_resources import requests # 检查并安装缺失的依赖 required = { 'torch', 'torchvision', 'numpy', 'matplotlib', 'tqdm', 'requests', 'pillow', 'scikit-learn', 'pyqt5', 'torchsummary' # 添加torchsummary } installed = {pkg.key for pkg in pkg_resources.working_set} missing = required - installed if missing: print(f"安装缺失的依赖: {', '.join(missing)}") python = sys.executable subprocess.check_call([python, '-m', 'pip', 'install', *missing]) # 现在导入其他模块 import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, random_split from torchvision import datasets, transforms, models import numpy as np import matplotlib.pyplot as plt import os import shutil from PIL import Image from tqdm import tqdm import matplotlib from matplotlib import font_manager import json from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay # PyQt5相关导入 from PyQt5.QtWidgets import (QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QScrollArea, QFileDialog, QMessageBox, QTextEdit) from PyQt5.QtGui import QPixmap from PyQt5.QtCore import Qt, QObject, pyqtSignal import threading import time # 导入torchsummary from torchsummary import summary # 设置中文字体支持 try: plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False except: try: font_url = "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansSC-Regular.otf" font_path = "NotoSansSC-Regular.otf" if not os.path.exists(font_path): response = requests.get(font_url) with open(font_path, 'wb') as f: f.write(response.content) font_prop = font_manager.FontProperties(fname=font_path) plt.rcParams['font.family'] = font_prop.get_name() except: print("警告: 无法设置中文字体") matplotlib.use('Agg') # 第二部分:下载并设置数据集 def download_and_extract_dataset(): base_dir = "data" data_path = os.path.join(base_dir, "dogs-vs-cats") train_folder = os.path.join(data_path, 'train') test_folder = os.path.join(data_path, 'test') os.makedirs(train_folder, exist_ok=True) os.makedirs(test_folder, exist_ok=True) # 检查数据集是否完整 cat_files = [f for f in os.listdir(train_folder) if f.startswith('cat')] dog_files = [f for f in os.listdir(train_folder) if f.startswith('dog')] if len(cat_files) > 1000 and len(dog_files) > 1000: print("数据集已存在,跳过下载") return print("正在下载数据集...") dataset_url = "https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip" try: zip_path = os.path.join(base_dir, "catsdogs.zip") # 下载文件 if not os.path.exists(zip_path): response = requests.get(dataset_url, stream=True) total_size = int(response.headers.get('content-length', 0)) with open(zip_path, 'wb') as f, tqdm( desc="下载进度", total=total_size, unit='B', unit_scale=True, unit_divisor=1024, ) as bar: for data in response.iter_content(chunk_size=1024): size = f.write(data) bar.update(size) print("下载完成,正在解压...") # 解压文件 with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(base_dir) print("数据集解压完成!") # 移动文件 extracted_dir = os.path.join(base_dir, "PetImages") # 移动猫图片 cat_source = os.path.join(extracted_dir, "Cat") for file in os.listdir(cat_source): src = os.path.join(cat_source, file) dst = os.path.join(train_folder, f"cat.{file}") if os.path.exists(src) and not os.path.exists(dst): shutil.move(src, dst) # 移动狗图片 dog_source = os.path.join(extracted_dir, "Dog") for file in os.listdir(dog_source): src = os.path.join(dog_source, file) dst = os.path.join(train_folder, f"dog.{file}") if os.path.exists(src) and not os.path.exists(dst): shutil.move(src, dst) # 创建测试集(从训练集中抽取20%) train_files = os.listdir(train_folder) np.random.seed(42) test_files = np.random.choice(train_files, size=int(len(train_files) * 0.2), replace=False) for file in test_files: src = os.path.join(train_folder, file) dst = os.path.join(test_folder, file) if os.path.exists(src) and not os.path.exists(dst): shutil.move(src, dst) # 清理临时文件 if os.path.exists(extracted_dir): shutil.rmtree(extracted_dir) if os.path.exists(zip_path): os.remove(zip_path) print( f"数据集设置完成!训练集: {len(os.listdir(train_folder))} 张图片, 测试集: {len(os.listdir(test_folder))} 张图片") except Exception as e: print(f"下载或设置数据集时出错: {str(e)}") print("请手动下载数据集并解压到 data/dogs-vs-cats 目录") print("下载地址: https://www.microsoft.com/en-us/download/details.aspx?id=54765") # 下载并解压数据集 download_and_extract_dataset() # 第三部分:自定义数据集 class DogsVSCats(Dataset): def __init__(self, data_dir, transform=None): self.image_paths = [] self.labels = [] for file in os.listdir(data_dir): if file.lower().endswith(('.png', '.jpg', '.jpeg')): img_path = os.path.join(data_dir, file) try: # 验证图片完整性 with Image.open(img_path) as img: img.verify() self.image_paths.append(img_path) # 根据文件名设置标签 if file.startswith('cat'): self.labels.append(0) elif file.startswith('dog'): self.labels.append(1) else: # 对于无法识别的文件,默认设为猫 self.labels.append(0) except (IOError, SyntaxError) as e: print(f"跳过损坏图片: {img_path} - {str(e)}") if not self.image_paths: print(f"错误: 在 {data_dir} 中没有找到有效图片!") for i in range(10): img_path = os.path.join(data_dir, f"example_{i}.jpg") img = Image.new('RGB', (224, 224), color=(i * 25, i * 25, i * 25)) img.save(img_path) self.image_paths.append(img_path) self.labels.append(0 if i % 2 == 0 else 1) print(f"已创建 {len(self.image_paths)} 个示例图片") self.transform = transform or transforms.Compose([ transforms.Resize((150, 150)), # 修改为150x150以匹配CNN输入 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): try: image = Image.open(self.image_paths[idx]).convert('RGB') except Exception as e: print(f"无法加载图片: {self.image_paths[idx]}, 使用占位符 - {str(e)}") image = Image.new('RGB', (150, 150), color=(100, 100, 100)) image = self.transform(image) label = torch.tensor(self.labels[idx], dtype=torch.long) return image, label # 第六部分:定义自定义CNN模型(添加额外的Dropout层) class CatDogCNN(nn.Module): def __init__(self): super(CatDogCNN, self).__init__() # 卷积层1: 输入3通道(RGB), 输出32通道, 卷积核3x3 self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) # 卷积层2: 输入32通道, 输出64通道, 卷积核3x3 self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 卷积层3: 输入64通道, 输出128通道, 卷积核3x3 self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # 卷积层4: 输入128通道, 输出256通道, 卷积核3x3 self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # 最大池化层 self.pool = nn.MaxPool2d(2, 2) # 全连接层 self.fc1 = nn.Linear(256 * 9 * 9, 512) # 输入尺寸计算: 150 -> 75 -> 37 -> 18 -> 9 self.fc2 = nn.Linear(512, 2) # 输出2个类别 (猫和狗) # Dropout防止过拟合(添加额外的Dropout层) self.dropout1 = nn.Dropout(0.5) # 第一个Dropout层 self.dropout2 = nn.Dropout(0.5) # 新添加的第二个Dropout层 def forward(self, x): # 卷积层1 + ReLU + 池化 x = self.pool(F.relu(self.conv1(x))) # 卷积层2 + ReLU + 池化 x = self.pool(F.relu(self.conv2(x))) # 卷积层3 + ReLU + 池化 x = self.pool(F.relu(self.conv3(x))) # 卷积层4 + ReLU + 池化 x = self.pool(F.relu(self.conv4(x))) # 展平特征图 x = x.view(-1, 256 * 9 * 9) # 全连接层 + Dropout x = self.dropout1(F.relu(self.fc1(x))) # 添加第二个Dropout层 x = self.dropout2(x) # 输出层 x = self.fc2(x) return x # 第七部分:模型训练和可视化 class Trainer: def __init__(self, model, train_loader, val_loader): self.train_loader = train_loader self.val_loader = val_loader self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {self.device}") self.model = model.to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) self.criterion = nn.CrossEntropyLoss() # 使用兼容性更好的调度器设置(移除了 verbose 参数) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='max', factor=0.1, patience=2) # 记录指标 self.train_losses = [] self.train_accuracies = [] self.val_losses = [] self.val_accuracies = [] def train(self, num_epochs): best_accuracy = 0.0 for epoch in range(num_epochs): # 训练阶段 self.model.train() running_loss = 0.0 correct = 0 total = 0 train_bar = tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [训练]") for images, labels in train_bar: images, labels = images.to(self.device), labels.to(self.device) self.optimizer.zero_grad() outputs = self.model(images) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() running_loss += loss.item() * images.size(0) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() train_loss = running_loss / total train_acc = correct / total train_bar.set_postfix(loss=train_loss, acc=train_acc) # 计算训练指标 epoch_train_loss = running_loss / total epoch_train_acc = correct / total self.train_losses.append(epoch_train_loss) self.train_accuracies.append(epoch_train_acc) # 验证阶段 val_loss, val_acc = self.validate() self.val_losses.append(val_loss) self.val_accuracies.append(val_acc) # 更新学习率 self.scheduler.step(val_acc) # 保存最佳模型 if val_acc > best_accuracy: best_accuracy = val_acc torch.save(self.model.state_dict(), 'best_cnn_model.pth') print(f"保存最佳模型,验证准确率: {best_accuracy:.4f}") # 打印epoch结果 print(f"Epoch {epoch + 1}/{num_epochs} | " f"训练损失: {epoch_train_loss:.4f} | 训练准确率: {epoch_train_acc:.4f} | " f"验证损失: {val_loss:.4f} | 验证准确率: {val_acc:.4f}") # 训练完成后可视化结果 self.visualize_training_results() def validate(self): self.model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): val_bar = tqdm(self.val_loader, desc="[验证]") for images, labels in val_bar: images, labels = images.to(self.device), labels.to(self.device) outputs = self.model(images) loss = self.criterion(outputs, labels) running_loss += loss.item() * images.size(0) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() val_loss = running_loss / total val_acc = correct / total val_bar.set_postfix(loss=val_loss, acc=val_acc) return running_loss / total, correct / total def visualize_training_results(self): """可视化训练和验证的准确率与损失""" epochs = range(1, len(self.train_accuracies) + 1) # 创建准确率图表 plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.plot(epochs, self.train_accuracies, 'bo-', label='训练准确率') plt.plot(epochs, self.val_accuracies, 'ro-', label='验证准确率') plt.title('训练和验证准确率') plt.xlabel('Epoch') plt.ylabel('准确率') plt.legend() plt.grid(True) # 创建损失图表 plt.subplot(1, 2, 2) plt.plot(epochs, self.train_losses, 'bo-', label='训练损失') plt.plot(epochs, self.val_losses, 'ro-', label='验证损失') plt.title('训练和验证损失') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig('training_visualization.png') print("训练结果可视化图表已保存为 training_visualization.png") # 单独保存准确率图表 plt.figure(figsize=(8, 6)) plt.plot(epochs, self.train_accuracies, 'bo-', label='训练准确率') plt.plot(epochs, self.val_accuracies, 'ro-', label='验证准确率') plt.title('训练和验证准确率') plt.xlabel('Epoch') plt.ylabel('准确率') plt.legend() plt.grid(True) plt.savefig('accuracy_curve.png') print("准确率曲线已保存为 accuracy_curve.png") # 单独保存损失图表 plt.figure(figsize=(8, 6)) plt.plot(epochs, self.train_losses, 'bo-', label='训练损失') plt.plot(epochs, self.val_losses, 'ro-', label='验证损失') plt.title('训练和验证损失') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.grid(True) plt.savefig('loss_curve.png') print("损失曲线已保存为 loss_curve.png") # 保存训练结果 results = { 'epochs': list(epochs), 'train_losses': self.train_losses, 'train_accuracies': self.train_accuracies, 'val_losses': self.val_losses, 'val_accuracies': self.val_accuracies } with open('training_results.json', 'w') as f: json.dump(results, f) print("训练结果已保存为 training_results.json") # 图像处理类 class ImageProcessor(QObject): result_signal = pyqtSignal(str, str) # 信号:filename, result def __init__(self, model, device, filename): super().__init__() self.model = model self.device = device self.filename = filename self.transform = transforms.Compose([ transforms.Resize((150, 150)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def process_image(self): try: # 加载图像 image = Image.open(self.filename).convert('RGB') image_tensor = self.transform(image).unsqueeze(0).to(self.device) # 模型预测 self.model.eval() with torch.no_grad(): output = self.model(image_tensor) probabilities = F.softmax(output, dim=1) _, predicted = torch.max(output, 1) # 获取猫和狗的置信度 cat_prob = probabilities[0][0].item() dog_prob = probabilities[0][1].item() # 确定结果和置信度 result = "猫" if predicted.item() == 0 else "狗" confidence = cat_prob if result == "猫" else dog_prob # 格式化输出结果 formatted_result = f"{result} ({confidence * 100:.1f}%置信度)" self.result_signal.emit(self.filename, formatted_result) except Exception as e: self.result_signal.emit(self.filename, f"处理错误: {str(e)}") # 主应用窗口 class CatDogClassifierApp(QWidget): def __init__(self, model, device): super().__init__() self.setWindowTitle("猫狗识别系统") self.setGeometry(100, 100, 1000, 700) self.model = model self.device = device self.initUI() self.image_processors = [] def initUI(self): # 主布局 main_layout = QVBoxLayout() # 标题 title = QLabel("猫狗识别系统") title.setAlignment(Qt.AlignCenter) title.setStyleSheet("font-size: 24px; font-weight: bold; margin: 10px;") main_layout.addWidget(title) # 按钮区域 button_layout = QHBoxLayout() self.upload_button = QPushButton("上传图像") self.upload_button.setStyleSheet("font-size: 16px; padding: 10px;") self.upload_button.clicked.connect(self.uploadImage) button_layout.addWidget(self.upload_button) self.batch_process_button = QPushButton("批量处理") self.batch_process_button.setStyleSheet("font-size: 16px; padding: 10px;") self.batch_process_button.clicked.connect(self.batchProcess) button_layout.addWidget(self.batch_process_button) self.clear_button = QPushButton("清除所有") self.clear_button.setStyleSheet("font-size: 16px; padding: 10px;") self.clear_button.clicked.connect(self.clearAll) button_layout.addWidget(self.clear_button) self.results_button = QPushButton("查看训练结果") self.results_button.setStyleSheet("font-size: 16px; padding: 10px;") self.results_button.clicked.connect(self.showTrainingResults) button_layout.addWidget(self.results_button) # 添加查看模型结构按钮 self.model_summary_button = QPushButton("查看模型结构") self.model_summary_button.setStyleSheet("font-size: 16px; padding: 10px;") self.model_summary_button.clicked.connect(self.showModelSummary) button_layout.addWidget(self.model_summary_button) main_layout.addLayout(button_layout) # 状态标签 self.status_label = QLabel("就绪") self.status_label.setStyleSheet("font-size: 14px; color: #666; margin: 5px;") main_layout.addWidget(self.status_label) # 图像预览区域 self.preview_area = QScrollArea() self.preview_area.setWidgetResizable(True) self.preview_area.setStyleSheet("background-color: #f0f0f0;") self.preview_widget = QWidget() self.preview_layout = QHBoxLayout() self.preview_layout.setAlignment(Qt.AlignTop | Qt.AlignLeft) self.preview_widget.setLayout(self.preview_layout) self.preview_area.setWidget(self.preview_widget) main_layout.addWidget(self.preview_area) # 底部信息 info_label = QLabel("基于卷积神经网络(CNN)的猫狗识别系统 | 支持上传单张或多张图片") info_label.setAlignment(Qt.AlignCenter) info_label.setStyleSheet("font-size: 12px; color: #888; margin: 10px;") main_layout.addWidget(info_label) self.setLayout(main_layout) def uploadImage(self): self.status_label.setText("正在选择图像...") filename, _ = QFileDialog.getOpenFileName( self, "选择图像", "", "图像文件 (*.png *.jpg *.jpeg)" ) if filename: self.status_label.setText(f"正在处理: {os.path.basename(filename)}") self.displayImage(filename) def batchProcess(self): self.status_label.setText("正在选择多张图像...") filenames, _ = QFileDialog.getOpenFileNames( self, "选择多张图像", "", "图像文件 (*.png *.jpg *.jpeg)" ) if filenames: self.status_label.setText(f"正在批量处理 {len(filenames)} 张图像...") for filename in filenames: self.displayImage(filename) def displayImage(self, filename): if not os.path.isfile(filename): QMessageBox.warning(self, "警告", "文件路径不安全或文件不存在") self.status_label.setText("错误: 文件不存在") return # 检查是否已存在相同文件 for i in reversed(range(self.preview_layout.count())): item = self.preview_layout.itemAt(i) if item.widget() and item.widget().objectName().startswith(f"container_{filename}"): widget_to_remove = item.widget() self.preview_layout.removeWidget(widget_to_remove) widget_to_remove.deleteLater() # 创建图像容器 container = QWidget() container.setObjectName(f"container_{filename}") container.setStyleSheet(""" background-color: white; border: 1px solid #ddd; border-radius: 5px; padding: 10px; margin: 5px; """) container.setFixedSize(300, 350) container_layout = QVBoxLayout(container) container_layout.setContentsMargins(5, 5, 5, 5) container_layout.setSpacing(5) # 显示文件名 filename_label = QLabel(os.path.basename(filename)) filename_label.setStyleSheet("font-size: 12px; color: #555;") filename_label.setAlignment(Qt.AlignCenter) container_layout.addWidget(filename_label) # 图像预览 pixmap = QPixmap(filename) if pixmap.width() > 280 or pixmap.height() > 200: pixmap = pixmap.scaled(280, 200, Qt.KeepAspectRatio, Qt.SmoothTransformation) preview_label = QLabel(container) preview_label.setPixmap(pixmap) preview_label.setAlignment(Qt.AlignCenter) preview_label.setFixedSize(280, 200) preview_label.setStyleSheet("border: 1px solid #eee;") container_layout.addWidget(preview_label) # 结果标签 result_label = QLabel("识别中...", container) result_label.setObjectName(f"result_{filename}") result_label.setAlignment(Qt.AlignCenter) result_label.setStyleSheet("font-size: 16px; font-weight: bold; padding: 5px;") container_layout.addWidget(result_label) # 删除按钮 delete_button = QPushButton("删除", container) delete_button.setObjectName(f"button_{filename}") delete_button.setStyleSheet(""" QPushButton { background-color: #ff6b6b; color: white; border: none; border-radius: 3px; padding: 5px; } QPushButton:hover { background-color: #ff5252; } """) delete_button.clicked.connect(lambda _, fn=filename: self.deleteImage(fn)) container_layout.addWidget(delete_button) # 添加到预览区域 self.preview_layout.addWidget(container) # 创建并启动图像处理线程 processor = ImageProcessor(self.model, self.device, filename) processor.result_signal.connect(self.updateUIWithResult) threading.Thread(target=processor.process_image).start() self.image_processors.append(processor) # 限制最大处理数量 if self.preview_layout.count() > 20: QMessageBox.warning(self, "警告", "最多只能同时处理20张图像") self.image_processors.clear() def deleteImage(self, filename): container_name = f"container_{filename}" container = self.findChild(QWidget, container_name) if container: self.preview_layout.removeWidget(container) container.deleteLater() self.status_label.setText(f"已删除: {os.path.basename(filename)}") def updateUIWithResult(self, filename, result): container = self.findChild(QWidget, f"container_{filename}") if container: result_label = container.findChild(QLabel, f"result_{filename}") if result_label: # 根据结果设置颜色 if "猫" in result: result_label.setStyleSheet("color: #1a73e8; font-size: 16px; font-weight: bold;") elif "狗" in result: result_label.setStyleSheet("color: #e91e63; font-size: 16px; font-weight: bold;") else: result_label.setStyleSheet("color: #f57c00; font-size: 16px; font-weight: bold;") result_label.setText(result) self.status_label.setText(f"完成识别: {os.path.basename(filename)} -> {result}") def clearAll(self): # 删除所有图像容器 while self.preview_layout.count(): item = self.preview_layout.takeAt(0) widget = item.widget() if widget is not None: widget.deleteLater() self.image_processors = [] self.status_label.setText("已清除所有图像") def showTrainingResults(self): """显示训练结果可视化图表""" if not os.path.exists('training_visualization.png'): QMessageBox.information(self, "提示", "训练结果可视化图表尚未生成") return try: # 创建结果展示窗口 results_window = QWidget() results_window.setWindowTitle("训练结果可视化") results_window.setGeometry(200, 200, 1200, 800) layout = QVBoxLayout() # 标题 title = QLabel("模型训练结果可视化") title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") title.setAlignment(Qt.AlignCenter) layout.addWidget(title) # 综合图表 layout.addWidget(QLabel("训练和验证准确率/损失:")) pixmap1 = QPixmap('training_visualization.png') label1 = QLabel() label1.setPixmap(pixmap1.scaled(1000, 500, Qt.KeepAspectRatio, Qt.SmoothTransformation)) layout.addWidget(label1) # 水平布局用于两个图表 h_layout = QHBoxLayout() # 准确率图表 vbox1 = QVBoxLayout() vbox1.addWidget(QLabel("准确率曲线:")) pixmap2 = QPixmap('accuracy_curve.png') label2 = QLabel() label2.setPixmap(pixmap2.scaled(450, 350, Qt.KeepAspectRatio, Qt.SmoothTransformation)) vbox1.addWidget(label2) h_layout.addLayout(vbox1) # 损失图表 vbox2 = QVBoxLayout() vbox2.addWidget(QLabel("损失曲线:")) pixmap3 = QPixmap('loss_curve.png') label3 = QLabel() label3.setPixmap(pixmap3.scaled(450, 350, Qt.KeepAspectRatio, Qt.SmoothTransformation)) vbox2.addWidget(label3) h_layout.addLayout(vbox2) layout.addLayout(h_layout) # 关闭按钮 close_button = QPushButton("关闭") close_button.setStyleSheet("font-size: 16px; padding: 8px;") close_button.clicked.connect(results_window.close) layout.addWidget(close_button, alignment=Qt.AlignCenter) results_window.setLayout(layout) results_window.show() except Exception as e: QMessageBox.critical(self, "错误", f"加载训练结果时出错: {str(e)}") def showModelSummary(self): """显示模型结构摘要""" # 创建摘要展示窗口 summary_window = QWidget() summary_window.setWindowTitle("模型结构摘要") summary_window.setGeometry(200, 200, 800, 600) layout = QVBoxLayout() # 标题 title = QLabel("模型各层参数状况") title.setStyleSheet("font-size: 20px; font-weight: bold; margin: 10px;") title.setAlignment(Qt.AlignCenter) layout.addWidget(title) # 创建文本编辑框显示摘要 summary_text = QTextEdit() summary_text.setReadOnly(True) summary_text.setStyleSheet("font-family: monospace; font-size: 12px;") # 获取模型摘要 try: # 使用StringIO捕获summary的输出 from io import StringIO import sys # 重定向标准输出 original_stdout = sys.stdout sys.stdout = StringIO() # 生成模型摘要 summary(self.model, input_size=(3, 150, 150), device=self.device.type) # 获取捕获的输出 summary_output = sys.stdout.getvalue() # 恢复标准输出 sys.stdout = original_stdout # 显示摘要 summary_text.setPlainText(summary_output) except Exception as e: summary_text.setPlainText(f"生成模型摘要时出错: {str(e)}") layout.addWidget(summary_text) # 关闭按钮 close_button = QPushButton("关闭") close_button.setStyleSheet("font-size: 16px; padding: 8px;") close_button.clicked.connect(summary_window.close) layout.addWidget(close_button, alignment=Qt.AlignCenter) summary_window.setLayout(layout) summary_window.show() # 程序入口点 if __name__ == "__main__": # 设置数据集路径 data_path = os.path.join("data", "dogs-vs-cats") train_folder = os.path.join(data_path, 'train') test_folder = os.path.join(data_path, 'test') # 检查是否已有训练好的模型 model_path = "catdog_cnn_model_with_extra_dropout.pth" # 修改模型名称以反映更改 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 创建模型实例(使用添加了额外Dropout层的新模型) model = CatDogCNN() if os.path.exists(model_path): print("加载已训练的模型...") model.load_state_dict(torch.load(model_path, map_location=device)) model = model.to(device) model.eval() print("模型加载完成") else: print("未找到训练好的模型,开始训练新模型...") # 创建完整训练集和测试集(使用数据增强) # 训练集使用增强后的transform train_transform = transforms.Compose([ transforms.RandomRotation(15), # 随机旋转15度 transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.Resize((150, 150)), transforms.ColorJitter(brightness=0.2, contrast=0.2), # 随机调整亮度和对比度 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 验证集和测试集使用基础transform(不需要增强) base_transform = transforms.Compose([ transforms.Resize((150, 150)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) full_train_dataset = DogsVSCats(train_folder, transform=train_transform) test_dataset = DogsVSCats(test_folder, transform=base_transform) # 划分训练集和验证集 (80% 训练, 20% 验证) train_size = int(0.8 * len(full_train_dataset)) val_size = len(full_train_dataset) - train_size gen = torch.Generator().manual_seed(42) train_dataset, val_dataset = random_split( full_train_dataset, [train_size, val_size], generator=gen ) # 创建数据加载器 batch_size = 32 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0) # 训练模型 trainer = Trainer(model, train_loader, val_loader) num_epochs = 15 print(f"开始训练(带额外Dropout层和数据增强),共 {num_epochs} 个epoch...") trainer.train(num_epochs) # 保存最终模型 torch.save(model.state_dict(), model_path) print(f"模型已保存为 {model_path}") # 输出模型各层的参数状况 print("\n模型各层参数状况:") summary(model, input_size=(3, 150, 150), device=device.type) # 启动应用程序 app = QApplication(sys.argv) window = CatDogClassifierApp(model, device) window.show() sys.exit(app.exec_())对此代码进行优化
06-26
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值