重新发送给我这份代码更改后的完整代码,仔细检查不要有错误和遗漏,不要省略任何部分(逻辑相似的也不行):
import os
import re
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyproj import Transformer
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
class DirectoryBasedLoader:
def __init__(self, data_root):
self.data_root = data_root
self.transformer = Transformer.from_crs("EPSG:4978", "EPSG:4326")
self.site_cache = {}
self.feature_names = [
'trotot', 'tgntot', 'tgetot', 'stddev',
'lat', 'lon', 'alt', 'hour', 'doy_sin', 'doy_cos'
]
def _parse_coordinates(self, file_path):
"""解析站点坐标(带缓存)"""
if file_path in self.site_cache:
return self.site_cache[file_path]
coordinates = None
try:
with open(file_path, 'r', encoding='utf-8') as f:
current_section = None
for line in f:
line = line.strip()
if line.startswith('+'):
current_section = line[1:]
elif line.startswith('-'):
current_section = None
elif current_section == 'TROP/STA_COORDINATES' and not line.startswith('*'):
parts = line.split()
if len(parts) >= 7:
coordinates = {
'x': float(parts[4]),
'y': float(parts[5]),
'z': float(parts[6])
}
break
except UnicodeDecodeError:
try:
with open(file_path, 'r', encoding='latin-1') as f:
# 相同的内容解析逻辑...
current_section = None
for line in f:
line = line.strip()
if line.startswith('+'):
current_section = line[1:]
elif line.startswith('-'):
current_section = None
elif current_section == 'TROP/STA_COORDINATES' and not line.startswith('*'):
parts = line.split()
if len(parts) >= 7:
coordinates = {
'x': float(parts[4]),
'y': float(parts[5]),
'z': float(parts[6])
}
break
except Exception as e:
print(f"坐标解析失败: {str(e)}")
except Exception as e:
print(f"坐标解析失败: {str(e)}")
self.site_cache[file_path] = coordinates
return coordinates
def _convert_coords(self, coords):
"""坐标转换(带异常处理)"""
try:
lat, lon, alt = self.transformer.transform(
coords['x'], coords['y'], coords['z']
)
return lat, lon, alt
except Exception as e:
print(f"坐标转换失败: {str(e)}")
return None, None, None
def _parse_observations(self, file_path, site_code):
"""解析观测数据"""
records = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
current_section = None
for line in f:
line = line.strip()
if line.startswith('+'):
current_section = line[1:]
elif line.startswith('-'):
current_section = None
elif current_section == 'TROP/SOLUTION' and not line.startswith('*'):
parts = line.split()
if len(parts) >= 7:
records.append({
'epoch': parts[1],
'trotot': float(parts[2]),
'stddev': float(parts[3]),
'tgntot': float(parts[4]),
'tgetot': float(parts[6]),
'site': site_code # 使用目录名作为站点代码
})
except UnicodeDecodeError:
try:
with open(file_path, 'r', encoding='latin-1') as f:
# 相同的内容解析逻辑...
current_section = None
for line in f:
line = line.strip()
if line.startswith('+'):
current_section = line[1:]
elif line.startswith('-'):
current_section = None
elif current_section == 'TROP/SOLUTION' and not line.startswith('*'):
parts = line.split()
if len(parts) >= 7:
records.append({
'epoch': parts[1],
'trotot': float(parts[2]),
'stddev': float(parts[3]),
'tgntot': float(parts[4]),
'tgetot': float(parts[6]),
'site': site_code # 使用目录名作为站点代码
})
except Exception as e:
print(f"观测数据读取失败: {str(e)}")
except Exception as e:
print(f"观测数据读取失败: {str(e)}")
return records
def load_all_data(self):
"""加载目录结构数据"""
all_dfs = []
# 获取所有站点目录
site_dirs = [d for d in glob.glob(os.path.join(self.data_root, '*'))
if os.path.isdir(d)]
for site_dir in site_dirs:
site_code = os.path.basename(site_dir)
print(f"正在加载站点: {site_code}")
# 加载该站点所有数据文件
for file_path in glob.glob(os.path.join(site_dir, '*.TRO')):
# 解析坐标
coords = self._parse_coordinates(file_path)
if not coords:
print(f"跳过无有效坐标的文件: {file_path}")
continue
# 坐标转换
lat, lon, alt = self._convert_coords(coords)
if None in (lat, lon, alt):
continue
# 解析观测数据
records = self._parse_observations(file_path, site_code)
if not records:
print(f"跳过无有效数据的文件: {file_path}")
continue
# 创建DataFrame
df = pd.DataFrame(records)
df['lat'] = lat
df['lon'] = lon
df['alt'] = alt
all_dfs.append(df)
print(f"成功加载: {file_path} 记录数: {len(df)}")
return pd.concat(all_dfs) if all_dfs else pd.DataFrame()
# ==============================
# 2. 时间序列数据集
# ==============================
# ==============================
# 2. 时间序列数据集(修正后)
# ==============================
class TemporalDataset(Dataset):
def __init__(self, data, window_size=6, features=None):
self.window_size = window_size
self.feature_names = features
self.site_to_id = {site: idx for idx, site in enumerate(data['site'].unique())}
# 数据预处理
data = data.sort_values(['site', 'time']).dropna(subset=self.feature_names)
if data.empty:
raise ValueError("输入数据为空或包含缺失值")
# 生成序列数据
self.sequences = []
self.targets = []
self.site_labels = []
self.timestamps = []
for site, group in data.groupby('site'):
if len(group) < self.window_size + 1:
continue
values = group[self.feature_names].values
times = group['time'].values
unix_times = (times.astype(np.datetime64) - np.datetime64('1970-01-01T00:00:00')) / np.timedelta64(1, 's')
for i in range(len(values) - self.window_size):
self.sequences.append(values[i:i + self.window_size])
self.targets.append(values[i + self.window_size][0])
self.site_labels.append(self.site_to_id[site])
self.timestamps.append(unix_times[i + self.window_size])
self.num_samples = len(self.sequences)
if self.num_samples == 0:
raise ValueError("没有生成有效样本")
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
noise = torch.randn(self.window_size, len(self.feature_names)) * 0.01
return (
torch.FloatTensor(self.sequences[idx]) + noise,
torch.FloatTensor([self.targets[idx]]),
torch.tensor(self.site_labels[idx], dtype=torch.long),
torch.FloatTensor([self.timestamps[idx]])
)
# ==============================
# 3. 站点感知LSTM模型
# ==============================
class SiteLSTM(nn.Module):
def __init__(self, input_size, num_sites, hidden_size=64):
super().__init__()
self.embedding = nn.Embedding(num_sites, 8)
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=2,
batch_first=True,
dropout=0.3
)
self.regressor = nn.Sequential(
nn.Linear(hidden_size + 8, 32),
nn.LayerNorm(32),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(32, 1))
def forward(self, x, site_ids):
lstm_out, _ = self.lstm(x)
site_emb = self.embedding(site_ids)
combined = torch.cat([lstm_out[:, -1, :], site_emb], dim=1)
return self.regressor(combined)
# ==============================
# 4. 训练和评估模块
# ==============================
class TropoTrainer:
def __init__(self, data_root='./data'):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.loader = DirectoryBasedLoader(data_root)
self.scaler = StandardScaler()
def _preprocess(self, raw_df):
# 空数据检查
if raw_df.empty:
raise ValueError("原始数据为空,请检查数据加载逻辑")
"""数据预处理"""
# 时间解析增强
try:
raw_df['time'] = pd.to_datetime(
raw_df['epoch'].str.replace(r'^(\d{2}):', r'20\1:', regex=True),
format='%Y-%j:%H:%M:%S',
errors='coerce'
)
time_mask = raw_df['time'].isna()
if time_mask.any():
print(f"发现{time_mask.sum()}条无效时间记录,示例无效数据:")
print(raw_df[time_mask].head(2))
raw_df = raw_df[~time_mask].copy()
except Exception as e:
print(f"时间解析失败: {str(e)}")
raise
# 特征工程
if 'time' not in raw_df.columns:
raise KeyError("时间列缺失,预处理失败")
raw_df['hour'] = raw_df['time'].dt.hour
raw_df['doy_sin'] = np.sin(2 * np.pi * raw_df['time'].dt.dayofyear / 365.25)
raw_df['doy_cos'] = np.cos(2 * np.pi * raw_df['time'].dt.dayofyear / 365.25)
# 特征列验证
required_features = ['trotot', 'tgntot', 'tgetot', 'stddev',
'lat', 'lon', 'alt', 'hour']
missing_features = [f for f in required_features if f not in raw_df.columns]
if missing_features:
raise KeyError(f"缺失关键特征列: {missing_features}")
# 标准化
try:
self.scaler.fit(raw_df[required_features])
raw_df[required_features] = self.scaler.transform(raw_df[required_features])
except ValueError as e:
print(f"标准化失败: {str(e)}")
print("数据统计信息:")
print(raw_df[required_features].describe())
raise
return raw_df
def _inverse_transform(self, values):
"""反标准化"""
dummy = np.zeros((len(values), len(self.scaler.feature_names_in_)))
dummy[:, 0] = values
return self.scaler.inverse_transform(dummy)[:, 0]
def train(self, window_size=6, epochs=100, batch_size=32):
try:
# 加载数据
raw_df = self.loader.load_all_data()
if raw_df.empty:
raise ValueError("数据加载器返回空DataFrame")
# 预处理
processed_df = self._preprocess(raw_df)
print(f"预处理后数据量: {len(processed_df)}条")
print("特征矩阵示例:")
print(processed_df[['site', 'time', 'trotot']].head(3))
# 创建数据集
full_dataset = TemporalDataset(processed_df, window_size)
print(f"生成有效样本数: {len(full_dataset)}")
# 划分数据集(按时间顺序)
train_size = int(0.8 * len(full_dataset))
train_dataset, test_dataset = torch.utils.data.random_split(
full_dataset, [train_size, len(full_dataset) - train_size],
generator=torch.Generator().manual_seed(42)
)
# 初始化模型
model = SiteLSTM(
input_size=len(full_dataset.feature_cols),
num_sites=len(full_dataset.site_to_id)
).to(self.device)
# 训练配置
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
criterion = nn.MSELoss()
# 训练循环
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
best_loss = float('inf')
history = {'train': [], 'val': []}
for epoch in range(epochs):
# 训练阶段
model.train()
train_loss = 0
for seq, target, site, _ in train_loader:
seq = seq.to(self.device)
target = target.to(self.device)
site = site.to(self.device)
optimizer.zero_grad()
pred = model(seq, site)
loss = criterion(pred, target)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
# 验证阶段
model.eval()
val_loss = 0
predictions = []
with torch.no_grad():
val_loader = DataLoader(test_dataset, batch_size=128)
for seq, target, site, _ in val_loader:
seq = seq.to(self.device)
target = target.to(self.device)
site = site.to(self.device)
pred = model(seq, site)
val_loss += criterion(pred, target).item()
predictions.append(pred.cpu().numpy())
# 记录历史
avg_train = train_loss / len(train_loader)
avg_val = val_loss / len(val_loader)
history['train'].append(avg_train)
history['val'].append(avg_val)
scheduler.step(avg_val)
# 保存最佳模型
if avg_val < best_loss:
best_loss = avg_val
torch.save(model.state_dict(), 'best_model.pth')
print(f"Epoch {epoch + 1:03d} | Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")
# 加载最佳模型
model.load_state_dict(torch.load('best_model.pth'))
except Exception as e:
print(f"\n{'*' * 40}")
print(f"训练失败详细原因: {str(e)}")
print(f"错误类型: {type(e).__name__}")
import traceback
traceback.print_exc()
print(f"{'*' * 40}\n")
raise
return model, history
def evaluate(self, model, output_dir='results'):
"""评估并保存结果"""
os.makedirs(output_dir, exist_ok=True)
# 重新加载完整数据
raw_df = self.loader.load_all_data()
processed_df = self._preprocess(raw_df)
full_dataset = TemporalDataset(processed_df, window_size=6)
# 创建数据加载器
test_loader = DataLoader(full_dataset, batch_size=128)
# 收集结果
results = []
model.eval()
with torch.no_grad():
for seq, target, site, timestamp in test_loader:
# ...获取预测值...
timestamp = timestamp.numpy().flatten()
datetime_objs = pd.to_datetime(timestamp, unit='s')
seq = seq.to(self.device)
site = site.to(self.device)
pred = model(seq, site).cpu().numpy().flatten()
true = target.numpy().flatten()
# 反标准化
pred = self._inverse_transform(pred)
true = self._inverse_transform(true)
# 收集数据
for p, t, s, ts in zip(pred, true, site, datetime_objs):
results.append({
'site': list(full_dataset.site_to_id.keys())[s],
'timestamp': ts,
'true': t,
'pred': p
})
# 转换为DataFrame
result_df = pd.DataFrame(results)
# 按站点保存结果
for site, group in result_df.groupby('site'):
site_dir = os.path.join(output_dir, site)
os.makedirs(site_dir, exist_ok=True)
# CSV文件
csv_path = os.path.join(site_dir, f'{site}_predictions.csv')
group.to_csv(csv_path, index=False)
# Excel文件
excel_path = os.path.join(site_dir, f'{site}_predictions.xlsx')
group.to_excel(excel_path, index=False)
# 生成对比图
plt.figure(figsize=(12, 6))
plt.plot(group['timestamp'], group['true'], label='真实值')
plt.plot(group['timestamp'], group['pred'], label='预测值', linestyle='--')
plt.title(f'站点 {site} 对流层延迟预测对比')
plt.xlabel('时间')
plt.ylabel('延迟量 (mm)')
plt.legend()
plt.gcf().autofmt_xdate()
plot_path = os.path.join(site_dir, f'{site}_comparison.png')
plt.savefig(plot_path, bbox_inches='tight')
plt.close()
# 保存汇总文件
result_df.to_csv(os.path.join(output_dir, 'all_predictions.csv'), index=False)
result_df.to_excel(os.path.join(output_dir, 'all_predictions.xlsx'), index=False)
print(f"结果已保存至 {output_dir} 目录")
return result_df
# ==============================
# 主程序
# ==============================
if __name__ == "__main__":
# 初始化训练器
trainer = TropoTrainer(data_root='./data')
try:
# 训练模型
model, history = trainer.train(epochs=100)
# 可视化训练过程
plt.figure(figsize=(10, 5))
plt.plot(history['train'], label='训练损失')
plt.plot(history['val'], label='验证损失')
plt.title('训练过程')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.savefig('training_history.png', bbox_inches='tight')
# 评估并保存结果
results = trainer.evaluate(model)
# 生成统计报告
report = results.groupby('site').apply(lambda x: pd.Series({
'MAE(mm)': np.mean(np.abs(x['pred'] - x['true'])),
'Max_True': np.max(x['true']),
'Min_True': np.min(x['true']),
'Max_Pred': np.max(x['pred']),
'Min_Pred': np.min(x['pred']),
'Samples': len(x)
})).reset_index()
print("\n站点预测性能报告:")
print(report.to_markdown(index=False))
except Exception as e:
print(f"运行出错: {str(e)}")