==============================================
验证声明
本文根据公开论文进行复现,仅供本人学习记录使用。如有侵权,请及时联系本人删除
=====================================================================
文章原文:基于深度学习的弹道导弹上升段轨迹跟踪算法
文章英文名:Trajectory tracking algorithm of ballistic missile in the ascent phase based on deep learning
原文摘要
弹道导弹的上升段包含助推段和自由飞行段,其动力学特征复杂。实现对其轨迹的高精度连续跟踪已成为导弹防御系统中急需解决的问题。为了解决这一问题,提出了一种基于深度学习的弹道导弹轨迹跟踪算法。
首先,为了准确描述弹道导弹上升段的运动特性,构建了基于重力转弯模型、三维转弯模型和精确动力学模型的交互多模型(IMM)算法,并给出了天基红外量测模型。
其次,针对传统 IMM 算法模型概率更新滞后的问题,引入深度学习算法,利用目标的当前飞行状态预测 IMM 模型集中的子模型概率。应用卷积神经网络(CNN)挖掘目标状态中的特征信息,同时引入挤压与激励网络(SENet)来实现对 CNN 特征通道重要性的精确分配。并使用长短期记忆神经网络(LSTM)对经 CNN 和 SENet 处理后的数据进行训练,建立了 SECL 概率预测模型。
然后,利用 SECL 概率预测模型对 IMM 算法进行优化,得到了 SECL-IMM 算法。
最后,在不同场景下进行了对比仿真实验。结果表明,与 LSTM 和 CNN 相比,SECL 概率预测模型具有更快的收敛速度和更好的稳定性,能有效降低模型的概率预测误差。与传统 IMM 算法相比,SECL-IMM 算法显著提高了轨迹跟踪精度和系统的鲁棒性,能够实现对弹道导弹上升段的高精度连续轨迹跟踪。
关键词:论文复现,SECL-IMM,深度学习,卡尔曼滤波,目标跟踪,Pytorch
0 结果先行
训练结果
processed_data
X shape: torch.Size([4991000, 10, 6]), Y shape: torch.Size([4991000, 3])
2025-12-28 08:05:54,510 - INFO -
>>> [Step 3] Training SECL...
=== Training SECL (Device: cuda) ===
Samples: 4991000
Epoch [1/50] | Train RMSE: 0.0514 | Val RMSE: 0.0424
Epoch [2/50] | Train RMSE: 0.0423 | Val RMSE: 0.0501
Epoch [3/50] | Train RMSE: 0.0403 | Val RMSE: 0.0479
Epoch [4/50] | Train RMSE: 0.0373 | Val RMSE: 0.0345
Epoch [5/50] | Train RMSE: 0.0351 | Val RMSE: 0.0560
Epoch [6/50] | Train RMSE: 0.0338 | Val RMSE: 0.0354
Epoch [7/50] | Train RMSE: 0.0331 | Val RMSE: 0.0296
Epoch [8/50] | Train RMSE: 0.0322 | Val RMSE: 0.0333
Epoch [9/50] | Train RMSE: 0.0317 | Val RMSE: 0.0276
Epoch [10/50] | Train RMSE: 0.0314 | Val RMSE: 0.0280
Epoch [11/50] | Train RMSE: 0.0308 | Val RMSE: 0.0321
Epoch [12/50] | Train RMSE: 0.0307 | Val RMSE: 0.0324
Epoch [13/50] | Train RMSE: 0.0303 | Val RMSE: 0.0290
Epoch [14/50] | Train RMSE: 0.0300 | Val RMSE: 0.0283
Epoch [15/50] | Train RMSE: 0.0298 | Val RMSE: 0.0243
Epoch [16/50] | Train RMSE: 0.0295 | Val RMSE: 0.0275
Epoch [17/50] | Train RMSE: 0.0294 | Val RMSE: 0.0267
Epoch [18/50] | Train RMSE: 0.0292 | Val RMSE: 0.0282
Epoch [19/50] | Train RMSE: 0.0290 | Val RMSE: 0.0260
Epoch [20/50] | Train RMSE: 0.0286 | Val RMSE: 0.0242
Epoch [21/50] | Train RMSE: 0.0285 | Val RMSE: 0.0294
Epoch [22/50] | Train RMSE: 0.0285 | Val RMSE: 0.0295
Epoch [23/50] | Train RMSE: 0.0281 | Val RMSE: 0.0244
Epoch [24/50] | Train RMSE: 0.0282 | Val RMSE: 0.0340
Epoch [25/50] | Train RMSE: 0.0281 | Val RMSE: 0.0244
Epoch [26/50] | Train RMSE: 0.0279 | Val RMSE: 0.0270
Epoch [27/50] | Train RMSE: 0.0277 | Val RMSE: 0.0241
Epoch [28/50] | Train RMSE: 0.0277 | Val RMSE: 0.0234
Epoch [29/50] | Train RMSE: 0.0277 | Val RMSE: 0.0255
Epoch [30/50] | Train RMSE: 0.0275 | Val RMSE: 0.0273
Epoch [31/50] | Train RMSE: 0.0274 | Val RMSE: 0.0240
Epoch [32/50] | Train RMSE: 0.0273 | Val RMSE: 0.0252
Epoch [33/50] | Train RMSE: 0.0273 | Val RMSE: 0.0337
Epoch [34/50] | Train RMSE: 0.0272 | Val RMSE: 0.0247
Epoch [35/50] | Train RMSE: 0.0271 | Val RMSE: 0.0257
Epoch [36/50] | Train RMSE: 0.0271 | Val RMSE: 0.0290
Epoch [37/50] | Train RMSE: 0.0270 | Val RMSE: 0.0250
Epoch [38/50] | Train RMSE: 0.0269 | Val RMSE: 0.0245
Epoch [39/50] | Train RMSE: 0.0269 | Val RMSE: 0.0257
Epoch [40/50] | Train RMSE: 0.0268 | Val RMSE: 0.0266
Epoch [41/50] | Train RMSE: 0.0268 | Val RMSE: 0.0283
Epoch [42/50] | Train RMSE: 0.0267 | Val RMSE: 0.0228
Epoch [43/50] | Train RMSE: 0.0266 | Val RMSE: 0.0236
Epoch [44/50] | Train RMSE: 0.0266 | Val RMSE: 0.0233
Epoch [45/50] | Train RMSE: 0.0266 | Val RMSE: 0.0249
Epoch [46/50] | Train RMSE: 0.0264 | Val RMSE: 0.0272
Epoch [47/50] | Train RMSE: 0.0264 | Val RMSE: 0.0298
Epoch [48/50] | Train RMSE: 0.0263 | Val RMSE: 0.0310
Epoch [49/50] | Train RMSE: 0.0263 | Val RMSE: 0.0234
Epoch [50/50] | Train RMSE: 0.0263 | Val RMSE: 0.0248
Training finished. Best Val RMSE: 0.0228
2025-12-28 09:34:02,767 - INFO -
>>> [Step 4] Evaluation...
2025-12-28 09:34:02,800 - INFO - Running Case 1...
>>> Running Case 1 (1 runs)...
0%| | 0/1 [00:00<?, ?it/s][INFO] Missile parameters loaded from dynamics.yaml
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:23<00:00, 23.15s/it]
[Case 1 Performance Summary (Avg RMSE)]
Method | Pos RMSE (m) | Vel RMSE (m/s)
--------------------------------------------------
SECL-IMM | 10081.35 | 472.80
Std-IMM | 8338.97 | 453.41
Singer | 266.16 | 73.55
Jerk | 233.37 | 48.09
2025-12-28 09:34:26,070 - INFO - Running Case 2...
>>> Running Case 2 (1 runs)...
0%| | 0/1 [00:00<?, ?it/s][INFO] Missile parameters loaded from dynamics.yaml
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:23<00:00, 23.09s/it]
[Case 2 Performance Summary (Avg RMSE)]
Method | Pos RMSE (m) | Vel RMSE (m/s)
--------------------------------------------------
SECL-IMM | 11803.56 | 512.01
Std-IMM | 9822.71 | 497.80
Singer | 273.60 | 73.86
Jerk | 239.09 | 48.32
2025-12-28 09:34:49,167 - INFO - Running Case 3...
>>> Running Case 3 (1 runs)...
0%| | 0/1 [00:00<?, ?it/s][INFO] Missile parameters loaded from dynamics.yaml
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:23<00:00, 23.12s/it]
[Case 3 Performance Summary (Avg RMSE)]
Method | Pos RMSE (m) | Vel RMSE (m/s)
--------------------------------------------------
SECL-IMM | 3706.53 | 252.61
Std-IMM | 5666.83 | 307.38
Singer | 262.94 | 72.35
Jerk | 230.97 | 47.26
2025-12-28 09:35:12,296 - INFO -
=== All Done! ===
部分结果展示
1. 背景:上升段跟踪的“阿喀琉斯之踵”
在反导拦截系统中,上升段(Boost Phase) 被认为是拦截的最佳窗口,但同时也是跟踪难度最大的阶段。与中段(自由飞行段)相比,上升段存在三个巨大的不确定性:
-
动力学未知:推力曲线保密,且加速度巨大(5g~8g),导致运动方程高度非线性。
-
意图不明:导弹可能进行程序转弯、S形机动或随机变轨。
-
模型滞后:传统的 交互多模型(IMM) 算法依赖马尔可夫转移矩阵进行模型切换,这种统计学方法往往具有滞后性——即“只有当你偏离了,我才知道你变轨了”。
针对这一痛点,SECL-IMM 算法提出了一种全新的思路:利用深度学习挖掘历史轨迹的时空特征,提前“预判”目标的运动模式,从而辅助 IMM 滤波器实现零延迟切换。
2. 论文核心理论解析
该论文提出了一种 数据驱动(Data-Driven)与模型驱动(Model-Driven)紧密耦合 的框架。
2.1 整体架构
系统由两部分组成:
-
深度判别网络(SECLNet):负责“定性”,判断当前处于什么飞行阶段(助推?滑行?转弯?)。
-
贝叶斯滤波网络(IMM-CKF):负责“定量”,基于给定的模式计算精确的状态估计(位置、速度)。
2.2 深度大脑:SECLNet
论文设计了一个名为 SECL (Squeeze-and-Excitation ConvLSTM) 的网络结构。它并没有直接回归坐标,而是输出模式概率。

