Super PI Records

本文通过一系列测试对比了不同型号CPU的性能,包括Pentium4HT、CoreDuoT2400、Core2DuoT7200等,并记录了在不同系统环境下如Vista、Ubuntu的运行时间。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1M (CPU, source, speed)
==
Pentium 4 HT, office, 37s
T2080, forum, 38s
T5300, forum, 24s
E6300, forum, 30s
P4, forum, 45s
Core Duo T2400, pic, 33s
Core 2 Duo T7200, pic, 25s
T5500, web, 20s
T2130, web, 36s
T5250*,
T2130*, web, 45s?

T5500, mine (Vista), 35s

T5500, mine (Ubuntu+wine), 31s

更改代码使其将./data/中的目录作为站点名,并识别同一目录中的所有数据为同一站点的数据 import os import re import glob import numpy as np import pandas as pd import matplotlib.pyplot as plt from pyproj import Transformer from sklearn.preprocessing import StandardScaler import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader # 设置中文显示 plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False # ============================== # 1. 增强型文件加载器 # ============================== class EnhancedTropoLoader: def __init__(self, data_root): self.data_root = data_root self.transformer = Transformer.from_crs("EPSG:4978", "EPSG:4326") self.site_cache = {} # 站点坐标缓存 def _parse_site_code(self, filename): """增强版文件名解析""" patterns = [ r"_([A-Z]{4})\d{2}[A-Z]{3}_TRO\.TRO$", # ABMF00GLP → ABMF r"_([A-Z]{4}\d{2})[A-Z]{3}_TRO\.TRO$", # AC2300USA → AC23 r"_([A-Z]{4})00([A-Z]{3})_TRO\.TRO$", # ABPO00MDG → ABPO r"_([A-Z]{4}\d{2})_TRO\.TRO$" # AC2300_TRO.TRO → AC23 ] for pattern in patterns: match = re.search(pattern, filename) if match: code = match.group(1) # 清理尾部数字(如果存在) return re.sub(r'\d{2}$', '', code) if len(code) > 4 else code return None def _parse_coordinates(self, file_path): """解析并缓存站点坐标""" if file_path in self.site_cache: return self.site_cache[file_path] coordinates = None try: with open(file_path, 'r') as f: current_section = None for line in f: line = line.strip() if line.startswith('+'): current_section = line[1:] elif line.startswith('-'): current_section = None elif current_section == 'TROP/STA_COORDINATES' and not line.startswith('*'): parts = line.split() if len(parts) >= 7: coordinates = { 'x': float(parts[4]), 'y': float(parts[5]), 'z': float(parts[6]) } break # 找到坐标后立即退出 except Exception as e: print(f"坐标解析失败 [{file_path}]: {str(e)}") self.site_cache[file_path] = coordinates return coordinates def _convert_coords(self, coords): """坐标转换(带异常处理)""" try: lat, lon, alt = self.transformer.transform( coords['x'], coords['y'], coords['z'] ) return lat, lon, alt except Exception as e: print(f"坐标转换失败: {str(e)}") return None, None, None def _parse_observations(self, file_path, site_code): """解析观测数据""" records = [] try: with open(file_path, 'r') as f: current_section = None for line in f: line = line.strip() if line.startswith('+'): current_section = line[1:] elif line.startswith('-'): current_section = None elif current_section == 'TROP/SOLUTION' and not line.startswith('*'): parts = line.split() if len(parts) >= 7: records.append({ 'epoch': parts[1], 'trotot': float(parts[2]), 'stddev': float(parts[3]), 'tgntot': float(parts[4]), 'tgetot': float(parts[6]), 'site': site_code }) except Exception as e: print(f"数据解析失败 [{file_path}]: {str(e)}") return records def load_all_data(self): """加载并合并所有有效数据""" all_dfs = [] # 递归查找所有.TRO文件 for file_path in glob.glob(os.path.join(self.data_root, '**', '*.TRO'), recursive=True): site_code = self._parse_site_code(os.path.basename(file_path)) if not site_code: print(f"跳过无法解析站点的文件: {file_path}") continue # 获取坐标 coords = self._parse_coordinates(file_path) if not coords: print(f"跳过无有效坐标的文件: {file_path}") continue # 坐标转换 lat, lon, alt = self._convert_coords(coords) if None in (lat, lon, alt): continue # 解析观测数据 records = self._parse_observations(file_path, site_code) if not records: print(f"跳过无有效数据的文件: {file_path}") continue # 创建DataFrame df = pd.DataFrame(records) df['lat'] = lat df['lon'] = lon df['alt'] = alt all_dfs.append(df) print(f"成功加载: {file_path} 记录数: {len(df)}") return pd.concat(all_dfs) if all_dfs else pd.DataFrame() # ============================== # 2. 时间序列数据集 # ============================== # ============================== # 2. 时间序列数据集(修正后) # ============================== class TemporalDataset(Dataset): def __init__(self, data, window_size=6): self.window_size = window_size self.feature_cols = ['trotot', 'tgntot', 'tgetot', 'stddev', 'lat', 'lon', 'alt', 'hour'] self.site_to_id = {site: idx for idx, site in enumerate(data['site'].unique())} # 按站点和时间排序 data = data.sort_values(['site', 'time']) # 生成序列 self.sequences = [] self.targets = [] self.site_labels = [] self.timestamps = [] for site, group in data.groupby('site'): values = group[self.feature_cols].values times = group['time'].values # 转换为Unix时间戳 unix_times = (times.astype(np.datetime64) - np.datetime64('1970-01-01T00:00:00')) / np.timedelta64(1, 's') for i in range(len(values) - window_size): self.sequences.append(values[i:i + window_size]) self.targets.append(values[i + window_size][0]) self.site_labels.append(self.site_to_id[site]) self.timestamps.append(unix_times[i + window_size]) # 关键补丁:记录数据长度 self.num_samples = len(self.sequences) def __len__(self): return self.num_samples # 现在数据集有明确长度 def __getitem__(self, idx): return ( torch.FloatTensor(self.sequences[idx]), torch.FloatTensor([self.targets[idx]]), torch.tensor(self.site_labels[idx], dtype=torch.long), torch.FloatTensor([self.timestamps[idx]]) ) # ============================== # 3. 站点感知LSTM模型 # ============================== class SiteLSTM(nn.Module): def __init__(self, input_size, num_sites, hidden_size=64): super().__init__() self.embedding = nn.Embedding(num_sites, 8) self.lstm = nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True, dropout=0.3 ) self.regressor = nn.Sequential( nn.Linear(hidden_size + 8, 32), nn.LayerNorm(32), nn.ReLU(), nn.Dropout(0.2), nn.Linear(32, 1)) def forward(self, x, site_ids): lstm_out, _ = self.lstm(x) site_emb = self.embedding(site_ids) combined = torch.cat([lstm_out[:, -1, :], site_emb], dim=1) return self.regressor(combined) # ============================== # 4. 训练和评估模块 # ============================== class TropoTrainer: def __init__(self, data_root='./data'): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.loader = EnhancedTropoLoader(data_root) self.scaler = StandardScaler() def _preprocess(self, raw_df): """数据预处理""" # 解析时间 raw_df['time'] = raw_df['epoch'].apply( lambda x: pd.to_datetime( f"20{x.split(':')[0]}-{x.split(':')[1]}", format='%Y-%j' ) + pd.to_timedelta(int(x.split(':')[2]), unit='s') ) raw_df = raw_df.dropna(subset=['time']) # 时间特征 raw_df['hour'] = raw_df['time'].dt.hour raw_df['doy_sin'] = np.sin(2 * np.pi * raw_df['time'].dt.dayofyear / 365) raw_df['doy_cos'] = np.cos(2 * np.pi * raw_df['time'].dt.dayofyear / 365) # 标准化 feature_cols = ['trotot', 'tgntot', 'tgetot', 'stddev', 'lat', 'lon', 'alt', 'hour'] raw_df[feature_cols] = self.scaler.fit_transform(raw_df[feature_cols]) return raw_df def _inverse_transform(self, values): """反标准化""" dummy = np.zeros((len(values), len(self.scaler.feature_names_in_))) dummy[:, 0] = values return self.scaler.inverse_transform(dummy)[:, 0] def train(self, window_size=6, epochs=100, batch_size=32): # 加载数据 raw_df = self.loader.load_all_data() if raw_df.empty: raise ValueError("未加载到有效数据") # 预处理 processed_df = self._preprocess(raw_df) # 创建数据集 full_dataset = TemporalDataset(processed_df, window_size) # 划分数据集(按时间顺序) train_size = int(0.8 * len(full_dataset)) train_dataset, test_dataset = torch.utils.data.random_split( full_dataset, [train_size, len(full_dataset) - train_size], generator=torch.Generator().manual_seed(42) ) # 初始化模型 model = SiteLSTM( input_size=len(full_dataset.feature_cols), num_sites=len(full_dataset.site_to_id) ).to(self.device) # 训练配置 optimizer = optim.AdamW(model.parameters(), lr=1e-4) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5) criterion = nn.MSELoss() # 训练循环 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) best_loss = float('inf') history = {'train': [], 'val': []} for epoch in range(epochs): # 训练阶段 model.train() train_loss = 0 for seq, target, site, _ in train_loader: seq = seq.to(self.device) target = target.to(self.device) site = site.to(self.device) optimizer.zero_grad() pred = model(seq, site) loss = criterion(pred, target) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() train_loss += loss.item() # 验证阶段 model.eval() val_loss = 0 predictions = [] with torch.no_grad(): val_loader = DataLoader(test_dataset, batch_size=128) for seq, target, site, _ in val_loader: seq = seq.to(self.device) target = target.to(self.device) site = site.to(self.device) pred = model(seq, site) val_loss += criterion(pred, target).item() predictions.append(pred.cpu().numpy()) # 记录历史 avg_train = train_loss / len(train_loader) avg_val = val_loss / len(val_loader) history['train'].append(avg_train) history['val'].append(avg_val) scheduler.step(avg_val) # 保存最佳模型 if avg_val < best_loss: best_loss = avg_val torch.save(model.state_dict(), 'best_model.pth') print(f"Epoch {epoch + 1:03d} | Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}") # 加载最佳模型 model.load_state_dict(torch.load('best_model.pth')) return model, history def evaluate(self, model, output_dir='results'): """评估并保存结果""" os.makedirs(output_dir, exist_ok=True) # 重新加载完整数据 raw_df = self.loader.load_all_data() processed_df = self._preprocess(raw_df) full_dataset = TemporalDataset(processed_df, window_size=6) # 创建数据加载器 test_loader = DataLoader(full_dataset, batch_size=128) # 收集结果 results = [] model.eval() with torch.no_grad(): for seq, target, site, timestamp in test_loader: # ...获取预测值... timestamp = timestamp.numpy().flatten() datetime_objs = pd.to_datetime(timestamp, unit='s') seq = seq.to(self.device) site = site.to(self.device) pred = model(seq, site).cpu().numpy().flatten() true = target.numpy().flatten() # 反标准化 pred = self._inverse_transform(pred) true = self._inverse_transform(true) # 收集数据 for p, t, s, ts in zip(pred, true, site, datetime_objs): results.append({ 'site': list(full_dataset.site_to_id.keys())[s], 'timestamp': ts, 'true': t, 'pred': p }) # 转换为DataFrame result_df = pd.DataFrame(results) # 按站点保存结果 for site, group in result_df.groupby('site'): site_dir = os.path.join(output_dir, site) os.makedirs(site_dir, exist_ok=True) # CSV文件 csv_path = os.path.join(site_dir, f'{site}_predictions.csv') group.to_csv(csv_path, index=False) # Excel文件 excel_path = os.path.join(site_dir, f'{site}_predictions.xlsx') group.to_excel(excel_path, index=False) # 生成对比图 plt.figure(figsize=(12, 6)) plt.plot(group['timestamp'], group['true'], label='真实值') plt.plot(group['timestamp'], group['pred'], label='预测值', linestyle='--') plt.title(f'站点 {site} 对流层延迟预测对比') plt.xlabel('时间') plt.ylabel('延迟量 (mm)') plt.legend() plt.gcf().autofmt_xdate() plot_path = os.path.join(site_dir, f'{site}_comparison.png') plt.savefig(plot_path, bbox_inches='tight') plt.close() # 保存汇总文件 result_df.to_csv(os.path.join(output_dir, 'all_predictions.csv'), index=False) result_df.to_excel(os.path.join(output_dir, 'all_predictions.xlsx'), index=False) print(f"结果已保存至 {output_dir} 目录") return result_df # ============================== # 主程序 # ============================== if __name__ == "__main__": # 初始化训练器 trainer = TropoTrainer(data_root='./data') try: # 训练模型 model, history = trainer.train(epochs=100) # 可视化训练过程 plt.figure(figsize=(10, 5)) plt.plot(history['train'], label='训练损失') plt.plot(history['val'], label='验证损失') plt.title('训练过程') plt.xlabel('Epoch') plt.ylabel('MSE Loss') plt.legend() plt.savefig('training_history.png', bbox_inches='tight') # 评估并保存结果 results = trainer.evaluate(model) # 生成统计报告 report = results.groupby('site').apply(lambda x: pd.Series({ 'MAE(mm)': np.mean(np.abs(x['pred'] - x['true'])), 'Max_True': np.max(x['true']), 'Min_True': np.min(x['true']), 'Max_Pred': np.max(x['pred']), 'Min_Pred': np.min(x['pred']), 'Samples': len(x) })).reset_index() print("\n站点预测性能报告:") print(report.to_markdown(index=False)) except Exception as e: print(f"运行出错: {str(e)}")
05-27
重新发送给我这份代码更改后的完整代码,仔细检查不要有错误和遗漏,不要省略任何部分(逻辑相似的也不行): import os import re import glob import numpy as np import pandas as pd import matplotlib.pyplot as plt from pyproj import Transformer from sklearn.preprocessing import StandardScaler import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader # 设置中文显示 plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False class DirectoryBasedLoader: def __init__(self, data_root): self.data_root = data_root self.transformer = Transformer.from_crs("EPSG:4978", "EPSG:4326") self.site_cache = {} self.feature_names = [ 'trotot', 'tgntot', 'tgetot', 'stddev', 'lat', 'lon', 'alt', 'hour', 'doy_sin', 'doy_cos' ] def _parse_coordinates(self, file_path): """解析站点坐标(带缓存)""" if file_path in self.site_cache: return self.site_cache[file_path] coordinates = None try: with open(file_path, 'r', encoding='utf-8') as f: current_section = None for line in f: line = line.strip() if line.startswith('+'): current_section = line[1:] elif line.startswith('-'): current_section = None elif current_section == 'TROP/STA_COORDINATES' and not line.startswith('*'): parts = line.split() if len(parts) >= 7: coordinates = { 'x': float(parts[4]), 'y': float(parts[5]), 'z': float(parts[6]) } break except UnicodeDecodeError: try: with open(file_path, 'r', encoding='latin-1') as f: # 相同的内容解析逻辑... current_section = None for line in f: line = line.strip() if line.startswith('+'): current_section = line[1:] elif line.startswith('-'): current_section = None elif current_section == 'TROP/STA_COORDINATES' and not line.startswith('*'): parts = line.split() if len(parts) >= 7: coordinates = { 'x': float(parts[4]), 'y': float(parts[5]), 'z': float(parts[6]) } break except Exception as e: print(f"坐标解析失败: {str(e)}") except Exception as e: print(f"坐标解析失败: {str(e)}") self.site_cache[file_path] = coordinates return coordinates def _convert_coords(self, coords): """坐标转换(带异常处理)""" try: lat, lon, alt = self.transformer.transform( coords['x'], coords['y'], coords['z'] ) return lat, lon, alt except Exception as e: print(f"坐标转换失败: {str(e)}") return None, None, None def _parse_observations(self, file_path, site_code): """解析观测数据""" records = [] try: with open(file_path, 'r', encoding='utf-8') as f: current_section = None for line in f: line = line.strip() if line.startswith('+'): current_section = line[1:] elif line.startswith('-'): current_section = None elif current_section == 'TROP/SOLUTION' and not line.startswith('*'): parts = line.split() if len(parts) >= 7: records.append({ 'epoch': parts[1], 'trotot': float(parts[2]), 'stddev': float(parts[3]), 'tgntot': float(parts[4]), 'tgetot': float(parts[6]), 'site': site_code # 使用目录名作为站点代码 }) except UnicodeDecodeError: try: with open(file_path, 'r', encoding='latin-1') as f: # 相同的内容解析逻辑... current_section = None for line in f: line = line.strip() if line.startswith('+'): current_section = line[1:] elif line.startswith('-'): current_section = None elif current_section == 'TROP/SOLUTION' and not line.startswith('*'): parts = line.split() if len(parts) >= 7: records.append({ 'epoch': parts[1], 'trotot': float(parts[2]), 'stddev': float(parts[3]), 'tgntot': float(parts[4]), 'tgetot': float(parts[6]), 'site': site_code # 使用目录名作为站点代码 }) except Exception as e: print(f"观测数据读取失败: {str(e)}") except Exception as e: print(f"观测数据读取失败: {str(e)}") return records def load_all_data(self): """加载目录结构数据""" all_dfs = [] # 获取所有站点目录 site_dirs = [d for d in glob.glob(os.path.join(self.data_root, '*')) if os.path.isdir(d)] for site_dir in site_dirs: site_code = os.path.basename(site_dir) print(f"正在加载站点: {site_code}") # 加载该站点所有数据文件 for file_path in glob.glob(os.path.join(site_dir, '*.TRO')): # 解析坐标 coords = self._parse_coordinates(file_path) if not coords: print(f"跳过无有效坐标的文件: {file_path}") continue # 坐标转换 lat, lon, alt = self._convert_coords(coords) if None in (lat, lon, alt): continue # 解析观测数据 records = self._parse_observations(file_path, site_code) if not records: print(f"跳过无有效数据的文件: {file_path}") continue # 创建DataFrame df = pd.DataFrame(records) df['lat'] = lat df['lon'] = lon df['alt'] = alt all_dfs.append(df) print(f"成功加载: {file_path} 记录数: {len(df)}") return pd.concat(all_dfs) if all_dfs else pd.DataFrame() # ============================== # 2. 时间序列数据集 # ============================== # ============================== # 2. 时间序列数据集(修正后) # ============================== class TemporalDataset(Dataset): def __init__(self, data, window_size=6, features=None): self.window_size = window_size self.feature_names = features self.site_to_id = {site: idx for idx, site in enumerate(data['site'].unique())} # 数据预处理 data = data.sort_values(['site', 'time']).dropna(subset=self.feature_names) if data.empty: raise ValueError("输入数据为空或包含缺失值") # 生成序列数据 self.sequences = [] self.targets = [] self.site_labels = [] self.timestamps = [] for site, group in data.groupby('site'): if len(group) < self.window_size + 1: continue values = group[self.feature_names].values times = group['time'].values unix_times = (times.astype(np.datetime64) - np.datetime64('1970-01-01T00:00:00')) / np.timedelta64(1, 's') for i in range(len(values) - self.window_size): self.sequences.append(values[i:i + self.window_size]) self.targets.append(values[i + self.window_size][0]) self.site_labels.append(self.site_to_id[site]) self.timestamps.append(unix_times[i + self.window_size]) self.num_samples = len(self.sequences) if self.num_samples == 0: raise ValueError("没有生成有效样本") def __len__(self): return self.num_samples def __getitem__(self, idx): noise = torch.randn(self.window_size, len(self.feature_names)) * 0.01 return ( torch.FloatTensor(self.sequences[idx]) + noise, torch.FloatTensor([self.targets[idx]]), torch.tensor(self.site_labels[idx], dtype=torch.long), torch.FloatTensor([self.timestamps[idx]]) ) # ============================== # 3. 站点感知LSTM模型 # ============================== class SiteLSTM(nn.Module): def __init__(self, input_size, num_sites, hidden_size=64): super().__init__() self.embedding = nn.Embedding(num_sites, 8) self.lstm = nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True, dropout=0.3 ) self.regressor = nn.Sequential( nn.Linear(hidden_size + 8, 32), nn.LayerNorm(32), nn.ReLU(), nn.Dropout(0.2), nn.Linear(32, 1)) def forward(self, x, site_ids): lstm_out, _ = self.lstm(x) site_emb = self.embedding(site_ids) combined = torch.cat([lstm_out[:, -1, :], site_emb], dim=1) return self.regressor(combined) # ============================== # 4. 训练和评估模块 # ============================== class TropoTrainer: def __init__(self, data_root='./data'): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.loader = DirectoryBasedLoader(data_root) self.scaler = StandardScaler() def _preprocess(self, raw_df): # 空数据检查 if raw_df.empty: raise ValueError("原始数据为空,请检查数据加载逻辑") """数据预处理""" # 时间解析增强 try: raw_df['time'] = pd.to_datetime( raw_df['epoch'].str.replace(r'^(\d{2}):', r'20\1:', regex=True), format='%Y-%j:%H:%M:%S', errors='coerce' ) time_mask = raw_df['time'].isna() if time_mask.any(): print(f"发现{time_mask.sum()}条无效时间记录,示例无效数据:") print(raw_df[time_mask].head(2)) raw_df = raw_df[~time_mask].copy() except Exception as e: print(f"时间解析失败: {str(e)}") raise # 特征工程 if 'time' not in raw_df.columns: raise KeyError("时间列缺失,预处理失败") raw_df['hour'] = raw_df['time'].dt.hour raw_df['doy_sin'] = np.sin(2 * np.pi * raw_df['time'].dt.dayofyear / 365.25) raw_df['doy_cos'] = np.cos(2 * np.pi * raw_df['time'].dt.dayofyear / 365.25) # 特征列验证 required_features = ['trotot', 'tgntot', 'tgetot', 'stddev', 'lat', 'lon', 'alt', 'hour'] missing_features = [f for f in required_features if f not in raw_df.columns] if missing_features: raise KeyError(f"缺失关键特征列: {missing_features}") # 标准化 try: self.scaler.fit(raw_df[required_features]) raw_df[required_features] = self.scaler.transform(raw_df[required_features]) except ValueError as e: print(f"标准化失败: {str(e)}") print("数据统计信息:") print(raw_df[required_features].describe()) raise return raw_df def _inverse_transform(self, values): """反标准化""" dummy = np.zeros((len(values), len(self.scaler.feature_names_in_))) dummy[:, 0] = values return self.scaler.inverse_transform(dummy)[:, 0] def train(self, window_size=6, epochs=100, batch_size=32): try: # 加载数据 raw_df = self.loader.load_all_data() if raw_df.empty: raise ValueError("数据加载器返回空DataFrame") # 预处理 processed_df = self._preprocess(raw_df) print(f"预处理后数据量: {len(processed_df)}条") print("特征矩阵示例:") print(processed_df[['site', 'time', 'trotot']].head(3)) # 创建数据集 full_dataset = TemporalDataset(processed_df, window_size) print(f"生成有效样本数: {len(full_dataset)}") # 划分数据集(按时间顺序) train_size = int(0.8 * len(full_dataset)) train_dataset, test_dataset = torch.utils.data.random_split( full_dataset, [train_size, len(full_dataset) - train_size], generator=torch.Generator().manual_seed(42) ) # 初始化模型 model = SiteLSTM( input_size=len(full_dataset.feature_cols), num_sites=len(full_dataset.site_to_id) ).to(self.device) # 训练配置 optimizer = optim.AdamW(model.parameters(), lr=1e-4) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5) criterion = nn.MSELoss() # 训练循环 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) best_loss = float('inf') history = {'train': [], 'val': []} for epoch in range(epochs): # 训练阶段 model.train() train_loss = 0 for seq, target, site, _ in train_loader: seq = seq.to(self.device) target = target.to(self.device) site = site.to(self.device) optimizer.zero_grad() pred = model(seq, site) loss = criterion(pred, target) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() train_loss += loss.item() # 验证阶段 model.eval() val_loss = 0 predictions = [] with torch.no_grad(): val_loader = DataLoader(test_dataset, batch_size=128) for seq, target, site, _ in val_loader: seq = seq.to(self.device) target = target.to(self.device) site = site.to(self.device) pred = model(seq, site) val_loss += criterion(pred, target).item() predictions.append(pred.cpu().numpy()) # 记录历史 avg_train = train_loss / len(train_loader) avg_val = val_loss / len(val_loader) history['train'].append(avg_train) history['val'].append(avg_val) scheduler.step(avg_val) # 保存最佳模型 if avg_val < best_loss: best_loss = avg_val torch.save(model.state_dict(), 'best_model.pth') print(f"Epoch {epoch + 1:03d} | Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}") # 加载最佳模型 model.load_state_dict(torch.load('best_model.pth')) except Exception as e: print(f"\n{'*' * 40}") print(f"训练失败详细原因: {str(e)}") print(f"错误类型: {type(e).__name__}") import traceback traceback.print_exc() print(f"{'*' * 40}\n") raise return model, history def evaluate(self, model, output_dir='results'): """评估并保存结果""" os.makedirs(output_dir, exist_ok=True) # 重新加载完整数据 raw_df = self.loader.load_all_data() processed_df = self._preprocess(raw_df) full_dataset = TemporalDataset(processed_df, window_size=6) # 创建数据加载器 test_loader = DataLoader(full_dataset, batch_size=128) # 收集结果 results = [] model.eval() with torch.no_grad(): for seq, target, site, timestamp in test_loader: # ...获取预测值... timestamp = timestamp.numpy().flatten() datetime_objs = pd.to_datetime(timestamp, unit='s') seq = seq.to(self.device) site = site.to(self.device) pred = model(seq, site).cpu().numpy().flatten() true = target.numpy().flatten() # 反标准化 pred = self._inverse_transform(pred) true = self._inverse_transform(true) # 收集数据 for p, t, s, ts in zip(pred, true, site, datetime_objs): results.append({ 'site': list(full_dataset.site_to_id.keys())[s], 'timestamp': ts, 'true': t, 'pred': p }) # 转换为DataFrame result_df = pd.DataFrame(results) # 按站点保存结果 for site, group in result_df.groupby('site'): site_dir = os.path.join(output_dir, site) os.makedirs(site_dir, exist_ok=True) # CSV文件 csv_path = os.path.join(site_dir, f'{site}_predictions.csv') group.to_csv(csv_path, index=False) # Excel文件 excel_path = os.path.join(site_dir, f'{site}_predictions.xlsx') group.to_excel(excel_path, index=False) # 生成对比图 plt.figure(figsize=(12, 6)) plt.plot(group['timestamp'], group['true'], label='真实值') plt.plot(group['timestamp'], group['pred'], label='预测值', linestyle='--') plt.title(f'站点 {site} 对流层延迟预测对比') plt.xlabel('时间') plt.ylabel('延迟量 (mm)') plt.legend() plt.gcf().autofmt_xdate() plot_path = os.path.join(site_dir, f'{site}_comparison.png') plt.savefig(plot_path, bbox_inches='tight') plt.close() # 保存汇总文件 result_df.to_csv(os.path.join(output_dir, 'all_predictions.csv'), index=False) result_df.to_excel(os.path.join(output_dir, 'all_predictions.xlsx'), index=False) print(f"结果已保存至 {output_dir} 目录") return result_df # ============================== # 主程序 # ============================== if __name__ == "__main__": # 初始化训练器 trainer = TropoTrainer(data_root='./data') try: # 训练模型 model, history = trainer.train(epochs=100) # 可视化训练过程 plt.figure(figsize=(10, 5)) plt.plot(history['train'], label='训练损失') plt.plot(history['val'], label='验证损失') plt.title('训练过程') plt.xlabel('Epoch') plt.ylabel('MSE Loss') plt.legend() plt.savefig('training_history.png', bbox_inches='tight') # 评估并保存结果 results = trainer.evaluate(model) # 生成统计报告 report = results.groupby('site').apply(lambda x: pd.Series({ 'MAE(mm)': np.mean(np.abs(x['pred'] - x['true'])), 'Max_True': np.max(x['true']), 'Min_True': np.min(x['true']), 'Max_Pred': np.max(x['pred']), 'Min_Pred': np.min(x['pred']), 'Samples': len(x) })).reset_index() print("\n站点预测性能报告:") print(report.to_markdown(index=False)) except Exception as e: print(f"运行出错: {str(e)}")
最新发布
05-27
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值