boost:tuple解码

本文详细介绍了如何利用模板特化在C++中实现高效的数据结构封装,通过实例展示了不同参数类型特化的具体实现,以及如何根据不同场景选择合适的特化方式。
 tuple想必人人都会用,究之实质,其实就是模板全特化及部分特化的又一种应用而已。简要的实现如下:


 template<typename T1 = void, typename T2 = void, typename T3 = void, typename T4 = void>
 struct Tuple
 {
  Tuple(){}
  Tuple(typename traits::type_traits<T1>::const_reference a,
   typename traits::type_traits<T2>::const_reference b,
   typename traits::type_traits<T3>::const_reference c,
   typename traits::type_traits<T4>::const_reference d)
   :a_(a),b_(b),c_(c),d_(d)
  {}

  T1 a_;
  T2 b_;
  T3 c_;
  T4 d_;
 };

 template<typename T1, typename T2, typename T3>
 struct Tuple<T1,T2,T3,void>
 {
  Tuple(){}
  Tuple(typename traits::type_traits<T1>::const_reference a,
   typename traits::type_traits<T2>::const_reference b,
   typename traits::type_traits<T3>::const_reference c)
   :a_(a),b_(b),c_(c)
  {}

  T1 a_;
  T2 b_;
  T3 c_;
 };

 template<typename T1, typename T2>
 struct Tuple<T1,T2,void,void>
 {
  Tuple(){}
  Tuple(typename traits::type_traits<T1>::const_reference a,
   typename traits::type_traits<T2>::const_reference b)
   :a_(a),b_(b)
  {}

  T1 a_;
  T2 b_;
 };

 template<typename T1>
 struct Tuple<T1,void,void,void>
 {
  Tuple(){}
  Tuple(typename traits::type_traits<T1>::const_reference a)
   :a_(a)
  {}

  T1 a_;
 };

 template<>
 struct Tuple<void,void,void,void>
 {
  Tuple(){}
 };


