restore_utf8、utf8togbk

本文提供了一个Python脚本,用于批量将*.h.utf8文件转换为GBK编码,并导出到对应的*.h文件。适用于Windows系统。
restore_utf8.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Run "python restore_utf8.py" to rename *.h.utf8 to *.h.
#

import os

def restore_utf8(dir):
resultfn = ''
for fn in os.listdir(dir):
sfile = os.path.join(dir, fn)
if os.path.isdir(sfile):
resultfn += restore_utf8(sfile)
continue
if fn.endswith('.utf8'):
orgfile = sfile[:-5]
try:
if os.path.exists(orgfile): os.remove(orgfile)
os.rename(sfile, orgfile)
resultfn += fn[:-5] + ' '
except:
print('except for %s' %(fn,))
return resultfn

if __name__=="__main__":
resultfn = restore_utf8(os.path.abspath('.'))
resultfn += restore_utf8(os.path.abspath('../core'))
resultfn += restore_utf8(os.path.abspath('../android'))
if resultfn != '': print('restore files: ' + resultfn)




#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Run "python utf8togbk.py" to convert source code files to the GBK format on Windows.
#

import os, codecs, sys

reload(sys)
sys.setdefaultencoding('gbk')

def utf8togbk(dir):
resultfn = ''
for fn in os.listdir(dir):
sfile = os.path.join(dir, fn)
if os.path.isdir(sfile):
resultfn += utf8togbk(sfile)
continue
if fn.endswith('.h') or fn.endswith('.cpp'):
if os.path.exists(sfile + '.utf8'):
continue
try:
text = open(sfile,'r',-1,'utf-8').read()
oldtext = text
except UnicodeDecodeError:
continue
except TypeError:
text = open(sfile).read()
oldtext = text
try:
if text[:3] == codecs.BOM_UTF8:
u = text[3:].decode('utf-8')
text = u.encode('gbk')
except UnicodeEncodeError:
continue
except UnicodeDecodeError:
continue
try:
text = text.replace('\r\n','\n')
text = text.replace('\n','\r\n')
if cmp(text, oldtext) != 0:
os.rename(sfile, sfile + '.utf8')
open(sfile, 'wb').write(text)
resultfn += fn + ' '
st = os.stat(sfile + '.utf8')
os.utime(sfile, (st.st_atime, st.st_mtime))
except:
print('except for %s' %(fn,))
return resultfn