2.3 滤波骨架:IMM-CKF

这使得滤波器不再是被动适应,而是具备了“先验感知”能力。
3. 项目复现:从理论到工程落地
基于上述理论,构建了一套完整的复现代码库 secl_imm_project。复现过程并非一帆风顺,我们解决了一个核心的“物理与数据的鸿沟”问题。
3.1 项目架构
项目采用模块化设计,实现了从数据生成到评估的全自动化流程:
Plaintext
secl_imm_project/
│
├── run_pipeline.py # [核心入口] 一键启动脚本:负责数据生成->训练->评估全流程
├── verify_all.py # [新增] 环境自检脚本:验证物理引擎、坐标系和依赖库
├── README.md # 项目说明文档
├── requirements.txt # 依赖库列表
│
├── sim/ # [核心] 物理仿真引擎
│ ├── __init__.py
│ ├── constants.py # 物理常数 (WGS-84, J2, 地球自转 WE 等)
│ ├── frames.py # 坐标转换库 (LLA <-> ECEF)
│ ├── dynamics.py # 高精度动力学方程 (含推力、气动、重力)
│ ├── missile.py # 导弹对象定义 (质量、推力曲线参数)
│ └── observation.py # 雷达观测站模型 (生成方位角/俯仰角)
│
├── data/ # 数据工程模块
│ ├── __init__.py
│ ├── factory.py # 多进程数据生成工厂 (调用 RK4 积分)
│ ├── preprocess.py # 数据预处理 (滑动窗口、归一化、转 Tensor)
│ ├── sampler.py # 蒙特卡洛初始条件采样器
│ └── labeling.py # 轨迹阶段打标逻辑 (Boost/Coast/Maneuver)
│
├── models/ # 算法模型库
│ ├── __init__.py
│ ├── secl.py # SECLNet 深度神经网络结构 (ConvLSTM + SE Block)
│ ├── ckf.py # 容积卡尔曼滤波器 (Cubature Kalman Filter)
│ ├── imm.py # 交互多模型 (Interacting Multiple Model) 核心逻辑
│ ├── secl_imm.py # 融合算法:将 SECL 概率注入 IMM
│ └── baselines.py # 基线算法 (Singer, Jerk 模型)
│
├── train/ # 模型训练模块
│ ├── __init__.py
│ └── train_secl.py # 训练主循环 (含断点续训、Loss记录、模型保存)
│
├── scripts/ # 辅助脚本工具
│ ├── __init__.py
│ └── eval_all_cases.py # 独立评估脚本 (定义 Case 1/2/3 场景与指标计算)
│
└── experiments/ # [自动生成] 实验输出目录 (无需手动创建)
└── secl_imm_main/ # 主实验运行目录
├── pipeline.log # 全流程运行日志 (排查报错的神器)
├── raw_data/ # 生成的原始弹道数据 (.csv)
│ ├── traj_0000.csv
│ └── ...
├── processed_data/ # 预处理后的训练数据 (.pt)
│ ├── train_X.pt
│ └── train_Y.pt
├── models/ # 训练好的模型与参数
│ ├── secl_best.pth # 验证集 Loss 最低的模型权重
│ ├── secl_last_ckpt.pth # 断点续训检查点
│ └── scaler.pkl # 数据归一化参数 (必须与模型配套)
├── plots/ # 自动生成的结果图表
│ ├── training_curve.png # 训练收敛曲线
│ ├── Case_1_prob.png # Case 1 模式概率切换图
│ ├── Case_3_prob.png # Case 3 模式概率切换图
│ └── ...
└── temp_eval/ # 评估阶段生成的临时验证数据
3.2 关键挑战与修正
在复现初期,我们遇到了 RMSE 异常高(甚至发散)的问题。经过深度分析,我们发现论文中的理想假设与实际工程存在冲突:
问题:推力失配(Thrust Mismatch)
在真实仿真中,导弹推力巨大;但在滤波器的预测模型中,防御方无法获知推力函数。如果我们强行使用包含推力项的方程(但参数是错的),会导致滤波器预测值严重超调。
工程解法:鲁棒弹道预测(Robust Ballistic Predictor)
我们在复现中采用了一种更鲁棒的策略:
-
预测模型做减法:滤波器内部的物理模型只保留重力和科里奥利力,完全移除推力项。
-
过程噪声做加法:针对 Boost 模型,将其过程噪声协方差矩阵(Q阵)调大 $10^4$ 倍。
-
物理含义:既然不知道推力,就将其视为“巨大的随机扰动”,迫使卡尔曼滤波在助推段降低对物理公式的信任,转而死死咬住雷达观测。
3.3 复现结果
在 500 万条轨迹数据的支持下,我们成功复现了论文效果。
场景:Case 3(复杂机动与变轨)
-
标准 IMM:由于机动延迟,RMSE 约为 5666 米。
-
SECL-IMM(本项目):利用深度学习提前识别机动,RMSE 降至 3706 米。
结果表明:在复现中,SECL-IMM 相比传统算法在机动段的跟踪精度提升了约 35%,验证了“AI+控制”范式的有效性。
4. 资源与总结
本项目不仅复现了论文算法,还提供了一套完整的弹道导弹仿真与跟踪框架。
-
对于研究者:可以直接使用
sim模块生成高质量的弹道数据集。 -
对于工程师:
models中的 CKF 和 IMM 实现经过了数值稳定性优化,可直接用于工程项目。
深度学习正在重塑传统控制领域,SECL-IMM 就是一个绝佳的例证。它告诉我们:物理模型提供底线,数据驱动突破上限。
5.源码
run_pipeline.py
import os
import sys
import logging
import torch
import numpy as np
import matplotlib.pyplot as plt
import traceback
import glob
import shutil
import gc
import time
import random
# 全局开关
CLEAN_START = False # 训练已经很好,不需要每次都重训了,设为 False 节省时间
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
def setup_logging(save_dir):
log_file = os.path.join(save_dir, 'pipeline.log')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler(log_file, encoding='utf-8', mode='a'), logging.StreamHandler(sys.stdout)]
)
return logging.getLogger(__name__)
try:
from data.factory import DataFactory
from data.preprocess import Preprocessor
from train.train_secl import train_secl_model
from models.secl import SECLNet
from models.secl_imm import SECL_IMM
from models.imm import IMM
from models.ckf import CKF
from models.baselines import BaselineFilterWrapper
from sim.observation import Observer
from scripts.eval_all_cases import run_simulation_case
from sim.constants import MU_EARTH, WE # 导入地球引力常数和自转角速度
except ImportError as e:
print(f"[Fatal] Import Error: {e}"); sys.exit(1)
def is_file_valid(file_path):
if not os.path.exists(file_path): return False
if file_path.endswith(('.pt', '.pth')):
try: torch.load(file_path, map_location='cpu'); return True
except: return False
return True
# =================================================================
# [核心修复] 鲁棒弹道预测函数 (Robust Ballistic Predictor)
# =================================================================
def robust_ballistic_predict(x, dt, params=None):
"""
考虑地球自转(Coriolis)和重力(Gravity)的预测模型。
假设推力为0。推力造成的偏差由 Q 矩阵吸收。
x: [rx, ry, rz, vx, vy, vz] (ECEF Frame)
"""
rx, ry, rz, vx, vy, vz = x
# 1. 重力加速度 (二体模型)
r_sq = rx**2 + ry**2 + rz**2
r_norm = np.sqrt(r_sq)
if r_norm < 1.0: r_norm = 6378000.0 # 保护
# g = -mu * r / |r|^3
g_factor = -MU_EARTH / (r_norm**3)
ax_g = g_factor * rx
ay_g = g_factor * ry
az_g = g_factor * rz
# 2. 科里奥利力和离心力 (ECEF 坐标系下的牛顿方程修正)
# a_cor = -2 * cross(Omega, v) - cross(Omega, cross(Omega, r))
# Omega = [0, 0, WE]
# 2.1 Coriolis: -2 * Omega x v
# Omega x v = [-WE*vy, WE*vx, 0]
# -2 * ... = [2*WE*vy, -2*WE*vx, 0]
ax_cor = 2 * WE * vy
ay_cor = -2 * WE * vx
az_cor = 0
# 2.2 Centrifugal: - Omega x (Omega x r)
# Omega x r = [-WE*ry, WE*rx, 0]
# Omega x (...) = [-WE*WE*rx, -WE*WE*ry, 0]
# - (...) = [WE^2 * rx, WE^2 * ry, 0]
ax_cen = (WE**2) * rx
ay_cen = (WE**2) * ry
az_cen = 0
# 总加速度
ax = ax_g + ax_cor + ax_cen
ay = ay_g + ay_cor + ay_cen
az = az_g + az_cor + az_cen
# 简单的欧拉积分 (对于滤波器预测足够,dt=0.1s)
# 也可以用 F 矩阵形式,但这里直接写非线性形式给 CKF 用
rx_new = rx + vx * dt + 0.5 * ax * dt**2
ry_new = ry + vy * dt + 0.5 * ay * dt**2
rz_new = rz + vz * dt + 0.5 * az * dt**2
vx_new = vx + ax * dt
vy_new = vy + ay * dt
vz_new = vz + az * dt
return np.array([rx_new, ry_new, rz_new, vx_new, vy_new, vz_new])
def main():
run_name = "secl_imm_main"
base_output_dir = os.path.join(PROJECT_ROOT, "experiments", run_name)
raw_data_dir = os.path.join(base_output_dir, "raw_data")
processed_data_dir = os.path.join(base_output_dir, "processed_data")
model_save_dir = os.path.join(base_output_dir, "models")
plots_dir = os.path.join(base_output_dir, "plots")
# Clean Start Logic
if CLEAN_START:
print("\n🧹 [Clean Start] Wiping processed data and models...")
if os.path.exists(processed_data_dir): shutil.rmtree(processed_data_dir)
if os.path.exists(model_save_dir): shutil.rmtree(model_save_dir)
for d in [raw_data_dir, processed_data_dir, model_save_dir, plots_dir]:
os.makedirs(d, exist_ok=True)
logger = setup_logging(base_output_dir)
logger.info(f"=== SECL-IMM Pipeline (Physics: Gravity+Coriolis) ===")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Device: {device}")
N_TRAIN = 1000
# Step 1: Data Gen
csv_files = glob.glob(os.path.join(raw_data_dir, "*.csv"))
if len(csv_files) >= N_TRAIN and not CLEAN_START:
logger.info(f"✅ [Step 1] Found {len(csv_files)} files, skipping.")
csv_files = csv_files[:N_TRAIN]
else:
logger.info(f"\n>>> [Step 1] Generating Data...")
if os.path.exists(raw_data_dir): shutil.rmtree(raw_data_dir)
os.makedirs(raw_data_dir)
factory = DataFactory(output_dir=raw_data_dir)
csv_files = factory.run_production(n_samples=N_TRAIN)
gc.collect()
# Step 2: Preprocess
train_x_path = os.path.join(processed_data_dir, 'train_X.pt')
if is_file_valid(train_x_path) and not CLEAN_START:
logger.info("✅ [Step 2] Preprocessed data valid.")
else:
logger.info("\n>>> [Step 2] Preprocessing...")
if os.path.exists(train_x_path): os.remove(train_x_path)
try:
prep = Preprocessor(seq_len=10)
prep.fit(csv_files)
prep.process_and_save(csv_files, processed_data_dir)
import pickle
with open(os.path.join(model_save_dir, 'scaler.pkl'), 'wb') as f:
pickle.dump(prep.scaler, f)
except Exception as e:
logger.error(f"Step 2 Failed: {e}"); sys.exit(1)
gc.collect()
# Step 3: Train
model_path = os.path.join(model_save_dir, "secl_best.pth")
force_train = CLEAN_START or not is_file_valid(model_path)
if not force_train:
logger.info("✅ [Step 3] Model valid, skipping.")
else:
logger.info(f"\n>>> [Step 3] Training SECL...")
try:
history = train_secl_model(
data_dir=processed_data_dir,
save_path=model_path,
quick_mode=False,
device=device,
force_restart=CLEAN_START
)
if history:
plt.figure()
plt.plot(history['train_rmse'], label='Train')
plt.plot(history['val_rmse'], label='Val')
plt.legend(); plt.savefig(os.path.join(plots_dir, 'training_curve.png')); plt.close()
except Exception as e:
logger.error(f"Step 3 Failed: {e}"); sys.exit(1)
# Step 4: Eval
logger.info(f"\n>>> [Step 4] Evaluation...")
try:
np.random.seed(42); torch.manual_seed(42); random.seed(42)
import pickle
with open(os.path.join(model_save_dir, 'scaler.pkl'), 'rb') as f:
scaler = pickle.load(f)
model = SECLNet().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
obs_engine = Observer()
def measurement_predict(x):
return obs_engine.get_measurements(x[:3])
def create_filters():
filters = []
# [核心调参逻辑]
# 我们移除了"推力模型",所以"推力"变成了"过程噪声"。
# 助推阶段推力极大(5g~10g),所以需要极大的 Q 阵来覆盖。
R_trust = np.eye(4) * 1e-4
# Model 1: Coast (滑行段)
# 此时主要受重力,robust_ballistic_predict 很准。Q 设小。
q_coast = np.diag([1.0, 1.0, 1.0, 5.0, 5.0, 5.0])
filters.append(CKF(6, 4, robust_ballistic_predict, measurement_predict, q_coast, R_trust))
# Model 2: Boost (助推段)
# 此时有巨大推力,模型完全不准。Q 必须极大,允许状态被观测强行拉回。
# Q_vel = 10000 -> std = 100m/s. 足够覆盖 50m/s^2 * dt 的偏差。
q_boost = np.diag([10.0, 10.0, 10.0, 10000.0, 10000.0, 10000.0])
filters.append(CKF(6, 4, robust_ballistic_predict, measurement_predict, q_boost, R_trust))
# Model 3: Maneuver (转弯)
# 中间状态
q_man = np.diag([5.0, 5.0, 5.0, 1000.0, 1000.0, 1000.0])
filters.append(CKF(6, 4, robust_ballistic_predict, measurement_predict, q_man, R_trust))
return filters
P_trans = [[0.8, 0.1, 0.1], [0.1, 0.8, 0.1], [0.1, 0.1, 0.8]]
init_mu = [0.33, 0.33, 0.34]
std_imm = IMM(create_filters(), P_trans, init_mu)
secl_imm = SECL_IMM(IMM(create_filters(), P_trans, init_mu), model, scaler, device)
singer = BaselineFilterWrapper('Singer', 0.1, obs_engine)
jerk = BaselineFilterWrapper('Jerk', 0.1, obs_engine)
factory = DataFactory(output_dir=os.path.join(base_output_dir, "temp_eval"))
def save_prob_plot(log_data, case_name):
if log_data is None or log_data.get('secl_mu') is None: return
t = np.arange(len(log_data['secl_mu'])) * 0.1
plt.figure(figsize=(10, 8))
modes = ['Coast', 'Boost', 'Maneuver']
for i in range(3):
plt.subplot(3, 1, i + 1)
plt.plot(t, log_data['imm_mu'][:, i], 'r--', alpha=0.5, label='Std')
plt.plot(t, log_data['secl_mu'][:, i], 'b-', label='SECL')
plt.ylabel(modes[i])
plt.legend()
plt.savefig(os.path.join(plots_dir, f"{case_name}_prob.png"))
plt.close()
logger.info("Running Case 1...")
log1 = run_simulation_case('Case 1', factory, secl_imm, std_imm, singer, jerk, n_runs=1)
save_prob_plot(log1, 'Case_1')
logger.info("Running Case 2...")
run_simulation_case('Case 2', factory, secl_imm, std_imm, singer, jerk, n_runs=1)
logger.info("Running Case 3...")
run_simulation_case('Case 3', factory, secl_imm, std_imm, singer, jerk, n_runs=1)
except Exception as e:
logger.error(f"Step 4 Failed: {e}"); traceback.print_exc(); sys.exit(1)
logger.info(f"\n=== All Done! ===")
if __name__ == "__main__":
main()
eval_all_cases.py
import numpy as np
import torch
import matplotlib.pyplot as plt
import sys
import os
from tqdm import tqdm
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
sys.path.insert(0, project_root)
from models.secl_imm import SECL_IMM
from models.ckf import CKF
from data.factory import DataFactory
from experiments.metrics import compute_trajectory_metrics
# [修改] 导入常量
from sim.constants import MU_EARTH, WE
# [核心] 同步 robust_ballistic_predict
def robust_ballistic_predict(x, dt, params=None):
rx, ry, rz, vx, vy, vz = x
r_sq = rx**2 + ry**2 + rz**2
r_norm = np.sqrt(r_sq)
if r_norm < 1.0: r_norm = 6378000.0
g_factor = -MU_EARTH / (r_norm**3)
ax_g = g_factor * rx
ay_g = g_factor * ry
az_g = g_factor * rz
ax_cor = 2 * WE * vy
ay_cor = -2 * WE * vx
ax_cen = (WE**2) * rx
ay_cen = (WE**2) * ry
ax = ax_g + ax_cor + ax_cen
ay = ay_g + ay_cor + ay_cen
az = az_g
rx_new = rx + vx * dt + 0.5 * ax * dt**2
ry_new = ry + vy * dt + 0.5 * ay * dt**2
rz_new = rz + vz * dt + 0.5 * az * dt**2
vx_new = vx + ax * dt
vy_new = vy + ay * dt
vz_new = vz + az * dt
return np.array([rx_new, ry_new, rz_new, vx_new, vy_new, vz_new])
def run_simulation_case(case_name, factory, secl_imm, std_imm, singer_filter_cls, jerk_filter_cls, n_runs=1):
print(f"\n>>> Running {case_name} ({n_runs} runs)...")
metrics = {'SECL-IMM': [], 'Std-IMM': [], 'Singer': [], 'Jerk': []}
log_data = {'true': None, 'secl_mu': None, 'imm_mu': None, 'secl_est': None, 'imm_est': None}
if case_name == 'Case 2':
bias_range = (0.9, 1.1); lon_range = (45, 50)
elif case_name == 'Case 3':
bias_range = (1.0, 1.0); lon_range = (10, 15)
else:
bias_range = (1.0, 1.0); lon_range = (45, 50)
for r in tqdm(range(n_runs)):
traj_data = factory.generate_trajectory_data(bias=np.random.uniform(*bias_range), lon_range=lon_range)
true_states = traj_data[:, 1:7]
measurements = traj_data[:, 8:12]
x0 = true_states[0] + np.random.normal(0, 100, 6)
P0 = np.eye(6) * 100
secl_x, secl_P = [x0.copy() for _ in range(3)], [P0.copy() for _ in range(3)]
secl_imm.history_buffer.clear()
for _ in range(10): secl_imm.history_buffer.append(x0.copy())
imm_x, imm_P = [x0.copy() for _ in range(3)], [P0.copy() for _ in range(3)]
singer_x = np.concatenate([x0, np.zeros(3)])
singer_P = np.block([[P0, np.zeros((6, 3))], [np.zeros((3, 6)), np.eye(3) * 100]])
jerk_x, jerk_P = singer_x.copy(), singer_P.copy()
est = {'SECL-IMM': [], 'Std-IMM': [], 'Singer': [], 'Jerk': []}
secl_mu_hist = []
imm_mu_hist = []
dt = 0.1
for k in range(len(traj_data)):
z = measurements[k]
# 这里调用的是 CKF.update, 不需要 params,predict 在内部调用
secl_x, secl_P, secl_mu, sx, _ = secl_imm.step(secl_x, secl_P, z, dt)
est['SECL-IMM'].append(sx)
secl_mu_hist.append(secl_mu.copy())
imm_x, imm_P, imm_mu, ix, _ = std_imm.step(imm_x, imm_P, z, dt)
est['Std-IMM'].append(ix)
imm_mu_hist.append(imm_mu.copy())
singer_x, singer_P = singer_filter_cls.step(singer_x, singer_P, z)
est['Singer'].append(singer_x[:6])
jerk_x, jerk_P = jerk_filter_cls.step(jerk_x, jerk_P, z)
est['Jerk'].append(jerk_x[:6])
start_idx = int(30 / dt)
if start_idx < len(true_states):
truth = true_states[start_idx:]
for key in metrics:
est_arr = np.array(est[key])
if len(est_arr) > start_idx:
m = compute_trajectory_metrics(truth, est_arr[start_idx:])
metrics[key].append(m)
if r == n_runs - 1:
log_data['true'] = true_states
log_data['secl_mu'] = np.array(secl_mu_hist)
log_data['imm_mu'] = np.array(imm_mu_hist)
log_data['secl_est'] = np.array(est['SECL-IMM'])
log_data['imm_est'] = np.array(est['Std-IMM'])
print(f"\n[{case_name} Performance Summary (Avg RMSE)]")
print(f"{'Method':<15} | {'Pos RMSE (m)':<15} | {'Vel RMSE (m/s)':<15}")
print("-" * 50)
for method in ['SECL-IMM', 'Std-IMM', 'Singer', 'Jerk']:
vals = [m['pos_avg_rmse'] for m in metrics[method]]
pos_rmse = np.mean(vals) if vals else 0.0
vel_rmse = np.mean([m['vel_avg_rmse'] for m in metrics[method]]) if metrics[method] else 0.0
print(f"{method:<15} | {pos_rmse:<15.2f} | {vel_rmse:<15.2f}")
return log_data



460

被折叠的 条评论
为什么被折叠?



