import os
import re
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyproj import Transformer
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# ==============================
# 1. 增强型文件加载器
# ==============================
class EnhancedTropoLoader:
def __init__(self, data_root):
self.data_root = data_root
self.transformer = Transformer.from_crs("EPSG:4978", "EPSG:4326")
self.site_cache = {}
self.feature_names = [
'trotot', 'tgntot', 'tgetot', 'stddev',
'lat', 'lon', 'alt', 'hour'
]
def _parse_site_code(self, filename):
"""改进版站点代码解析,支持多种格式"""
patterns = [
r"_([A-Z]{4})\d{2}[A-Z]{3}_TRO\.TRO$", # ABMF00GLP → ABMF
r"_([A-Z]{4}\d{2})[A-Z]{3}_TRO\.TRO$", # AC2300USA → AC23
r"_([A-Z]{4})00([A-Z]{3})_TRO\.TRO$", # ABPO00MDG → ABPO
r"_([A-Z]{4})_TRO\.TRO$" # ABPO_TRO.TRO → ABPO
]
for pattern in patterns:
match = re.search(pattern, filename)
if match:
code = match.group(1)
# 清理尾部数字(如果存在)
return re.sub(r'\d{2}$', '', code) if len(code) > 4 else code
return None
def _parse_file(self, file_path):
"""解析单个文件"""
try:
# 获取站点代码
filename = os.path.basename(file_path)
site_code = self._parse_site_code(filename)
if not site_code:
print(f"跳过无法解析站点的文件: {file_path}")
return None
# 读取坐标
coordinates = self._get_coordinates(file_path)
if not coordinates:
print(f"跳过无有效坐标的文件: {file_path}")
return None
# 坐标转换
lat, lon, alt = self.transformer.transform(
coordinates['x'], coordinates['y'], coordinates['z']
)
if None in (lat, lon, alt):
return None
# 读取观测数据
records = self._read_observations(file_path, site_code)
if len(records) < 10:
print(f"跳过数据不足的文件: {file_path}")
return None
# 创建DataFrame
df = pd.DataFrame(records)
df['lat'] = lat
df['lon'] = lon
df['alt'] = alt
return df
except Exception as e:
print(f"文件解析失败 [{file_path}]: {str(e)}")
return None
def _get_coordinates(self, file_path):
"""获取站点坐标"""
if file_path in self.site_cache:
return self.site_cache[file_path]
coordinates = None
try:
with open(file_path, 'r') as f:
current_section = None
for line in f:
line = line.strip()
if line.startswith('+'):
current_section = line[1:]
elif line.startswith('-'):
current_section = None
elif current_section == 'TROP/STA_COORDINATES' and not line.startswith('*'):
parts = line.split()
if len(parts) >= 7:
coordinates = {
'x': float(parts[4]),
'y': float(parts[5]),
'z': float(parts[6])
}
break
except Exception as e:
print(f"坐标解析失败: {str(e)}")
self.site_cache[file_path] = coordinates
return coordinates
def _read_observations(self, file_path, site_code):
"""读取观测数据"""
records = []
try:
with open(file_path, 'r') as f:
current_section = None
for line in f:
line = line.strip()
if line.startswith('+'):
current_section = line[1:]
elif line.startswith('-'):
current_section = None
elif current_section == 'TROP/SOLUTION' and not line.startswith('*'):
parts = line.split()
if len(parts) >= 7:
records.append({
'epoch': parts[1],
'trotot': float(parts[2]),
'stddev': float(parts[3]),
'tgntot': float(parts[4]),
'tgetot': float(parts[6]),
'site': site_code
})
except Exception as e:
print(f"观测数据读取失败: {str(e)}")
return records
def load_all_data(self):
"""加载所有数据"""
all_dfs = []
for file_path in glob.glob(os.path.join(self.data_root, '**', '*.TRO'), recursive=True):
df = self._parse_file(file_path)
if df is not None:
all_dfs.append(df)
print(f"成功加载: {file_path} 记录数: {len(df)}")
return pd.concat(all_dfs) if all_dfs else pd.DataFrame()
# ==============================
# 2. 时间序列数据集
# ==============================
class TemporalDataset(Dataset):
def __init__(self, data, window_size=6):
self.window_size = window_size
self.site_to_id = {site: idx for idx, site in enumerate(data['site'].unique())}
# 按站点和时间排序
data = data.sort_values(['site', 'time'])
# 生成序列
self.sequences = []
self.targets = []
self.site_labels = []
self.timestamps = []
for site, group in data.groupby('site'):
values = group[self.feature_names].values
times = group['time'].values
unix_times = (times.astype(np.datetime64) - np.datetime64('1970-01-01T00:00:00')) / np.timedelta64(1, 's')
for i in range(len(values) - self.window_size):
self.sequences.append(values[i:i + self.window_size])
self.targets.append(values[i + self.window_size][0])
self.site_labels.append(self.site_to_id[site])
self.timestamps.append(unix_times[i + self.window_size])
self.num_samples = len(self.sequences)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 添加高斯噪声增强
noise = torch.randn(self.window_size, len(self.feature_names)) * 0.01
return (
torch.FloatTensor(self.sequences[idx]) + noise,
torch.FloatTensor([self.targets[idx]]),
torch.tensor(self.site_labels[idx], dtype=torch.long),
torch.FloatTensor([self.timestamps[idx]])
)
# ==============================
# 3. 改进的LSTM模型
# ==============================
class EnhancedLSTM(nn.Module):
def __init__(self, input_size, num_sites, hidden_size=128):
super().__init__()
self.embedding = nn.Embedding(num_sites, 16)
self.lstm = nn.LSTM(
input_size, hidden_size,
num_layers=3,
bidirectional=True,
batch_first=True,
dropout=0.4
)
self.attention = nn.Sequential(
nn.Linear(hidden_size * 2, 32),
nn.Tanh(),
nn.Linear(32, 1),
nn.Softmax(dim=1)
)
self.regressor = nn.Sequential(
nn.Linear(hidden_size * 2 + 16, 64),
nn.LayerNorm(64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 32),
nn.LayerNorm(32),
nn.ReLU(),
nn.Linear(32, 1)
)
def forward(self, x, site_ids):
lstm_out, _ = self.lstm(x)
attn_weights = self.attention(lstm_out)
context = torch.sum(attn_weights * lstm_out, dim=1)
site_emb = self.embedding(site_ids)
combined = torch.cat([context, site_emb], dim=1)
return self.regressor(combined)
# ==============================
# 4. 训练管理器
# ==============================
class TrainingManager:
def __init__(self, data_root):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.loader = EnhancedTropoLoader(data_root)
self.scaler = StandardScaler()
def _preprocess(self, raw_df):
"""数据预处理"""
# 时间解析
raw_df['time'] = raw_df['epoch'].apply(
lambda x: pd.to_datetime(
f"20{x.split(':')[0]}-{x.split(':')[1]}",
format='%Y-%j'
) + pd.to_timedelta(int(x.split(':')[2]), unit='s')
)
raw_df = raw_df.dropna(subset=['time'])
# 特征工程
raw_df['hour'] = raw_df['time'].dt.hour
raw_df['doy_sin'] = np.sin(2 * np.pi * raw_df['time'].dt.dayofyear / 365)
raw_df['doy_cos'] = np.cos(2 * np.pi * raw_df['time'].dt.dayofyear / 365)
# 标准化
raw_df[self.loader.feature_names] = self.scaler.fit_transform(
raw_df[self.loader.feature_names]
)
return raw_df
def train(self, window_size=6, epochs=200, batch_size=64):
# 加载数据
raw_df = self.loader.load_all_data()
if raw_df.empty:
raise ValueError("未加载到有效数据")
# 预处理
processed_df = self._preprocess(raw_df)
# 创建数据集
full_dataset = TemporalDataset(processed_df, window_size)
print(f"数据集样本数量: {len(full_dataset)}")
# 划分数据集
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(
full_dataset, [train_size, test_size],
generator=torch.Generator().manual_seed(42)
)
# 初始化模型
model = EnhancedLSTM(
input_size=len(self.loader.feature_names),
num_sites=len(full_dataset.site_to_id),
hidden_size=128
).to(self.device)
# 训练配置
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=1e-3,
steps_per_epoch=len(train_loader),
epochs=epochs
)
criterion = nn.MSELoss()
# 训练循环
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
best_loss = float('inf')
history = {'train': [], 'val': []}
for epoch in range(epochs):
model.train()
train_loss = 0
for seq, target, site, _ in train_loader:
seq = seq.to(self.device)
target = target.to(self.device)
site = site.to(self.device)
optimizer.zero_grad()
pred = model(seq, site)
loss = criterion(pred, target)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
scheduler.step()
# 验证
model.eval()
val_loss = 0
with torch.no_grad():
val_loader = DataLoader(test_dataset, batch_size=256)
for seq, target, site, _ in val_loader:
seq = seq.to(self.device)
target = target.to(self.device)
site = site.to(self.device)
pred = model(seq, site)
val_loss += criterion(pred, target).item()
avg_train = train_loss / len(train_loader)
avg_val = val_loss / len(val_loader)
history['train'].append(avg_train)
history['val'].append(avg_val)
# 保存最佳模型
if avg_val < best_loss:
best_loss = avg_val
torch.save(model.state_dict(), 'best_model.pth')
print(f"Epoch {epoch + 1:03d} | Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")
return model, history
def evaluate(self, model, output_dir='results'):
"""评估与结果保存"""
os.makedirs(output_dir, exist_ok=True)
# 重新加载数据
raw_df = self.loader.load_all_data()
processed_df = self._preprocess(raw_df)
full_dataset = TemporalDataset(processed_df, window_size=6)
# 预测
model.eval()
results = []
with torch.no_grad():
test_loader = DataLoader(full_dataset, batch_size=256)
for seq, target, site, timestamp in test_loader:
seq = seq.to(self.device)
site = site.to(self.device)
pred = model(seq, site).cpu().numpy().flatten()
true = target.numpy().flatten()
times = pd.to_datetime(timestamp.numpy().flatten(), unit='s')
for p, t, s, ts in zip(pred, true, site, times):
results.append({
'site': list(full_dataset.site_to_id.keys())[s],
'timestamp': ts,
'true': t,
'pred': p
})
# 反标准化
result_df = pd.DataFrame(results)
dummy = np.zeros((len(result_df), len(self.loader.feature_names)))
dummy[:, 0] = result_df['true']
result_df['true'] = self.scaler.inverse_transform(dummy)[:, 0]
dummy[:, 0] = result_df['pred']
result_df['pred'] = self.scaler.inverse_transform(dummy)[:, 0]
# 保存结果
self._save_results(result_df, output_dir)
return result_df
def _save_results(self, df, output_dir):
"""保存结果和可视化"""
# 按站点保存
for site, group in df.groupby('site'):
site_dir = os.path.join(output_dir, site)
os.makedirs(site_dir, exist_ok=True)
# 保存数据
csv_path = os.path.join(site_dir, f'{site}_predictions.csv')
group.to_csv(csv_path, index=False)
# 生成可视化
self._plot_predictions(group, site, site_dir)
# 保存汇总
df.to_csv(os.path.join(output_dir, 'all_predictions.csv'), index=False)
print(f"结果已保存至 {output_dir}")
def _plot_predictions(self, data, site, save_dir):
"""生成可视化图表"""
plt.figure(figsize=(16, 9))
plt.plot(data['timestamp'], data['true'], label='真实值', linewidth=1.5)
plt.plot(data['timestamp'], data['pred'], label='预测值', linestyle='--', alpha=0.8)
plt.title(f'站点 {site} 对流层延迟预测 (MAE: {np.mean(np.abs(data["true"] - data["pred"])):.2f}mm)')
plt.xlabel('时间')
plt.ylabel('延迟量 (mm)')
plt.legend()
plt.grid(True)
plt.gcf().autofmt_xdate()
plot_path = os.path.join(save_dir, f'{site}_comparison.png')
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.close()
# ==============================
# 主程序
# ==============================
if __name__ == "__main__":
try:
trainer = TrainingManager(data_root='./data')
model, history = trainer.train(epochs=200)
results = trainer.evaluate(model)
# 生成统计报告
report = results.groupby('site').apply(lambda x: pd.Series({
'MAE(mm)': np.mean(np.abs(x['true'] - x['pred'])),
'Max_True': x['true'].max(),
'Min_True': x['true'].min(),
'Max_Pred': x['pred'].max(),
'Min_Pred': x['pred'].min(),
'Samples': len(x)
})).reset_index()
print("\n站点预测性能报告:")
print(report.to_markdown(index=False))
# 绘制训练曲线
plt.figure(figsize=(12, 6))
plt.plot(history['train'], label='训练损失')
plt.plot(history['val'], label='验证损失')
plt.title('训练过程')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.savefig('training_history.png', bbox_inches='tight')
except Exception as e:
print(f"运行出错: {str(e)}")
"D:\Pycharm 2024\idle\pythonProject1\.venv\Scripts\python.exe" D:\idle\test-lstm\LSTM_TROP_TEST.py
成功加载: ./data\IGS0OPSFIN_20250010000_01D_05M_ABMF00GLP_TRO.TRO 记录数: 288
跳过无法解析站点的文件: ./data\IGS0OPSFIN_20250010000_01D_05M_AC2300USA_TRO.TRO
成功加载: ./data\IGS0OPSFIN_20250020000_01D_05M_ABMF00GLP_TRO.TRO 记录数: 288
成功加载: ./data\IGS0OPSFIN_20250020000_01D_05M_ABPO00MDG_TRO.TRO 记录数: 288
运行出错: 'TemporalDataset' object has no attribute 'feature_names'
进程已结束,退出代码为 0
最新发布