if __name__=="__main__":
resultfn = utf8togbk(os.path.abspath('.'))
resultfn += utf8togbk(os.path.abspath('../core'))
resultfn += utf8togbk(os.path.abspath('../android'))
if resultfn != '': print('utf8->gbk: ' + resultfn)
1、warn sheet中第二行为标题行,第三行往后才是数据行,并且你是按照顺序填充的而不是按照名称进行填充的,实际上在warn sheet中上述对应的列并不是按照顺序进行排列的。 2、如果你能够按照递归将文件的变更差分出来,并且能够找到具体变更了哪几行,可以将winmerge的功能全部删除。 3、在ファイル差分 sheet中你将我在N列的内容给删除掉了,我只要你将A列内容从第二行开始填充即可,N列内容为公式不要删除。 4、_org_fm sheet中的内容,也是与warn sheet一样的处理但是这个标题行在第一行。 import os import re import subprocess import pandas as pd from openpyxl import load_workbook, Workbook import difflib import sys import io import time import shutil from pathlib import Path from collections import defaultdict import numpy as np import traceback # 设置系统标准输出为UTF-8 sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace') def recursive_compare_dirs(old_dir, new_dir): """ 递归比较两个目录,获取变更文件和变更行号(基于新文件行号) """ changed_files = defaultdict(set) print(f"递归比较目录: {old_dir} 和 {new_dir}") # 使用更高效的os.scandir替代os.walk for entry in os.scandir(new_dir): if entry.is_dir(): # 递归处理子目录 sub_changed = recursive_compare_dirs( os.path.join(old_dir, entry.name), os.path.join(new_dir, entry.name) ) for rel_path, lines in sub_changed.items(): changed_files[os.path.join(entry.name, rel_path)] = lines elif entry.is_file(): rel_path = os.path.relpath(entry.path, new_dir) old_path = os.path.join(old_dir, rel_path) # 处理新增文件 if not os.path.exists(old_path): try: # 使用更高效的行数统计方法 with open(entry.path, 'rb') as f: line_count = sum(1 for _ in f) # 标记所有行为已变更(基于新文件行号) changed_lines = set(range(1, line_count + 1)) changed_files[rel_path] = changed_lines print(f"新增文件: {rel_path}, 行数: {line_count}") except Exception as e: print(f"读取新文件出错: {entry.path} - {e}") continue # 处理修改文件 - 使用文件大小和修改时间快速过滤 if (os.path.getsize(entry.path) == os.path.getsize(old_path) and os.path.getmtime(entry.path) <= os.path.getmtime(old_path)): continue try: # 读取文件内容 with open(old_path, 'r', encoding='utf-8', errors='ignore') as f_old: old_content = f_old.readlines() with open(entry.path, 'r', encoding='utf-8', errors='ignore') as f_new: new_content = f_new.readlines() # 比较内容差异并获取变更行号(基于新文件) changed_lines = detect_changed_lines(old_content, new_content) if changed_lines: changed_files[rel_path] = changed_lines print(f"变更文件: {rel_path}, 变更行数: {len(changed_lines)}") except Exception as e: print(f"比较文件出错: {rel_path} - {e}") return dict(changed_files) def detect_changed_lines(old_content, new_content): """ 优化版:检测文件中的变更行号(基于新文件行号) """ changed_lines = set() matcher = difflib.SequenceMatcher(None, old_content, new_content) for opcode in matcher.get_opcodes(): if opcode[0] != 'equal': # 获取新文件中的变更行号范围(索引+1转换为实际行号) start = opcode[3] + 1 # 转换为基于1的行号 end = opcode[4] + 1 # 转换为基于1的行号 changed_lines.update(range(start, end)) return changed_lines def get_changed_files_and_lines(old_dir, new_dir, winmerge_path, save_report=False): """使用WinMerge获取变更文件列表及具体变更行号(基于新文件行号)""" # 创建临时目录存放报告 temp_dir = os.path.join(os.path.dirname(__file__), "temp") os.makedirs(temp_dir, exist_ok=True) report_file = os.path.join(temp_dir, "winmerge_diff_report.txt") # 移除路径结尾的反斜杠 old_dir = old_dir.rstrip('\\') new_dir = new_dir.rstrip('\\') print(f"开始WinMerge比较: 旧目录={old_dir}, 新目录={new_dir}") # WinMerge命令参数 cmd = [ f'"{winmerge_path}"', '/u', '/r', '/minimize', '/noprefs', '/noninteractive', f'/report="{report_file}"', '/f "Text Report"', f'"{old_dir}"', f'"{new_dir}"' ] full_cmd = ' '.join(cmd) print(f"执行命令: {full_cmd}") changed_files = {} try: # 运行WinMerge result = subprocess.run( full_cmd, shell=True, timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, encoding='utf-8', errors='ignore' ) # 检查WinMerge输出 print(f"WinMerge退出代码: {result.returncode}") print(f"WinMerge stdout: {result.stdout[:300] if result.stdout else '空'}") print(f"WinMerge stderr: {result.stderr[:300] if result.stderr else '空'}") # 检查报告文件 report_content = "" if os.path.exists(report_file): print(f"找到报告文件: {report_file}") try: with open(report_file, 'r', encoding='utf-16') as f: report_content = f.read() print("成功以UTF-16读取报告文件") except UnicodeError: try: with open(report_file, 'r', encoding='utf-8') as f: report_content = f.read() print("成功以UTF-8读取报告文件") except Exception as e: print(f"读取报告文件失败: {e}") report_content = result.stdout else: print(f"未找到报告文件: {report_file}") report_content = result.stdout # 当报告内容可用时处理 diff_files = set() if report_content.strip(): print(f"报告内容长度: {len(report_content)} 字符") # 使用更高效的正则表达式匹配 pattern = re.compile( r'(?:文件|Files|Comparing|File)\s+["\']?(.+?)["\']?\s+(?:和|and|are)\s+["\']?(.+?)["\']?\s+(?:不同|differ|different)', re.IGNORECASE ) # 查找所有匹配的文件路径 matches = pattern.findall(report_content) print(f"找到 {len(matches)} 个匹配的差异文件") for match in matches: # 提取新版本文件路径(第二组) new_file = match[1].strip() # 规范化路径 new_file = os.path.normpath(new_file) diff_files.add(new_file) print(f"识别到差异文件: {new_file}") # 批量处理文件比较 for file_path in diff_files: # 获取相对路径 try: rel_path = os.path.relpath(file_path, new_dir) old_file_path = os.path.join(old_dir, rel_path) if os.path.isfile(old_file_path) and os.path.isfile(file_path): print(f"比较文件: {old_file_path} vs {file_path}") with open(old_file_path, 'r', encoding='utf-8', errors='ignore') as f_old: content_old = f_old.readlines() with open(file_path, 'r', encoding='utf-8', errors='ignore') as f_new: content_new = f_new.readlines() # 获取变更行号(基于新文件) changed_lines = detect_changed_lines(content_old, content_new) if changed_lines: changed_files[rel_path] = changed_lines print(f"变更文件: {rel_path}, 变更行号: {sorted(changed_lines)[:5]}... (共{len(changed_lines)}行)") except Exception as e: print(f"处理文件 {file_path} 出错: {e}") else: print("报告内容为空,使用递归比较目录") changed_files = recursive_compare_dirs(old_dir, new_dir) except subprocess.TimeoutExpired: print("WinMerge执行超时,使用递归比较") changed_files = recursive_compare_dirs(old_dir, new_dir) except Exception as e: print(f"WinMerge处理出错: {e}") changed_files = recursive_compare_dirs(old_dir, new_dir) print(f"找到 {len(changed_files)} 个变更文件") return changed_files def detect_encoding(file_path): """优化版:检测文件编码""" # 常见编码类型列表(优先级排序) encodings = ['utf-8', 'utf-16', 'cp932', 'shift_jis', 'gbk', 'big5', 'latin1'] for encoding in encodings: try: with open(file_path, 'r', encoding=encoding) as f: f.read(4096) return encoding except: continue return 'utf-8' def update_excel_sheets(csv_folder, output_excel, changed_files): """优化版:更新Excel表格(修复填充问题)""" try: print(f"开始更新Excel: {output_excel}") # 加载或创建Excel文件 if os.path.exists(output_excel): print(f"加载现有Excel文件: {output_excel}") wb = load_workbook(output_excel) print(f"现有工作表: {wb.sheetnames}") else: print("创建新的Excel文件") wb = Workbook() # 删除默认创建的工作表 for sheet_name in wb.sheetnames: wb.remove(wb[sheet_name]) # 创建所需的工作表 wb.create_sheet("ファイル差分") wb.create_sheet("_org_fm") wb.create_sheet("warn") # === 功能1: 写入文件差分表 === print("\n=== 写入文件差分表 ===") if "ファイル差分" not in wb.sheetnames: wb.create_sheet("ファイル差分") print("创建'ファイル差分'工作表") ws_diff = wb["ファイル差分"] # 清空工作表(保留标题) if ws_diff.max_row > 1: print(f"清除'ファイル差分'工作表数据 (现有行数: {ws_diff.max_row})") ws_diff.delete_rows(2, ws_diff.max_row - 1) if ws_diff.max_row == 0 or ws_diff["A1"].value != "文件路径": ws_diff.append(["文件路径"]) print("添加'文件路径'标题") # 写入变更文件 print(f"写入 {len(changed_files)} 个变更文件路径") for file_path in changed_files.keys(): ws_diff.append([file_path]) # === 功能2: 复制func_met.csv到_org_fm工作表 === func_met_path = os.path.join(csv_folder, "func_met.csv") if os.path.exists(func_met_path): print("\n=== 处理func_met.csv ===") if "_org_fm" not in wb.sheetnames: wb.create_sheet("_org_fm") print("创建'_org_fm'工作表") ws_fm = wb["_org_fm"] # 清空工作表 if ws_fm.max_row > 1: print(f"清除'_org_fm'工作表数据 (现有行数: {ws_fm.max_row})") ws_fm.delete_rows(2, ws_fm.max_row - 1) # 读取并写入数据 encoding = detect_encoding(func_met_path) print(f"检测到func_met.csv编码: {encoding}") df_fm = pd.read_csv(func_met_path, encoding=encoding) print(f"func_met.csv 列名: {df_fm.columns.tolist()}") print(f"行数: {len(df_fm)}") # 写入标题 if ws_fm.max_row == 0: ws_fm.append(df_fm.columns.tolist()) # 批量写入数据 print("写入func_met.csv数据...") for _, row in df_fm.iterrows(): ws_fm.append(row.tolist()) else: print(f"未找到func_met.csv: {func_met_path}") # === 功能3: 高效处理warn.csv === warn_path = os.path.join(csv_folder, "warn.csv") if os.path.exists(warn_path): print("\n=== 处理warn.csv ===") if "warn" not in wb.sheetnames: wb.create_sheet("warn") print("创建'warn'工作表") ws_warn = wb["warn"] headers = ['Source', 'Line #', 'Level', 'Warn #', 'Message', 'WarnFilter(变更有无)'] ws_warn.append(headers) print(f"添加标题行: {headers}") else: ws_warn = wb["warn"] print(f"使用现有'warn'工作表") # 读取CSV文件 encoding = detect_encoding(warn_path) print(f"检测到warn.csv编码: {encoding}") df_warn = pd.read_csv(warn_path, encoding=encoding) print(f"warn.csv 列名: {df_warn.columns.tolist()}") print(f"行数: {len(df_warn)}") print("前5行数据:") print(df_warn.head()) # 列名映射详情 column_mapping = { 'File': 'Source', 'Line': 'Line #', 'Grp': 'Level', 'Nbr': 'Warn #', 'Description': 'Message' } print(f"列名映射规则: {column_mapping}") # 构建变更文件映射 file_map = {} for file_path, changed_lines in changed_files.items(): # 文件名映射 filename = Path(file_path).name if filename not in file_map: file_map[filename] = set() file_map[filename] = file_map[filename].union(changed_lines) # 完整路径映射 file_map[file_path] = changed_lines # 打印变更文件信息 print(f"变更文件数量: {len(changed_files)}") for i, (file_path, lines) in enumerate(changed_files.items()): if i < 5: # 只打印前5个文件 print(f"变更文件: {file_path}, 变更行数: {len(lines)}") # 准备数据批量写入 rows_to_write = [] match_count = 0 for index, row in df_warn.iterrows(): # 创建映射行 new_row = { 'Source': row.get('File', row.get('Source', '')), 'Line #': row.get('Line', row.get('Line #', '')), 'Level': row.get('Grp', row.get('Level', '')), 'Warn #': row.get('Nbr', row.get('Warn #', '')), 'Message': row.get('Description', row.get('Message', '')), 'WarnFilter(变更有无)': 'No' # 默认值 } # 计算变更标记 source = str(new_row['Source']) line_num = new_row['Line #'] # 检查是否为有效行号 try: line_num = int(line_num) if not pd.isna(line_num) else 0 except: line_num = 0 if source and line_num > 0: # 尝试文件名匹配 filename = Path(source).name if filename in file_map and line_num in file_map[filename]: new_row['WarnFilter(变更有无)'] = 'Yes' match_count += 1 # 尝试完整路径匹配 elif source in file_map and line_num in file_map[source]: new_row['WarnFilter(变更有无)'] = 'Yes' match_count += 1 # 转换为有序列表 row_data = [ new_row['Source'], new_row['Line #'], new_row['Level'], new_row['Warn #'], new_row['Message'], new_row['WarnFilter(变更有无)'] ] rows_to_write.append(row_data) print(f"匹配到 {match_count} 条变更警告") # 清空现有数据行 if ws_warn.max_row > 1: print(f"清除'warn'工作表数据 (现有行数: {ws_warn.max_row})") ws_warn.delete_rows(2, ws_warn.max_row - 1) # 批量写入Excel print(f"写入 {len(rows_to_write)} 行数据到'warn'工作表") start_time = time.time() # 逐行写入数据 for i, row_data in enumerate(rows_to_write): ws_warn.append(row_data) # 每1000行输出一次进度 if (i + 1) % 1000 == 0: print(f"已写入 {i + 1} 行数据...") end_time = time.time() print(f"数据写入完成,耗时: {end_time - start_time:.2f}秒") # 打印前5行写入的数据样本 print("\n写入的前5行数据示例:") for row_idx in range(2, min(7, len(rows_to_write) + 2)): row_data = [] for col_idx in range(1, 7): cell_value = ws_warn.cell(row=row_idx, column=col_idx).value # 截断过长的值以便显示 if cell_value and len(str(cell_value)) > 50: row_data.append(str(cell_value)[:50] + "...") else: row_data.append(str(cell_value)) print(f"行 {row_idx}: {row_data}") else: print(f"未找到warn.csv: {warn_path}") # === 保存Excel文件 === print("\n保存Excel文件...") save_start = time.time() # 创建备份 backup_path = None if os.path.exists(output_excel): timestamp = time.strftime("%Y%m%d_%H%M%S") backup_dir = os.path.join(os.path.dirname(output_excel), "backups") os.makedirs(backup_dir, exist_ok=True) backup_filename = f"{Path(output_excel).stem}_backup_{timestamp}{Path(output_excel).suffix}" backup_path = os.path.join(backup_dir, backup_filename) shutil.copy2(output_excel, backup_path) print(f"创建备份: {backup_path}") # 保存Excel try: wb.save(output_excel) save_duration = time.time() - save_start print(f"Excel保存完成,耗时: {save_duration:.2f}秒") print(f"最终文件: {output_excel}") file_size = os.path.getsize(output_excel) / 1024 / 1024 print(f"文件大小: {file_size:.2f} MB") # 验证保存结果 if os.path.exists(output_excel): print("文件保存验证成功") else: print("!!! 文件保存后不存在,保存可能失败 !!!") return True except Exception as save_error: print(f"保存Excel时出错: {save_error}") if backup_path and os.path.exists(backup_path): print(f"恢复备份文件: {backup_path}") try: shutil.copy2(backup_path, output_excel) print("恢复成功") except Exception as restore_error: print(f"恢复备份失败: {restore_error}") return False except Exception as e: print(f"\n!!! 更新Excel出错: {str(e)} !!!") traceback.print_exc() # 错误备份逻辑 timestamp = time.strftime("%Y%m%d_%H%M%S") backup_name = f"{output_excel}.error_{timestamp}.xlsx" print(f"尝试将错误状态Excel备份至: {backup_name}") try: wb.save(backup_name) print("备份成功") except Exception as backup_error: print(f"备份失败: {backup_error}") return False def main(): # 配置路径(根据实际情况修改) old_code_dir = r"E:\system\Desktop\项目所需文件\工具\ffff\code\old\GA_D82DD83D_00-00-07\mainline\spa_traveo\src" new_code_dir = r"E:\system\Desktop\项目所需文件\工具\ffff\code\new\GA_D82DD83D_00-00-08\mainline\spa_traveo\src" csv_folder = r"E:\system\Desktop\项目所需文件\工具\ffff\APL\Tool出力結果" output_excel = r"E:\system\Desktop\项目所需文件\工具\ffff\GA_D24D_00-00-01(三回目)_QAC.xlsx" winmerge_path = r"E:/App/WinMerge/WinMerge/WinMergeU.exe" print("="*80) print("开始文件比较...") print(f"旧代码目录: {old_code_dir}") print(f"新代码目录: {new_code_dir}") print(f"CSV文件夹: {csv_folder}") print(f"输出Excel: {output_excel}") print(f"WinMerge路径: {winmerge_path}") try: start_time = time.time() # 获取变更文件 changed_files = get_changed_files_and_lines(old_code_dir, new_code_dir, winmerge_path) print(f"\n找到 {len(changed_files)} 个变更文件") # 打印变更文件详情 print("\n变更文件详情:") for i, (file_path, lines) in enumerate(changed_files.items()): if i < 10: # 最多打印10个文件详情 print(f"{i+1}. {file_path}: 变更行数 {len(lines)}") if len(lines) < 10: # 打印少于10行的变更行号 print(f" 行号: {sorted(lines)}") # 更新Excel print("="*80) print("更新Excel表格...") success = update_excel_sheets(csv_folder, output_excel, changed_files) total_time = time.time() - start_time print(f"\n总处理时间: {total_time:.2f}秒") if success: print(f"\n处理完成! 输出文件: {output_excel}") else: print("\n处理失败,请检查错误日志") except Exception as e: print(f"\n!!! 处理过程中发生严重错误: {str(e)} !!!") traceback.print_exc() print("="*80) print("程序结束") if __name__ == "__main__": main()
09-25
import pandas as pd import numpy as np import jieba # 中文分词库 import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay from tensorflow.keras.preprocessing.text import Tokenizer from tensorflow.keras.preprocessing.sequence import pad_sequences from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Embedding, LSTM, Dense from tensorflow.keras.callbacks import EarlyStopping import keras_nlp # KerasNLP库,用于Transformer实现 encodings = ['gbk', 'gb2312', 'gb18030', 'iso-8859-1', 'unicode_escape'] for encoding in encodings: try: df = pd.read_csv('C:\Users\ASUS\Desktop\sport_news.csv', encoding=encoding) print(f"成功使用 {encoding} 编码读取文件") break except UnicodeDecodeError: print(f"{encoding} 编码读取失败,尝试下一个编码") # 1. 数据加载与预处理 # 加载数据集(请确保sport_news.csv文件在当前目录下) #df = pd.read_csv('D:\新建文件夹\sport_news.csv',encoding='utf-8') texts = df['text'].tolist() # 提取文本数据 labels = df['label'].astype('category').cat.codes.tolist() # 标签转为整数编码 # 加载停用词表(请确保stopwords.txt文件在当前目录下,内容为中文停用词) # 若没有停用词表,可注释掉stopwords相关代码,或从网上下载中文停用词表 try: stopwords = set(pd.read_csv('stopwords.txt', header=None, squeeze=True, encoding='utf-8')) except FileNotFoundError: print("警告:未找到stopwords.txt,将跳过停用词过滤") stopwords = set() # 空集合,不过滤停用词 # 中文文本预处理函数:分词+去停用词 def preprocess(text): words = jieba.cut(text) # 使用jieba进行分词 # 过滤停用词和空字符串 return [w for w in words if w not in stopwords and w.strip()] # 对所有文本进行预处理,并转为空格分隔的字符串(方便后续向量化) processed_texts = [preprocess(text) for text in texts] processed_texts = [' '.join(words) for words in processed_texts] # 划分训练集、验证集、测试集(7:2:1) X_train, X_temp, y_train, y_temp = train_test_split(processed_texts, labels, test_size=0.3, random_state=42) X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=1/3, random_state=42) # 2. 文本向量化(LSTM模型专用) max_len = 200 # 文本最大长度(超过截断,不足补全) vocab_size = 10000 # 词汇表大小 # 初始化Tokenizer,用于将文本转为整数序列 tokenizer = Tokenizer(num_words=vocab_size) tokenizer.fit_on_texts(X_train) # 根据训练集构建词汇表 # 将文本转为整数序列,并统一长度 X_train_seq = tokenizer.texts_to_sequences(X_train) X_val_seq = tokenizer.texts_to_sequences(X_val) X_test_seq = tokenizer.texts_to_sequences(X_test) # 填充/截断序列至固定长度max_len X_train_pad = pad_sequences(X_train_seq, maxlen=max_len, padding='post') X_val_pad = pad_sequences(X_val_seq, maxlen=max_len, padding='post') X_test_pad = pad_sequences(X_test_seq, maxlen=max_len, padding='post') # 3. 构建并训练LSTM模型 lstm_model = Sequential([ # 嵌入层:将整数序列映射为32维向量 Embedding(input_dim=vocab_size, output_dim=32, input_length=max_len), # LSTM层:64个隐藏单元,返回最后一个时间步的输出 LSTM(64, return_sequences=False), # 全连接层:16个神经元,ReLU激活 Dense(16, activation='relu'), # 输出层:8个类别,Softmax激活(输出概率分布) Dense(8, activation='softmax') ]) # 编译模型 lstm_model.compile( optimizer='adam', loss='sparse_categorical_crossentropy', # 适用于整数标签 metrics=['accuracy'] # 监控准确率 ) # 早停策略:防止过拟合,验证集损失3轮不下降则停止 early_stopping = EarlyStopping(patience=3, restore_best_weights=True) # 训练LSTM模型 print("开始训练LSTM模型...") lstm_history = lstm_model.fit( X_train_pad, y_train, validation_data=(X_val_pad, y_val), batch_size=32, epochs=10, callbacks=[early_stopping] ) # 4. 构建并训练Transformer模型(基于KerasNLP) # 初始化分词器(使用KerasNLP内置的中文分词器) tokenizer_transformer = keras_nlp.tokenizers.BytePairTokenizer.from_preset("zh_bert_base") tokenizer_transformer.vocabulary_size = vocab_size # 与LSTM保持一致的词汇表大小 # 文本预处理函数:转换为序列并固定长度 def transformer_preprocess(text): return tokenizer_transformer(text, sequence_length=max_len) # 转换数据集为Transformer可接受的格式 X_train_transformer = transformer_preprocess(X_train) X_val_transformer = transformer_preprocess(X_val) X_test_transformer = transformer_preprocess(X_test) # 构建Transformer模型 transformer_model = keras_nlp.models.TextClassifier( backbone=keras_nlp.models.TransformerBackbone( vocabulary_size=vocab_size, num_layers=1, # Transformer编码器层数 num_heads=2, # 注意力头数 hidden_dim=64, # 隐藏层维度 intermediate_dim=128, # 中间层维度 ), num_classes=8 # 分类类别数 ) # 编译Transformer模型 transformer_model.compile( optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) # 训练Transformer模型 print("开始训练Transformer模型...") transformer_history = transformer_model.fit( X_train_transformer, y_train, validation_data=(X_val_transformer, y_val), batch_size=32, epochs=10, callbacks=[early_stopping] ) # 5. 模型评估 # 评估LSTM模型 print("\n===== LSTM模型测试集评估结果 =====") lstm_loss, lstm_acc = lstm_model.evaluate(X_test_pad, y_test, verbose=0) print(f"测试集准确率:{lstm_acc:.4f}") y_pred_lstm = np.argmax(lstm_model.predict(X_test_pad), axis=1) print(classification_report(y_test, y_pred_lstm)) # 评估Transformer模型 print("\n===== Transformer模型测试集评估结果 =====") trans_loss, trans_acc = transformer_model.evaluate(X_test_transformer, y_test, verbose=0) print(f"测试集准确率:{trans_acc:.4f}") y_pred_trans = np.argmax(transformer_model.predict(X_test_transformer), axis=1) print(classification_report(y_test, y_pred_trans)) # 6. 结果可视化 # 绘制准确率曲线 plt.figure(figsize=(12, 5)) plt.subplot(1, 2, 1) plt.plot(lstm_history.history['accuracy'], label='LSTM 训练集') plt.plot(lstm_history.history['val_accuracy'], label='LSTM 验证集') plt.plot(transformer_history.history['accuracy'], label='Transformer 训练集') plt.plot(transformer_history.history['val_accuracy'], label='Transformer 验证集') plt.title('模型准确率曲线') plt.xlabel('训练轮次(Epochs)') plt.ylabel('准确率') plt.legend() # 绘制LSTM混淆矩阵 plt.subplot(1, 2, 2) cm_lstm = confusion_matrix(y_test, y_pred_lstm) disp = ConfusionMatrixDisplay( confusion_matrix=cm_lstm, display_labels=df['label'].astype('category').cat.categories ) disp.plot(cmap=plt.cm.Blues, ax=plt.gca()) plt.title('LSTM混淆矩阵') plt.tight_layout() plt.show() # 绘制Transformer混淆矩阵 plt.figure(figsize=(7, 7)) cm_trans = confusion_matrix(y_test, y_pred_trans) disp = ConfusionMatrixDisplay( confusion_matrix=cm_trans, display_labels=df['label'].astype('category').cat.categories ) disp.plot(cmap=plt.cm.Blues) plt.title('Transformer混淆矩阵') plt.tight_layout() plt.show()
最新发布
11-09
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值