import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn.preprocessing import LabelEncoder, RobustScaler from sklearn.metrics import ( classification_report, accuracy_score, confusion_matrix, ConfusionMatrixDisplay ) import xgboost as xgb import matplotlib as mpl import matplotlib.font_manager as fm import os import csv import re import json from typing import Tuple, Dict, List, Any # 设置 Pandas 选项 pd.set_option('future.no_silent_downcasting', True) # ============================= # 全局配置与常量 # ============================= # 设置特征重要性文件和要读取的特征数量 def clean_feature_name(feature_name): """更精确地清理特征名中的多余引号,保留必要的引号""" # 去除首尾的引号(仅当整个字符串被引号包裹时) if feature_name.startswith('"') and feature_name.endswith('"'): feature_name = feature_name[1:-1] # 处理内部的双引号(将两个连续的双引号转义为一个) feature_name = feature_name.replace('""', '"') return feature_name # 设置特征重要性文件和要读取的特征数量 FEATURE_IMPORTANCE_FILE = "XGBoost_feature_importance.csv" NUM_FEATURES_TO_SELECT = 160 # 可以根据需要调整这个值 # 用于记录被忽略的特征及其原因 ignored_features = { 'invalid_format': [], 'empty_name': [], 'invalid_importance': [] } try: # 从CSV文件中读取特征并按重要性排序 with open(FEATURE_IMPORTANCE_FILE, 'r', encoding='utf-8') as f: reader = csv.reader(f) # 跳过可能的标题行(如果存在) try: header = next(reader) if header[0].lower() == 'feature' and len(header) >= 2: print("检测到标题行,已跳过") else: # 如果不是标题行,需要处理这一行 if len(header) >= 2: try: feature_name = clean_feature_name(header[0]) importance = float(header[1]) features.append((feature_name, importance)) except ValueError: ignored_features['invalid_importance'].append(header[0]) except Exception: ignored_features['invalid_format'].append(header[0]) except StopIteration: pass # 空文件,后面会处理 # 读取所有特征并按重要性排序(降序) features = [] for row_num, row in enumerate(reader, start=2): # 从第2行开始计数 if len(row) < 2: ignored_features['invalid_format'].append(f"行 {row_num}: 列数不足 - {row}") continue raw_feature_name = row[0] # 清理特征名 try: feature_name = clean_feature_name(raw_feature_name) if not feature_name: ignored_features['empty_name'].append(f"行 {row_num}: 空特征名 - {raw_feature_name}") continue # 解析重要性值 try: importance = float(row[1]) features.append((feature_name, importance)) except ValueError: ignored_features['invalid_importance'].append(f"行 {row_num}: 无效的重要性值 '{row[1]}' - 特征: {feature_name}") except Exception as e: ignored_features['invalid_format'].append(f"行 {row_num}: 特征名解析错误 '{raw_feature_name}' - 错误: {str(e)}") # 按特征重要性降序排序 features.sort(key=lambda x: x[1], reverse=True) # 选择前N个最重要的特征 SELECTED_FEATURES = [feature[0] for feature in features[:NUM_FEATURES_TO_SELECT]] if not SELECTED_FEATURES: raise ValueError(f"CSV 文件中未找到有效的特征数据或列表为空") print(f"\n成功从 {FEATURE_IMPORTANCE_FILE} 加载特征列表,共 {len(SELECTED_FEATURES)} 个特征") print(f"前5个最重要的特征: {SELECTED_FEATURES[:5]}") # 打印忽略的特征统计 total_ignored = sum(len(v) for v in ignored_features.values()) if total_ignored > 0: print(f"\n警告: 共忽略了 {total_ignored} 个特征") for reason, features_list in ignored_features.items(): if features_list: print(f"\n由于 {reason} 被忽略的特征 ({len(features_list)} 个):") for i, feature in enumerate(features_list[:10]): # 最多显示前10个 print(f" {i+1}. {feature}") if len(features_list) > 10: print(f" ...(共 {len(features_list)} 个,只显示前10个)") else: print("\n所有特征都被成功解析,没有忽略任何特征") except FileNotFoundError: print(f"错误: 找不到 {FEATURE_IMPORTANCE_FILE},请确保文件存在。") SELECTED_FEATURES = [] # 需后续处理 except Exception as e: raise ValueError(f"读取特征重要性文件时出错: {str(e)}") RESULTS_DIR = "app_loocv_results" # 结果存储文件夹 # 用户可配置的模型选择 (可修改) SELECTED_MODELS = ['XGBoost'] # 可选模型: 'XGBoost', 'MLP', 'TabNet' # 创建结果文件夹 os.makedirs(RESULTS_DIR, exist_ok=True) # ============================= # 工具函数 # ============================= def set_chinese_font(): """设置中文字体支持""" common_fonts = ['SimHei', 'Microsoft YaHei', 'KaiTi', 'SimSun', 'FangSong'] available_fonts = {f.name for f in fm.fontManager.ttflist} font = next((f for f in common_fonts if f in available_fonts), None) if not font: cjk_fonts = [f for f in fm.fontManager.ttflist if 'CJK' in f.name] font = cjk_fonts[0].name if cjk_fonts else None if not font: droid_font_path = "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf" if os.path.exists(droid_font_path): fm.fontManager.addfont(droid_font_path) font = "Droid Sans Fallback" print(f"\n已启用中文字体: {font}") if font: mpl.rcParams['font.family'] = font mpl.rcParams['font.sans-serif'] = [font] mpl.rcParams['axes.unicode_minus'] = False plt.rcParams['savefig.dpi'] = 300 plt.rcParams['figure.autolayout'] = True # 初始化字体 set_chinese_font() # ============================= # 数据准备 # ============================= # 1. 标签映射 & 保留类别 category_mapping = { "001-local_games": "游戏", "002-cloud_games": "云游戏", "101-audio_call_conference": "语音通话", "102-video_call_conference": "视频通话", "201-audio_streaming_media": "音频流媒体", "202-video_streaming_media": "视频媒体", "301-watch_livecast": "直播", "501-file_download": "文件下载", "502-file_upload": "文件上传" } keep_ids = list(category_mapping.keys()) # 仅保留这些二级标签 # 2. 核心app提取函数 def extract_app_core(app_name: str) -> str: """提取app核心名称去除数字后缀""" parts = app_name.split('-') if len(parts) < 2: return app_name if parts[-1].isdigit(): return '-'.join(parts[:-1]) return app_name # 3. 读取数据 # 注意:请确保此路径正确 try: df = pd.read_csv("./combined_features_tsfresh_wave.csv") df['level2_label'] = df['level2_label'].astype(str).str.zfill(3) # 确保二级标签为3位字符串 df = df[df['level2_label'].isin(keep_ids)].reset_index(drop=True) # 过滤无效标签 # 提取核心app名称 df['app_core'] = df['app_name'].apply(extract_app_core) print(f"过滤后数据形状: {df.shape}") print(f"二级标签分布:\n{df['level2_label'].value_counts()}\n") print(f"核心app分布:\n{df['app_core'].value_counts()}\n") # 4. 特征/标签/分组定义 # 这里的 X 会在主流程中进一步根据 SELECTED_FEATURES 筛选 X = df.drop(columns=[ "level1_label", "level2_label", "app_name", "flow_id", "app_core" ], errors='ignore') y = df["level2_label"] # 标签编码 label_encoder = LabelEncoder() y_encoded = label_encoder.fit_transform(y) encoded_classes = label_encoder.classes_ print(f"标签编码顺序: {encoded_classes}\n") except FileNotFoundError: print("警告: 找不到数据文件 ./combined_features_tsfresh_wave.csv,请检查路径。") # 为了不让脚本直接崩溃,这里设置为空,实际运行时请确保文件存在 df = pd.DataFrame() X = pd.DataFrame() y_encoded = np.array([]) encoded_classes = [] label_encoder = LabelEncoder() # ============================= # 预处理函数 # ============================= def preprocess_data(X_train, X_test, y_train, y_test): """统一特征预处理分类编码、缺失值填充、标准化""" # 1. 分类特征编码 cat_cols = X_train.select_dtypes(include=["object"]).columns for col in cat_cols: le = LabelEncoder() X_train[col] = le.fit_transform(X_train[col]) # Handle unseen labels in test set X_test[col] = X_test[col].map(lambda s: '<unknown>' if s not in le.classes_ else s) le_classes = list(le.classes_) le_classes.append('<unknown>') le.classes_ = np.array(le_classes) X_test[col] = le.transform(X_test[col]) # 2. 缺失值/异常值处理 for col in X_train.columns: if pd.api.types.is_numeric_dtype(X_train[col]): med = X_train[col].median() X_train[col] = X_train[col].fillna(med).replace([np.inf, -np.inf], med) X_test[col] = X_test[col].fillna(med).replace([np.inf, -np.inf], med) # 3. 标准化 scaler = RobustScaler() X_tr = scaler.fit_transform(X_train) X_te = scaler.transform(X_test) return X_tr, X_te, y_train, y_test # ============================= # 模型工厂类 # ============================= def train_model( model_name: str, X_train: pd.DataFrame, X_test: pd.DataFrame, y_train: np.ndarray, y_test: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, Dict[str, float]]: """ 统一模型训练函数 返回: (预测标签, 置信度, 特征重要性字典) """ # 1. 共享预处理 X_tr, X_te, y_tr, y_te = preprocess_data(X_train, X_test, y_train, y_test) feature_names = X_train.columns.tolist() importance_dict = {} # 2. 根据不同模型类型进行训练 if model_name == "XGBoost": # 转换为XGBoost格式 dtrain = xgb.DMatrix(X_tr, label=y_tr, feature_names=feature_names) dtest = xgb.DMatrix(X_te, label=y_te, feature_names=feature_names) # 训练参数 - 使用softprob获取概率 params = { "objective": "multi:softprob", "num_class": len(encoded_classes), "max_depth": 3, "eta": 0.2, "eval_metric": "merror", "verbosity": 0, "random_state": 42 } # 训练并预测 model = xgb.train(params, dtrain, num_boost_round=100, verbose_eval=False) proba = model.predict(dtest) # 获取概率矩阵 # 预测类别和置信度 y_pred = np.argmax(proba, axis=1).astype(int) confidence = np.max(proba, axis=1) # 最高概率值作为置信度 # --- 获取特征重要性 (Gain: 特征对模型带来的平均增益) --- # 使用 get_score 可以获取到 map {feature_name: score} raw_importance = model.get_score(importance_type='gain') # 确保所有特征都在字典中,如果XGBoost没用到该特征,置为0 importance_dict = {k: raw_importance.get(k, 0.0) for k in feature_names} return y_pred, confidence, importance_dict elif model_name in ["MLP", "TabNet"]: # 这里只是占位符,神经网络通常没有直接的"feature_importance" # 可以使用 Permutation Importance,但为了简化,这里返回空字典 # 实际逻辑需根据具体库补充 # y_pred = ... # confidence = ... raise NotImplementedError(f"模型 {model_name} 尚未实现完整训练逻辑") else: raise ValueError(f"不支持的模型类型: {model_name}") # ============================= # 交叉验证框架 # ============================= def run_core_app_cv( models_to_run: list, X: pd.DataFrame, y_encoded: np.ndarray, df: pd.DataFrame, selected_features: list, label_encoder: LabelEncoder ) -> Tuple[dict, dict]: """ 核心App留一交叉验证 :return: 1. all_results: 包含预测结果的字典 2. avg_feature_importances: 包含各模型平均特征重要性的字典 """ unique_app_cores = df['app_core'].unique() all_results = {} avg_feature_importances = {} # 存储每个模型的平均特征重要性 for model_name in models_to_run: print(f"\n{'='*60}\n开始 {model_name} 模型交叉验证\n{'='*60}") results = {"true": [], "pred": [], "confidence": [], "cores": [], "indices": []} fold_importances = [] # 收集每一折的特征重要性 for app_idx, test_core in enumerate(unique_app_cores, 1): print(f"\n{'='*20} {model_name} - 第 {app_idx}/{len(unique_app_cores)} 个App: {test_core} {'='*20}") # 划分训练/测试集 train_mask = df['app_core'] != test_core test_mask = df['app_core'] == test_core # 确保只使用存在的列 valid_features = [f for f in selected_features if f in X.columns] X_train = X.loc[train_mask, valid_features].copy() X_test = X.loc[test_mask, valid_features].copy() y_train = y_encoded[train_mask] y_test = y_encoded[test_mask] print(f"训练集: {X_train.shape[0]}样本, 测试集: {X_test.shape[0]}样本") # 训练并预测 - 获取特征重要性 y_pred, confidence, importance_dict = train_model(model_name, X_train, X_test, y_train, y_test) # 保存预测结果 results["true"].extend(label_encoder.inverse_transform(y_test)) results["pred"].extend(label_encoder.inverse_transform(y_pred)) results["confidence"].extend(confidence) results["cores"].extend([test_core] * len(y_test)) results["indices"].extend(df[test_mask].index.tolist()) # 保存本折特征重要性 if importance_dict: fold_importances.append(importance_dict) all_results[model_name] = results # --- 计算该模型所有折的平均特征重要性 --- if fold_importances: # 转换为DataFrame方便计算均值 imp_df = pd.DataFrame(fold_importances) # 缺失值填0(表示某折中该特征未被用到),计算均值 mean_importance = imp_df.fillna(0).mean() # 存入结果字典 avg_feature_importances[model_name] = mean_importance.to_dict() print(f"\n已计算 {model_name} 的平均特征重要性 (基于 {len(fold_importances)} 次折叠)") return all_results, avg_feature_importances # ============================= # 结果保存与分析 # ============================= def save_results(model_name: str, results: dict, save_dir: str, confidence_threshold: float = 0.85): """保存单模型结果并生成可视化 包含置信度分析""" # 创建结果DataFrame df_res = df.loc[results["indices"], ['app_name', 'app_core', 'level2_label']].reset_index(drop=True) true_ids = [str(c).zfill(3) for c in results["true"]] pred_ids = [str(c).zfill(3) for c in results["pred"]] df_res['true_label'] = true_ids df_res['pred_label'] = pred_ids df_res['confidence'] = results["confidence"] df_res['correct'] = (np.array(results["true"]) == np.array(results["pred"])).astype(int) # 保存样本级结果 sample_path = os.path.join(save_dir, f'{model_name}_sample_results.csv') df_res.to_csv(sample_path, index=False, encoding='utf-8-sig') # 高置信度样本分析 high_conf_df = df_res[df_res['confidence'] >= confidence_threshold] if not high_conf_df.empty: high_conf_acc = accuracy_score(high_conf_df['true_label'], high_conf_df['pred_label']) else: high_conf_acc = 0 # 混淆矩阵 class_names = [category_mapping.get(c, c) for c in encoded_classes] cm_full = confusion_matrix(results["true"], results["pred"], labels=encoded_classes) fig, ax = plt.subplots(figsize=(12, 10)) ConfusionMatrixDisplay(confusion_matrix=cm_full, display_labels=class_names).plot( cmap='Blues', values_format='d', ax=ax, colorbar=False) plt.title(f'{model_name}混淆矩阵 (全样本)', fontsize=14) plt.xticks(rotation=45, ha='right') plt.tight_layout() plt.savefig(os.path.join(save_dir, f'{model_name}_confusion_matrix_full.png'), dpi=300) plt.close() overall_acc = accuracy_score(results["true"], results["pred"]) return { 'model': model_name, 'overall_accuracy': overall_acc, 'high_confidence_accuracy': high_conf_acc, 'high_confidence_ratio': len(high_conf_df) / len(df_res) } def save_feature_importance(model_name: str, importance_dict: dict, save_dir: str): """ 保存特征重要性为JSON,按重要性降序排列 """ # 排序:从高到低 sorted_items = sorted(importance_dict.items(), key=lambda item: item[1], reverse=True) # 转换为列表格式 [{"feature": "xxx", "importance": 123}, ...] # 或者直接保存为有序字典,这里选择保存为有序字典结构 sorted_dict = {k: v for k, v in sorted_items} # 保存 JSON json_path = os.path.join(save_dir, f'{model_name}_feature_importance.json') try: with open(json_path, 'w', encoding='utf-8') as f: json.dump(sorted_dict, f, indent=4, ensure_ascii=False) print(f"已保存 {model_name} 特征重要性至: {json_path}") except Exception as e: print(f"保存特征重要性失败: {e}") # 可选:同时也保存一份CSV方便查看 try: df_imp = pd.DataFrame(sorted_items, columns=['feature', 'importance']) df_imp.to_csv(os.path.join(save_dir, f'{model_name}_feature_importance.csv'), index=False) except Exception as e: print(f"保存特征重要性CSV失败: {e}") # ============================= # 主执行流程 # ============================= if __name__ == "__main__": if df.empty: print("数据未加载,程序退出。") exit() print("\n" + "="*60) print("开始执行核心App交叉验证流程") print("="*60 + "\n") # 1. 特征准备 missing_features = [f for f in SELECTED_FEATURES if f not in X.columns] if missing_features: print(f"警告: {len(missing_features)} 个特征在数据集中不存在,将被忽略。") SELECTED_FEATURES = [f for f in SELECTED_FEATURES if f in X.columns] print(f"实际使用特征数: {len(SELECTED_FEATURES)}") # 2. 执行交叉验证 (接收两个返回值) all_results, all_importances = run_core_app_cv( models_to_run=SELECTED_MODELS, X=X, y_encoded=y_encoded, df=df, selected_features=SELECTED_FEATURES, label_encoder=label_encoder ) # 3. 结果保存与分析 final_metrics = {} for model_name, results in all_results.items(): print(f"\n正在处理 {model_name} 的结果分析...") metrics = save_results(model_name, results, RESULTS_DIR, confidence_threshold=0.9) # --- 新增:保存特征重要性 --- if model_name in all_importances: save_feature_importance(model_name, all_importances[model_name], RESULTS_DIR) final_metrics[model_name] = metrics print(f"[{model_name}] 总体准确率: {metrics['overall_accuracy']:.4f}") print(f"[{model_name}] 高置信度准确率: {metrics['high_confidence_accuracy']:.4f}") print("\n" + "="*60) print("流程执行完毕 结果保存至目录:", RESULTS_DIR) print("="*60) 我需要添加一个功能,用于在每一轮训练后记录单个应用的训练结果,最后输出一个csv文件保存所有模型的结果便于用户核对,请你告诉我需要修改和添加的内容
最新发布
11-27
# worm_gear_v15_ug_quality_plus_lines_optimized_v2.py # 1. 补 STL 导入 # 2. 主线程 GLFW + Tkinter 非阻塞 # 3. 错误路径清理 GLFW # 4. 启用面剔除 # 5. 热更新网格(不重启窗口) import numpy as np import math, glm, numba as nb, logging, time, os, glfw, OpenGL.GL as gl, OpenGL.GL.shaders as shaders import tkinter as tk, tkinter.messagebox as msg, tkinter.filedialog as fd from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg import matplotlib.pyplot as plt from collections import namedtuple from typing import Tuple, List, Optional from stl import mesh # <<< NEW 1 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s') # ---------------- 参数结构 ---------------- class MachiningError: __slots__ = ('tool_rx_err_deg', 'tool_ry_err_deg', 'tool_rz_err_deg', 'pitch_err_um', 'tool_dx_um', 'tool_dy_um') def __init__(self, rx=0.0, ry=0.0, rz=0.0, pitch=0.0, dx=0.0, dy=0.0): self.tool_rx_err_deg = rx self.tool_ry_err_deg = ry self.tool_rz_err_deg = rz self.pitch_err_um = pitch self.tool_dx_um = dx self.tool_dy_um = dy class WormParams: __slots__ = ('module', 'teeth', 'length', 'pressure_angle', 'helix_angle', 'backlash', 'root_factor', 'pitch_radius', 'base_radius', 'lead', 'root_radius', 'tip_radius', 'axial_res', 'tooth_res') def __init__(self, m=2, z=2, L=50, alpha=20, beta=5, bl=0.1, rf=1.0): self.module, self.teeth, self.length = m, z, L self.pressure_angle = math.radians(alpha) self.helix_angle = math.radians(beta) self.backlash, self.root_factor = bl, rf self._derive() def _derive(self): self.pitch_radius = self.module * self.teeth / 2 self.base_radius = self.pitch_radius * math.cos(self.pressure_angle) self.lead = 2 * math.pi * self.pitch_radius / math.tan(self.helix_angle) self.root_radius = self.pitch_radius - 1.25 * self.module * self.root_factor self.tip_radius = self.pitch_radius + self.module complexity = max(1, min(5, int(10 / self.module))) self.axial_res = 100 * complexity self.tooth_res = 100 * complexity # ---------------- 几何计算 ---------------- @nb.njit(cache=True, fastmath=True) def theory_profile(p: WormParams) -> np.ndarray: rb, ra, rf = p.base_radius, p.tip_radius, p.root_radius n = p.tooth_res // 2 theta_max = math.sqrt((ra/rb)**2 - 1) t = np.linspace(0, theta_max, n) right = np.empty((n, 2), np.float32) for i in range(n): ti = t[i] right[i, 0] = rb * (math.cos(ti) + ti * math.sin(ti)) right[i, 1] = rb * (math.sin(ti) - ti * math.cos(ti)) left = np.empty((n, 2), np.float32) for i in range(n): ti = t[i] left[i, 0] = rb * (math.cos(-ti) - ti * math.sin(-ti)) left[i, 1] = rb * (math.sin(-ti) + ti * math.cos(-ti)) left = left[::-1] n_root = p.tooth_res // 3 angles = np.linspace(-np.pi/2 + p.pressure_angle, np.pi/2 - p.pressure_angle, n_root) root = np.empty((n_root, 2), np.float32) for i in range(n_root): ang = angles[i] root[i, 0], root[i, 1] = rf * math.cos(ang), rf * math.sin(ang) profile = np.vstack((right, root, left)) if p.backlash > 0: off = np.array([math.cos(p.pressure_angle), math.sin(p.pressure_angle)]) * p.backlash / 2 profile[:len(right)] -= off profile[-len(left):] += off return profile @nb.njit(cache=True, fastmath=True, parallel=True) def build_mesh_90_norm_color(p: WormParams, profile: np.ndarray, err: MachiningError) -> Tuple: n_prof = len(profile) z = np.linspace(0, p.length, p.axial_res, dtype=np.float32) phase = (2 * np.pi * z / p.lead).astype(np.float32) offset = (2 * np.pi * np.arange(p.teeth) / p.teeth).astype(np.float32) rx, ry, rz = math.radians(err.tool_rx_err_deg), math.radians(err.tool_ry_err_deg), math.radians(err.tool_rz_err_deg) cos_rz, sin_rz = math.cos(rz), math.sin(rz) max_vertices = p.axial_res * n_prof * p.teeth vertices = np.empty((max_vertices, 3), np.float32) normals = np.empty((max_vertices, 3), np.float32) colors = np.empty((max_vertices, 3), np.float32) vertex_count = 0 faces = [] for i in nb.prange(p.axial_res): cos_p, sin_p = math.cos(phase[i]), math.sin(phase[i]) z_val = z[i] for j in range(n_prof): x0, y0 = profile[j] x0 += err.pitch_err_um * 1e-3 * math.cos(phase[i]) y0 += err.pitch_err_um * 1e-3 * math.sin(phase[i]) x0 += err.tool_dx_um * 1e-3; y0 += err.tool_dy_um * 1e-3 new_x = x0 * cos_rz - y0 * sin_rz new_y = x0 * sin_rz + y0 * cos_rz x0, y0 = new_x, new_y for k in range(p.teeth): cos_t, sin_t = math.cos(offset[k]), math.sin(offset[k]) x = x0*(cos_p*cos_t - sin_p*sin_t) - y0*(cos_p*sin_t + sin_p*cos_t) y = x0*(sin_p*cos_t + cos_p*sin_t) + y0*(cos_p*cos_t - sin_p*sin_t) ang = math.atan2(y, x) if ang < 0: ang += 2*math.pi if ang <= 3*math.pi/2: idx = vertex_count vertices[idx] = (x, y, z_val) norm_len = math.sqrt(x*x + y*y) normals[idx] = (x/norm_len, y/norm_len, 0.0) if norm_len > 1e-6 else (0,0,1) err_val = math.sqrt((x-profile[j,0])**2 + (y-profile[j,1])**2)*1e3 colors[idx] = (min(1.0, err_val/50.0), 0.0, max(0.0, 1.0 - err_val/50.0)) vertex_count += 1 vertices = vertices[:vertex_count]; normals = normals[:vertex_count]; colors = colors[:vertex_count] n_per_layer = vertex_count // p.axial_res for i in range(p.axial_res - 1): start = i*n_per_layer; next_start = (i+1)*n_per_layer if next_start + n_per_layer > vertex_count: break for j in range(n_per_layer - 1): a, b, c, d = start+j, start+j+1, next_start+j, next_start+j+1 if d < vertex_count: faces.append((a, b, c)); faces.append((b, d, c)) return vertices, normals, colors, np.array(faces, dtype=np.uint32) def build_overlay_lines(p: WormParams, err: MachiningError) -> Tuple[np.ndarray, np.ndarray]: prof_perf = theory_profile(p) prof_err = apply_machining_errors(prof_perf, p, err) z0 = 0.0 line_perf = np.column_stack([prof_perf[:, 0], prof_perf[:, 1], np.full(len(prof_perf), z0)]) line_err = np.column_stack([prof_err[:, 0], prof_err[:, 1], np.full(len(prof_err), z0)]) return line_perf.astype(np.float32), line_err.astype(np.float32) @nb.njit(cache=True, fastmath=True) def apply_machining_errors(profile: np.ndarray, p: WormParams, err: MachiningError) -> np.ndarray: result = np.empty_like(profile) rx_rad, ry_rad, rz_rad = math.radians(err.tool_rx_err_deg), math.radians(err.tool_ry_err_deg), math.radians(err.tool_rz_err_deg) cos_rz, sin_rz = math.cos(rz_rad), math.sin(rz_rad) for i in range(len(profile)): x, y = profile[i] x += err.tool_dx_um * 1e-3; y += err.tool_dy_um * 1e-3 new_x = x * cos_rz - y * sin_rz new_y = x * sin_rz + y * cos_rz result[i] = (new_x, new_y) return result # ---------------- GLRenderer ---------------- class GLRenderer: def __init__(self): self.vao = self.ebo = self.shader = self.line_shader = None self.face_count = 0 self.line_vao_perf = self.line_vao_err = None self.line_count_perf = self.line_count_err = 0 def setup_shaders(self): vert = """#version 330 core layout(location = 0) in vec3 aPos; layout(location = 1) in vec3 aNorm; layout(location = 2) in vec3 aColor; uniform mat4 MVP; uniform vec3 lightDir; out vec3 color; void main(){ gl_Position = MVP * vec4(aPos, 1.0); vec3 norm = normalize(aNorm); float diff = max(dot(norm, normalize(lightDir)), 0.0); color = aColor * (0.3 + 0.7 * diff); }""" frag = """#version 330 core in vec3 color; out vec4 FragColor; void main(){ FragColor = vec4(color, 1.0); }""" self.shader = shaders.compileProgram(shaders.compileShader(vert, gl.GL_VERTEX_SHADER), shaders.compileShader(frag, gl.GL_FRAGMENT_SHADER)) line_vert = """#version 330 core layout(location = 0) in vec3 aPos; uniform mat4 MVP; void main(){ gl_Position = MVP * vec4(aPos, 1.0); }""" line_frag = """#version 330 core uniform vec3 lineColor; out vec4 FragColor; void main(){ FragColor = vec4(lineColor, 1.0); }""" self.line_shader = shaders.compileProgram(shaders.compileShader(line_vert, gl.GL_VERTEX_SHADER), shaders.compileShader(line_frag, gl.GL_FRAGMENT_SHADER)) def upload_mesh(self, vertices, normals, colors, faces): if self.vao is None: self.vao = gl.glGenVertexArrays(1) gl.glBindVertexArray(self.vao) # 顶点 if gl.glIsBuffer(gl.GLuint(0)) == gl.GL_FALSE: self.vbo = gl.glGenBuffers(1) gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.vbo) gl.glBufferData(gl.GL_ARRAY_BUFFER, vertices.nbytes, vertices, gl.GL_DYNAMIC_DRAW) # <<< NEW 5 gl.glVertexAttribPointer(0, 3, gl.GL_FLOAT, gl.GL_FALSE, 0, None); gl.glEnableVertexAttribArray(0) # 法线 if gl.glIsBuffer(gl.GLuint(0)) == gl.GL_FALSE: self.nbo = gl.glGenBuffers(1) gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.nbo) gl.glBufferData(gl.GL_ARRAY_BUFFER, normals.nbytes, normals, gl.GL_DYNAMIC_DRAW) gl.glVertexAttribPointer(1, 3, gl.GL_FLOAT, gl.GL_FALSE, 0, None); gl.glEnableVertexAttribArray(1) # 颜色 if gl.glIsBuffer(gl.GLuint(0)) == gl.GL_FALSE: self.cbo = gl.glGenBuffers(1) gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.cbo) gl.glBufferData(gl.GL_ARRAY_BUFFER, colors.nbytes, colors, gl.GL_DYNAMIC_DRAW) gl.glVertexAttribPointer(2, 3, gl.GL_FLOAT, gl.GL_FALSE, 0, None); gl.glEnableVertexAttribArray(2) # 索引 if self.ebo is None: self.ebo = gl.glGenBuffers(1) gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, self.ebo) gl.glBufferData(gl.GL_ELEMENT_ARRAY_BUFFER, faces.nbytes, faces, gl.GL_DYNAMIC_DRAW) self.face_count = len(faces) gl.glBindVertexArray(0) def upload_lines(self, line_perf, line_err): # 理论线 if self.line_vao_perf is None: self.line_vao_perf = gl.glGenVertexArrays(1) gl.glBindVertexArray(self.line_vao_perf) if gl.glIsBuffer(gl.GLuint(0)) == gl.GL_FALSE: self.line_vbo_perf = gl.glGenBuffers(1) gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.line_vbo_perf) gl.glBufferData(gl.GL_ARRAY_BUFFER, line_perf.nbytes, line_perf, gl.GL_DYNAMIC_DRAW) gl.glVertexAttribPointer(0, 3, gl.GL_FLOAT, gl.GL_FALSE, 0, None); gl.glEnableVertexAttribArray(0) self.line_count_perf = len(line_perf) # 误差线 if self.line_vao_err is None: self.line_vao_err = gl.glGenVertexArrays(1) gl.glBindVertexArray(self.line_vao_err) if gl.glIsBuffer(gl.GLuint(0)) == gl.GL_FALSE: self.line_vbo_err = gl.glGenBuffers(1) gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.line_vbo_err) gl.glBufferData(gl.GL_ARRAY_BUFFER, line_err.nbytes, line_err, gl.GL_DYNAMIC_DRAW) gl.glVertexAttribPointer(0, 3, gl.GL_FLOAT, gl.GL_FALSE, 0, None); gl.glEnableVertexAttribArray(0) self.line_count_err = len(line_err) gl.glBindVertexArray(0) def render(self, mvp, light_dir): gl.glUseProgram(self.shader) gl.glUniformMatrix4fv(gl.glGetUniformLocation(self.shader, "MVP"), 1, gl.GL_FALSE, glm.value_ptr(mvp)) gl.glUniform3f(gl.glGetUniformLocation(self.shader, "lightDir"), *light_dir) gl.glBindVertexArray(self.vao) gl.glDrawElements(gl.GL_TRIANGLES, self.face_count, gl.GL_UNSIGNED_INT, None) # 线条 gl.glUseProgram(self.line_shader) gl.glUniformMatrix4fv(gl.glGetUniformLocation(self.line_shader, "MVP"), 1, gl.GL_FALSE, glm.value_ptr(mvp)) gl.glUniform3f(gl.glGetUniformLocation(self.line_shader, "lineColor"), 1.0, 1.0, 1.0) gl.glBindVertexArray(self.line_vao_perf); gl.glDrawArrays(gl.GL_LINE_LOOP, 0, self.line_count_perf) gl.glUniform3f(gl.glGetUniformLocation(self.line_shader, "lineColor"), 1.0, 0.0, 0.0) gl.glBindVertexArray(self.line_vao_err); gl.glDrawArrays(gl.GL_LINE_LOOP, 0, self.line_count_err) gl.glBindVertexArray(0) # ---------------- GLWindow ---------------- class GLWindow: def __init__(self, params: WormParams, errors: MachiningError): self.params, self.errors = params, errors self.renderer = GLRenderer() self.camera_distance = 80.0 self.camera_rotation = glm.vec2(0.0, 0.0) self.last_mouse_pos = None self.window = None self._init_glfw() self._init_scene() def _init_glfw(self): if not glfw.init(): msg.showerror("初始化失败", "GLFW 初始化失败,请检查显卡驱动") raise RuntimeError("GLFW init failed") # <<< NEW 3 glfw.window_hint(glfw.SAMPLES, 8) glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3) glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3) glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE) self.window = glfw.create_window(1400, 900, "UG级蜗杆3D - 90°切除+误差热图+叠加线", None, None) if not self.window: glfw.terminate() # <<< NEW 3 raise RuntimeError("窗口创建失败") glfw.make_context_current(self.window) glfw.set_cursor_pos_callback(self.window, self._on_mouse_move) glfw.set_scroll_callback(self.window, self._on_scroll) glfw.set_key_callback(self.window, self._on_key) gl.glEnable(gl.GL_DEPTH_TEST) gl.glEnable(gl.GL_MULTISAMPLE) gl.glEnable(gl.GL_CULL_FACE); gl.glCullFace(gl.GL_BACK) # <<< NEW 4 gl.glClearColor(0.05, 0.05, 0.05, 1.0) def _init_scene(self): self.renderer.setup_shaders() self.update_mesh() # <<< NEW 5 def update_mesh(self): # <<< NEW 5 profile = theory_profile(self.params) vertices, normals, colors, faces = build_mesh_90_norm_color(self.params, profile, self.errors) self.renderer.upload_mesh(vertices, normals, colors, faces) line_perf, line_err = build_overlay_lines(self.params, self.errors) self.renderer.upload_lines(line_perf, line_err) def _get_mvp_matrix(self): proj = glm.perspective(glm.radians(45.0), 1400/900, 0.1, 1000.0) eye = glm.vec3(0, 0, self.camera_distance) rot_x = glm.rotate(glm.mat4(1.0), self.camera_rotation.x, glm.vec3(1, 0, 0)) rot_y = glm.rotate(rot_x, self.camera_rotation.y, glm.vec3(0, 1, 0)) eye = glm.vec3(rot_y * glm.vec4(eye, 1.0)) return proj * glm.lookAt(eye, glm.vec3(0), glm.vec3(0, 1, 0)) def _on_mouse_move(self, window, x, y): if self.last_mouse_pos is not None and glfw.get_mouse_button(window, glfw.MOUSE_BUTTON_LEFT) == glfw.PRESS: dx, dy = x - self.last_mouse_pos[0], y - self.last_mouse_pos[1] self.camera_rotation.x += dy * 0.01 self.camera_rotation.y += dx * 0.01 self.last_mouse_pos = (x, y) def _on_scroll(self, window, dx, dy): self.camera_distance = max(10.0, min(200.0, self.camera_distance - dy)) def _on_key(self, window, key, scancode, action, mods): if action == glfw.PRESS and key == glfw.KEY_R: self.camera_distance = 80.0; self.camera_rotation = glm.vec2(0.0, 0.0) elif action == glfw.PRESS and key == glfw.KEY_S: self._screenshot() def _screenshot(self): width, height = glfw.get_window_size(self.window) gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) data = gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE) os.makedirs("screenshots", exist_ok=True) import imageio, time filename = f"screenshots/worm_gear_{time.strftime('%Y%m%d_%H%M%S')}.png" image = np.flipud(np.frombuffer(data, dtype=np.uint8).reshape(height, width, 3)) imageio.imwrite(filename, image) logging.info(f"截图已保存: {filename}") def main_loop(self): while not glfw.window_should_close(self.window): glfw.poll_events() gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) mvp = self._get_mvp_matrix() light_dir = glm.normalize(glm.vec3(0.5, 0.5, 1.0)) self.renderer.render(mvp, light_dir) glfw.swap_buffers(self.window) glfw.terminate() def shutdown(self): # <<< NEW 5 if self.window: glfw.set_window_should_close(self.window, True) # ---------------- Tkinter 控制 ---------------- class ControlPanel: def __init__(self, root): self.root = root self.root.title("蜗杆参数控制") self.params, self.errors = WormParams(), MachiningError() self.gl_window = None # <<< NEW 5 self._build_ui() def _build_ui(self): # 参数区 pf = tk.LabelFrame(self.root, text="蜗杆参数"); pf.pack(fill=tk.X, padx=10, pady=5) self.param_vars = {} for i, (lab, attr, minv, maxv, step) in enumerate([ ("模数 (mm)", "module", 1.0, 10.0, 0.1), ("齿数", "teeth", 1, 50, 1), ("长度 (mm)", "length", 10, 200, 1), ("压力角 (°)", "pressure_angle", 10, 30, 1), ("螺旋角 (°)", "helix_angle", 1, 20, 1), ("齿侧间隙 (mm)", "backlash", 0.0, 0.5, 0.01), ("齿根系数", "root_factor", 0.8, 1.4, 0.05)]): fr = tk.Frame(pf); fr.grid(row=i//2, column=i%2, sticky="ew", padx=5, pady=2) tk.Label(fr, text=lab, width=12).pack(side=tk.LEFT) var = tk.DoubleVar(value=getattr(self.params, attr)) tk.Spinbox(fr, from_=minv, to=maxv, increment=step, textvariable=var, width=8).pack(side=tk.RIGHT) self.param_vars[attr] = var # 误差区 ef = tk.LabelFrame(self.root, text="加工误差"); ef.pack(fill=tk.X, padx=10, pady=5) self.error_vars = {} for i, (lab, attr, minv, maxv, step) in enumerate([ ("X向误差 (μm)", "tool_dx_um", -50, 50, 1), ("Y向误差 (μm)", "tool_dy_um", -50, 50, 1), ("RX误差 (°)", "tool_rx_err_deg", -5, 5, 0.1), ("RY误差 (°)", "tool_ry_err_deg", -5, 5, 0.1), ("RZ误差 (°)", "tool_rz_err_deg", -5, 5, 0.1), ("螺距误差 (μm)", "pitch_err_um", -50, 50, 1)]): fr = tk.Frame(ef); fr.grid(row=i//2, column=i%2, sticky="ew", padx=5, pady=2) tk.Label(fr, text=lab, width=12).pack(side=tk.LEFT) var = tk.DoubleVar(value=getattr(self.errors, attr)) tk.Spinbox(fr, from_=minv, to=maxv, increment=step, textvariable=var, width=8).pack(side=tk.RIGHT) self.error_vars[attr] = var # 按钮区 bf = tk.Frame(self.root); bf.pack(fill=tk.X, padx=10, pady=10) tk.Button(bf, text="更新模型", command=self.update_model).pack(side=tk.LEFT, padx=5) tk.Button(bf, text="重置参数", command=self.reset_params).pack(side=tk.LEFT, padx=5) tk.Button(bf, text="导出STL", command=self.export_stl).pack(side=tk.LEFT, padx=5) tk.Button(bf, text="退出", command=self.root.quit).pack(side=tk.RIGHT, padx=5) def update_model(self): for attr, var in self.param_vars.items(): setattr(self.params, attr, var.get()) for attr, var in self.error_vars.items(): setattr(self.errors, attr, var.get()) self.params._derive() if self.gl_window is None: # 首次创建 self.gl_window = GLWindow(self.params, self.errors) self.root.after(100, self._tk_poll_gl) # 非阻塞轮询 else: # 热更新 self.gl_window.params = self.params self.gl_window.errors = self.errors self.gl_window.update_mesh() def _tk_poll_gl(self): # <<< NEW 2 if self.gl_window and glfw.window_should_close(self.gl_window.window): self.gl_window.shutdown() self.gl_window = None if self.gl_window: glfw.poll_events() self.root.after(100, self._tk_poll_gl) def reset_params(self): self.params, self.errors = WormParams(), MachiningError() for attr, var in {**self.param_vars, **self.error_vars}.items(): var.set(getattr(self.params if attr in self.param_vars else self.errors, attr)) def export_stl(self): profile = theory_profile(self.params) vertices, normals, colors, faces = build_mesh_90_norm_color(self.params, profile, self.errors) stl_mesh = mesh.Mesh(np.zeros(faces.shape[0], dtype=mesh.Mesh.dtype)) for i, face in enumerate(faces): stl_mesh.vectors[i] = vertices[face] filename = fd.asksaveasfilename(defaultextension=".stl", filetypes=[("STL文件", "*.stl")], title="保存蜗杆STL文件") if filename: stl_mesh.save(filename) msg.showinfo("导出成功", f"STL文件已保存到: {filename}") # ---------------- main ---------------- def main(): root = tk.Tk() ControlPanel(root) root.mainloop() if __name__ == "__main__": main()
09-19
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值