Time Pyramid Transformer (TPT) 完整代码

【投稿赢 iPhone 17】「我的第一个开源项目」故事征集:用代码换C位出道! 10w+人浏览 1.6k人参与

本代码提出了一套面向稀疏与非等距时间序列预测的增强型 Time Pyramid Transformer (TPT) 框架,具有以下关键创新点:

  1. 时间金字塔并行编码:在同一网络中并行构建 Full / Half / Quarter 三条分支,分别对原序列进行原尺度、二倍下采样和四倍下采样的独立 Transformer 编码,随后通过上采样对齐并以可学习融合(Dense + 小型 Transformer)进行联合表示学习,从体系层面同时捕获长期趋势、中尺度平滑变化与局部快速波动,解决单尺度模型对多尺度动态感知不足的问题。

  2. 连续时间嵌入(Time2Vec):引入可学习的 Time2Vec 层,将连续时间位置信息作为显式特征拼接到每个时间步,模型因此天然支持插值/非整数时间点(例如 1990.5)与不规则采样数据,使得基于 LHS 扩增得到的 fractional-year 样本能被有效利用,显著增强外推能力与对非等间隔数据的适应性。

  3. 工程化 Polarity-Aware Linear Attention:设计并实现了 polarity-aware 线性注意的工程版本,通过将 Q/K 分解为正负分量并采用核映射(phi = ELU+1)进行线性化计算,兼顾对“符号/极性”变化的敏感性与计算复杂度的下降(从 O(N²) 降至近 O(N·d)),在小样本与长序列场景下保持效率与表达力的平衡。

  4. 数据增强与鲁棒融合流程:配套 LHS(拉丁超立方)时间插值扩增策略与多尺度上采样/裁切对齐方案,使模型在样本极度有限的情况下仍能稳定训练;融合模块可学习各尺度权重,支持可解释的消融分析(例如量化各尺度对精度的贡献)。

    # -*- coding: utf-8 -*-
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    import tensorflow as tf
    from tensorflow.keras.layers import Input, Dense, Concatenate, LayerNormalization, Dropout, GlobalAveragePooling1D, \
        Conv1D, AveragePooling1D, UpSampling1D, Lambda
    from tensorflow.keras.models import Model
    from tensorflow.keras.optimizers import Adam
    from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
    from sklearn.preprocessing import MinMaxScaler
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import mean_absolute_error, mean_squared_error
    import os
    import time
    
    # 设置中文字体
    plt.rcParams["font.family"] = ["SimHei"]
    plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
    
    # ----------------------------
    # 1. 数据加载与预处理
    # ----------------------------
    def load_data(file_path):
        """加载飞行数据CSV文件"""
        try:
            df = pd.read_csv(file_path)
            required_columns = ['ISO time', 'Longitude', 'Latitude', 'Altitude']
            for col in required_columns:
                if col not in df.columns:
                    raise ValueError(f"CSV文件中缺少必要的列: {col}")
    
            # 按时间排序
            df['ISO time'] = pd.to_datetime(df['ISO time'])
            df = df.sort_values('ISO time')
    
            # 提取关键特征
            data = df[['Longitude', 'Latitude', 'Altitude']].values
            return data, df
        except Exception as e:
            print(f"加载数据时出错: {str(e)}")
            return None, None
    
    def create_sequences(data, seq_length, pred_length):
        """创建输入序列和目标序列"""
        X, y = [], []
        for i in range(len(data) - seq_length - pred_length + 1):
            X.append(data[i:i + seq_length])
            y.append(data[i + seq_length:i + seq_length + pred_length])
        return np.array(X), np.array(y)
    
    def normalize_data(data):
        """归一化数据"""
        scaler = MinMaxScaler()
        data_reshaped = data.reshape(-1, 3)
        data_scaled = scaler.fit_transform(data_reshaped)
        return data_scaled.reshape(data.shape), scaler
    
    # ----------------------------
    # 2. 改进 Transformer 相关层
    #    - PositionalEncoding: 保留
    #    - PolarityAwareLinearAttention: 新增(工程化线性注意)
    #    - transformer_block_improved: 使用 polarity-aware linear attn + FFN + Pre-LN
    # ----------------------------
    
    class PositionalEncoding(tf.keras.layers.Layer):
        """位置编码层(保留)"""
        def __init__(self, position, d_model, **kwargs):
            super(PositionalEncoding, self).__init__(**kwargs)
            self.position = position
            self.d_model = d_model
            self.pos_encoding = self.positional_encoding(position, d_model)
    
        def get_angles(self, position, i, d_model):
            angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
            return position * angles
    
        def positional_encoding(self, position, d_model):
            angle_rads = self.get_angles(
                position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
                i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
                d_model=d_model
            )
            sines = tf.math.sin(angle_rads[:, 0::2])
            cosines = tf.math.cos(angle_rads[:, 1::2])
            pos_encoding = tf.concat([sines, cosines], axis=-1)
            pos_encoding = pos_encoding[tf.newaxis, ...]
            return tf.cast(pos_encoding, tf.float32)
    
        def call(self, inputs):
            # inputs shape: (batch, seq_len, d_model)
            return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
    
        def get_config(self):
            config = super(PositionalEncoding, self).get_config()
            config.update({
                'position': self.position,
                'd_model': self.d_model
            })
            return config
    
    class PolarityAwareLinearAttention(tf.keras.layers.Layer):
        """
        工程化的 Polarity-aware Linear Attention(简化版、可训练)。
        (保留你原来的实现)
        """
        def __init__(self, dim, proj_dim=None, dropout=0.0, eps=1e-6, **kwargs):
            super(PolarityAwareLinearAttention, self).__init__(**kwargs)
            self.dim = dim
            self.proj_dim = proj_dim if proj_dim is not None else dim
            self.dropout = dropout
            self.eps = eps
    
            # q/k/v 投影
            self.q_proj = Dense(self.proj_dim, use_bias=False)
            self.k_proj = Dense(self.proj_dim, use_bias=False)
            self.v_proj = Dense(self.proj_dim, use_bias=False)
            # 输出投影
            self.out_proj = Dense(self.dim, use_bias=False)
            self.dropout_layer = Dropout(dropout)
    
        def build(self, input_shape):
            super(PolarityAwareLinearAttention, self).build(input_shape)
    
        def kernel_feature_map(self, x):
            """
            简单核映射 phi(x) = elu(x) + 1 (确保非负)。
            """
            return tf.nn.elu(x) + 1.0
    
        def split_pos_neg(self, x):
            """将张量分为正部分和负部分(用于 polarity-aware 处理)"""
            pos = tf.nn.relu(x)
            neg = tf.nn.relu(-x)
            return pos, neg
    
        def call(self, inputs, mask=None, training=False):
            # inputs: (batch, seq_len, dim)
            q = self.q_proj(inputs)  # (b, n, d')
            k = self.k_proj(inputs)
            v = self.v_proj(inputs)
    
            # 分解正负
            q_pos, q_neg = self.split_pos_neg(q)
            k_pos, k_neg = self.split_pos_neg(k)
            v_pos, v_neg = self.split_pos_neg(v)
    
            # kernel feature map (确保非负,以便做线性注意)
            q_pos_phi = self.kernel_feature_map(q_pos)
            k_pos_phi = self.kernel_feature_map(k_pos)
    
            q_neg_phi = self.kernel_feature_map(q_neg)
            k_neg_phi = self.kernel_feature_map(k_neg)
    
            # 线性注意的核心计算(batch-wise einsum)
            kv_pos = tf.einsum('bnd,bne->bde', k_pos_phi, v_pos)  # (b, d', d')
            denom_pos = tf.einsum('bnd,bd->bn', q_pos_phi, tf.reduce_sum(k_pos_phi, axis=1) + self.eps)  # (b, n)
            numer_pos = tf.einsum('bnd,bde->bne', q_pos_phi, kv_pos)  # (b, n, d')
    
            kv_neg = tf.einsum('bnd,bne->bde', k_neg_phi, v_neg)
            denom_neg = tf.einsum('bnd,bd->bn', q_neg_phi, tf.reduce_sum(k_neg_phi, axis=1) + self.eps)
            numer_neg = tf.einsum('bnd,bde->bne', q_neg_phi, kv_neg)
    
            denom = tf.expand_dims(denom_pos + denom_neg + self.eps, axis=-1)  # (b, n, 1)
            context = (numer_pos + numer_neg) / denom  # (b, n, d')
    
            context = self.dropout_layer(context, training=training)
            out = self.out_proj(context)  # 投影回原始维度 (b, n, dim)
            return out
    
        def get_config(self):
            cfg = super(PolarityAwareLinearAttention, self).get_config()
            cfg.update({
                'dim': self.dim,
                'proj_dim': self.proj_dim,
                'dropout': self.dropout,
                'eps': self.eps
            })
            return cfg
    
    def transformer_block_improved(x, d_model, ff_dim, proj_dim=None, dropout=0.1):
        """
        改进的 Transformer 块(Pre-LN + PolarityAwareLinearAttention + FFN)。
        """
        # Pre-LN
        x_norm = LayerNormalization(epsilon=1e-6)(x)
        attn_out = PolarityAwareLinearAttention(dim=d_model, proj_dim=proj_dim, dropout=dropout)(x_norm)
        # 残差
        x = attn_out + x
    
        # Feed-forward
        x_norm2 = LayerNormalization(epsilon=1e-6)(x)
        ff = Dense(ff_dim, activation='gelu')(x_norm2)
        ff = Dropout(dropout)(ff)
        ff = Dense(d_model)(ff)
        ff = Dropout(dropout)(ff)
        out = x + ff
        return out
    
    # ----------------------------
    # Time2Vec 层(用于连续时间嵌入)
    # ----------------------------
    class Time2Vec(tf.keras.layers.Layer):
        """
        简单 Time2Vec 实现(可学习):输出形状 (batch, seq_len, k)
        参考原 Time2Vec 思路:一维线性成分 + 若干周期性分量
        """
        def __init__(self, k=16, **kwargs):
            super(Time2Vec, self).__init__(**kwargs)
            self.k = k
    
        def build(self, input_shape):
            # input_shape: (batch, seq_len, 1)
            self.w = self.add_weight(shape=(self.k,), initializer='glorot_uniform', name='w')
            self.b = self.add_weight(shape=(self.k,), initializer='zeros', name='b')
            super(Time2Vec, self).build(input_shape)
    
        def call(self, t):
            # t: (batch, seq_len, 1)  按时间点的标量输入
            # 线性成分
            lin = t * self.w[:1] + self.b[:1]  # use first component as linear
            # 周期成分
            periodic = tf.math.sin(t * self.w[1:] + self.b[1:])
            # concat
            return tf.concat([lin, periodic], axis=-1)  # (batch, seq_len, k)
    
        def get_config(self):
            cfg = super(Time2Vec, self).get_config()
            cfg.update({'k': self.k})
            return cfg
    
    # ----------------------------
    # 辅助:上采样并裁切到目标长度
    # ----------------------------
    def upsample_and_crop(x, target_len, factor):
        """
        x: Tensor (batch, seq_small, dim)
        factor: upsampling factor (int)
        returns: Tensor (batch, target_len, dim)
        """
        # 使用 UpSampling1D 重复元素
        x_up = UpSampling1D(size=factor)(x)  # (batch, seq_small*factor, dim)
        # 如果上采样后长度 >= target_len,则裁切;否则再 repeat 或 pad(极端情况)
        def slice_to_target(y):
            return y[:, :target_len, :]
        x_crop = Lambda(lambda z: z[:, :target_len, :])(x_up)
        return x_crop
    
    # ----------------------------
    # 3. 构建模型:Time Pyramid Transformer (TPT)
    #    - 三尺度并行:Full / Half / Quarter
    #    - Time2Vec 融入每个时间步
    #    - 多尺度编码后上采样到原序列长度并融合
    # ----------------------------
    def build_model(seq_length, pred_length, n_features=3,
                    d_model=64, num_layers_full=3, num_layers_half=2, num_layers_quarter=1,
                    ff_dim=128, proj_dim=64, time2vec_k=16, dropout=0.1):
        """
        构建 Time Pyramid Transformer(TPT)模型:
        输入 (seq_length, n_features) -> 输出 (pred_length, n_features)
        设计要点:
          - 输入投影 -> time2vec concat -> 线性投影回 d_model
          - full / half / quarter 三分支各自堆叠若干 transformer_block_improved
          - half/quarter 使用 avg pooling 下采样,处理后上采样并裁切到 seq_length
          - 融合 (concat -> Dense -> Transformer block) -> 全局池化 -> MLP -> 输出
        """
        inputs = Input(shape=(seq_length, n_features), name="inputs")  # (b, seq, feat)
    
        # 输入投影(token embedding)
        x_proj = Dense(d_model, activation=None, name="input_projection")(inputs)  # (b, seq, d_model)
    
        # Time2Vec: 构造 normalized time positions [0,1] -> (b, seq, 1)
        def build_time_tensor(inp):
            # inp shape: (b, seq, feat)
            b = tf.shape(inp)[0]
            seq = tf.shape(inp)[1]
            # create linspace of length seq
            rr = tf.linspace(0.0, 1.0, seq)
            rr = tf.reshape(rr, (1, seq, 1))
            rr = tf.tile(rr, [b, 1, 1])
            return rr
    
        t_pos = Lambda(build_time_tensor, name="time_positions")(inputs)  # (b, seq, 1)
        t2v = Time2Vec(k=time2vec_k, name="time2vec")(t_pos)  # (b, seq, k)
    
        # concat time features to x_proj
        x_cat = Concatenate(axis=-1)([x_proj, t2v])  # (b, seq, d_model + k)
    
        # project back to d_model
        x = Dense(d_model, activation='relu', name="project_back")(x_cat)  # (b, seq, d_model)
    
        # ----------------- Full scale branch -----------------
        xf = x
        for i in range(num_layers_full):
            xf = transformer_block_improved(xf, d_model=d_model, ff_dim=ff_dim, proj_dim=proj_dim, dropout=dropout)
        enc_full = xf  # (b, seq, d_model)
    
        # ----------------- Half scale branch -----------------
        if seq_length >= 2:
            x_half = AveragePooling1D(pool_size=2, strides=2, padding='valid')(x)  # (b, seq//2, d_model)
            for i in range(num_layers_half):
                x_half = transformer_block_improved(x_half, d_model=d_model, ff_dim=ff_dim, proj_dim=proj_dim, dropout=dropout)
            # 上采样回 seq_length
            enc_half_up = upsample_and_crop(x_half, target_len=seq_length, factor=2)  # (b, seq, d_model)
        else:
            enc_half_up = enc_full
    
        # ----------------- Quarter scale branch -----------------
        if seq_length >= 4:
            x_q = AveragePooling1D(pool_size=4, strides=4, padding='valid')(x)  # (b, seq//4, d_model)
            for i in range(num_layers_quarter):
                x_q = transformer_block_improved(x_q, d_model=d_model, ff_dim=ff_dim, proj_dim=proj_dim, dropout=dropout)
            # 上采样回 seq_length
            enc_q_up = upsample_and_crop(x_q, target_len=seq_length, factor=4)  # (b, seq, d_model)
        else:
            enc_q_up = enc_full
    
        # ----------------- Fuse multi-scale features -----------------
        fused = Concatenate(axis=-1)([enc_full, enc_half_up, enc_q_up])  # (b, seq, d_model*3)
        fused = Dense(d_model, activation='relu', name="fusion_dense")(fused)  # (b, seq, d_model)
        # small transformer to mix fused features
        fused = transformer_block_improved(fused, d_model=d_model, ff_dim=ff_dim, proj_dim=proj_dim, dropout=dropout)
    
        # ----------------- Prediction head -----------------
        pooled = GlobalAveragePooling1D(name="global_avg_pool")(fused)  # (b, d_model)
        feat = Dense(d_model, activation='relu', name="final_mlp")(pooled)
        feat = Dropout(dropout)(feat)
        outputs = Dense(pred_length * n_features, activation='linear', name="output_dense")(feat)
        outputs = tf.keras.layers.Reshape((pred_length, n_features), name="output_reshape")(outputs)
    
        model = Model(inputs=inputs, outputs=outputs, name="TPT_Improved")
        model.compile(optimizer=Adam(learning_rate=1e-3), loss='mse')
        return model
    
    # ----------------------------
    # 4. 训练/评估/绘图函数(保持原样,仅在内部调用改为新的 model)
    # ----------------------------
    def train_model(model, X_train, y_train, X_val, y_val, epochs=100, batch_size=32):
        """训练模型"""
        early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
        checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True)
        history = model.fit(
            X_train, y_train,
            validation_data=(X_val, y_val),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=[early_stopping, checkpoint]
        )
        return model, history
    
    def evaluate_model(model, X_data, y_data, scaler, dataset_name="数据"):
        """评估模型性能(保持不变)"""
        y_pred = model.predict(X_data)
        y_data_reshaped = y_data.reshape(-1, 3)
        y_pred_reshaped = y_pred.reshape(-1, 3)
        y_data_inv = scaler.inverse_transform(y_data_reshaped)
        y_pred_inv = scaler.inverse_transform(y_pred_reshaped)
        mae = mean_absolute_error(y_data_inv, y_pred_inv)
        rmse = np.sqrt(mean_squared_error(y_data_inv, y_pred_inv))
        mae_lon = mean_absolute_error(y_data_inv[:, 0], y_pred_inv[:, 0])
        mae_lat = mean_absolute_error(y_data_inv[:, 1], y_pred_inv[:, 1])
        mae_alt = mean_absolute_error(y_data_inv[:, 2], y_pred_inv[:, 2])
        rmse_lon = np.sqrt(mean_squared_error(y_data_inv[:, 0], y_pred_inv[:, 0]))
        rmse_lat = np.sqrt(mean_squared_error(y_data_inv[:, 1], y_pred_inv[:, 1]))
        rmse_alt = np.sqrt(mean_squared_error(y_data_inv[:, 2], y_pred_inv[:, 2]))
    
        print(f"\n{dataset_name}评价指标:")
        print(f"平均绝对误差 (MAE): {mae:.6f}")
        print(f"均方根误差 (RMSE): {rmse:.6f}")
        print(f"各维度指标:")
        print(f"经度 MAE: {mae_lon:.6f}, RMSE: {rmse_lon:.6f}")
        print(f"纬度 MAE: {mae_lat:.6f}, RMSE: {rmse_lat:.6f}")
        print(f"高度 MAE: {mae_alt:.6f}, RMSE: {rmse_alt:.6f}")
    
        return y_pred, y_data_inv, y_pred_inv
    
    # 以下绘图函数均保留(与原脚本一致)
    def plot_trajectory_comparison(history_data, pred_data, true_data=None):
        fig = plt.figure(figsize=(15, 12))
        ax = fig.add_subplot(111, projection='3d')
        ax.plot(history_data[:, 0], history_data[:, 1], history_data[:, 2], 'b-', label='历史轨迹', linewidth=2)
        ax.plot(pred_data[:, 0], pred_data[:, 1], pred_data[:, 2], 'r--', label='预测轨迹', linewidth=2)
        if true_data is not None:
            ax.plot(true_data[:, 0], true_data[:, 1], true_data[:, 2], 'g-.', label='真实轨迹', linewidth=2)
        ax.scatter(history_data[0, 0], history_data[0, 1], history_data[0, 2], color='green', s=100, label='起点')
        ax.scatter(history_data[-1, 0], history_data[-1, 1], history_data[-1, 2], color='blue', s=100, label='历史终点')
        ax.scatter(pred_data[-1, 0], pred_data[-1, 1], pred_data[-1, 2], color='red', s=100, label='预测终点')
        ax.set_xlabel('经度', fontsize=12)
        ax.set_ylabel('纬度', fontsize=12)
        ax.set_zlabel('高度', fontsize=12)
        ax.set_title('飞行轨迹预测对比', fontsize=15)
        ax.legend(fontsize=10)
        ax.view_init(elev=30, azim=45)
        plt.tight_layout()
        return fig
    
    def plot_error_analysis(true_data, pred_data, title_suffix=""):
        fig, axes = plt.subplots(3, 1, figsize=(12, 15))
        axes[0].plot(true_data[:, 0], label='真实经度')
        axes[0].plot(pred_data[:, 0], label='预测经度')
        axes[0].plot(np.abs(true_data[:, 0] - pred_data[:, 0]), 'r--', label='误差')
        axes[0].set_title(f'经度预测对比{title_suffix}')
        axes[0].set_xlabel('时间步')
        axes[0].set_ylabel('经度值')
        axes[0].legend()
    
        axes[1].plot(true_data[:, 1], label='真实纬度')
        axes[1].plot(pred_data[:, 1], label='预测纬度')
        axes[1].plot(np.abs(true_data[:, 1] - pred_data[:, 1]), 'r--', label='误差')
        axes[1].set_title(f'纬度预测对比{title_suffix}')
        axes[1].set_xlabel('时间步')
        axes[1].set_ylabel('纬度值')
        axes[1].legend()
    
        axes[2].plot(true_data[:, 2], label='真实高度')
        axes[2].plot(pred_data[:, 2], label='预测高度')
        axes[2].plot(np.abs(true_data[:, 2] - pred_data[:, 2]), 'r--', label='误差')
        axes[2].set_title(f'高度预测对比{title_suffix}')
        axes[2].set_xlabel('时间步')
        axes[2].set_ylabel('高度值')
        axes[2].legend()
        plt.tight_layout()
        return fig
    
    def plot_all_trajectories(true_data, pred_data, title="所有轨迹对比"):
        fig = plt.figure(figsize=(20, 16))
        ax1 = fig.add_subplot(221, projection='3d')
        ax1.plot(true_data[:, 0], true_data[:, 1], true_data[:, 2], 'b-', label='真实轨迹', alpha=0.7, linewidth=1.5)
        ax1.plot(pred_data[:, 0], pred_data[:, 1], pred_data[:, 2], 'r--', label='预测轨迹', alpha=0.7, linewidth=1.5)
        ax1.set_xlabel('经度'); ax1.set_ylabel('纬度'); ax1.set_zlabel('高度'); ax1.set_title('3D轨迹对比'); ax1.legend(); ax1.view_init(elev=30, azim=45)
    
        ax2 = fig.add_subplot(222)
        ax2.plot(true_data[:, 0], true_data[:, 1], 'b-', label='真实轨迹', alpha=0.7, linewidth=1.5)
        ax2.plot(pred_data[:, 0], pred_data[:, 1], 'r--', label='预测轨迹', alpha=0.7, linewidth=1.5)
        ax2.set_xlabel('经度'); ax2.set_ylabel('纬度'); ax2.set_title('经度-纬度 投影'); ax2.legend(); ax2.grid(True)
    
        ax3 = fig.add_subplot(223)
        ax3.plot(true_data[:, 0], true_data[:, 2], 'b-', label='真实轨迹', alpha=0.7, linewidth=1.5)
        ax3.plot(pred_data[:, 0], pred_data[:, 2], 'r--', label='预测轨迹', alpha=0.7, linewidth=1.5)
        ax3.set_xlabel('经度'); ax3.set_ylabel('高度'); ax3.set_title('经度-高度 投影'); ax3.legend(); ax3.grid(True)
    
        ax4 = fig.add_subplot(224)
        ax4.plot(true_data[:, 1], true_data[:, 2], 'b-', label='真实轨迹', alpha=0.7, linewidth=1.5)
        ax4.plot(pred_data[:, 1], pred_data[:, 2], 'r--', label='预测轨迹', alpha=0.7, linewidth=1.5)
        ax4.set_xlabel('纬度'); ax4.set_ylabel('高度'); ax4.set_title('纬度-高度 投影'); ax4.legend(); ax4.grid(True)
    
        plt.suptitle(title, fontsize=16, y=0.95)
        plt.tight_layout()
        return fig
    
    def plot_residual_distribution(true_data, pred_data, title_suffix=""):
        residuals = true_data - pred_data
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
        axes[0].hist(residuals[:, 0], bins=30, alpha=0.7)
        axes[0].axvline(x=0, color='red', linestyle='--')
        axes[0].set_title(f'经度残差分布 (均值: {np.mean(residuals[:, 0]):.6f}){title_suffix}')
        axes[0].set_xlabel('残差值'); axes[0].set_ylabel('频数'); axes[0].grid(True, alpha=0.3)
    
        axes[1].hist(residuals[:, 1], bins=30, alpha=0.7)
        axes[1].axvline(x=0, color='red', linestyle='--')
        axes[1].set_title(f'纬度残差分布 (均值: {np.mean(residuals[:, 1]):.6f}){title_suffix}')
        axes[1].set_xlabel('残差值'); axes[1].set_ylabel('频数'); axes[1].grid(True, alpha=0.3)
    
        axes[2].hist(residuals[:, 2], bins=30, alpha=0.7)
        axes[2].axvline(x=0, color='red', linestyle='--')
        axes[2].set_title(f'高度残差分布 (均值: {np.mean(residuals[:, 2]):.6f}){title_suffix}')
        axes[2].set_xlabel('残差值'); axes[2].set_ylabel('频数'); axes[2].grid(True, alpha=0.3)
    
        plt.tight_layout()
        return fig
    
    # ----------------------------
    # 5. 主程序(保持原样,仅使用新的 build_model)
    # ----------------------------
    def main():
        seq_length = 20
        pred_length = 1
        epochs = 100
        batch_size = 128
    
        file_path = 'csv格式/14螺旋上升.csv'
        if not os.path.exists(file_path):
            print(f"错误: 文件 '{file_path}' 不存在")
            return
    
        data, df = load_data(file_path)
        if data is None:
            return
        print(f"成功加载数据,共 {len(data)} 条记录")
    
        data_scaled, scaler = normalize_data(data)
        X, y = create_sequences(data_scaled, seq_length, pred_length)
    
        X_train_val, X_test, y_train_val, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42, shuffle=False)
        X_train, X_val, y_train, y_val = train_test_split(
            X_train_val, y_train_val, test_size=0.25, random_state=42, shuffle=False)
    
        print(f"训练集: {X_train.shape}, 验证集: {X_val.shape}, 测试集: {X_test.shape}")
    
        # 使用 TPT 改进 transformer
        model = build_model(seq_length=seq_length, pred_length=pred_length, n_features=3,
                            d_model=96, num_layers_full=3, num_layers_half=2, num_layers_quarter=1,
                            ff_dim=192, proj_dim=64, time2vec_k=16, dropout=0.1)
        model.summary()
    
        print("\n开始训练模型...")
        start_time = time.time()
        model, history = train_model(model, X_train, y_train, X_val, y_val, epochs=epochs, batch_size=batch_size)
        train_time = time.time() - start_time
        print(f"模型训练完成,耗时: {train_time:.2f} 秒")
    
        # 绘制训练损失曲线
        plt.figure(figsize=(10, 6))
        plt.plot(history.history['loss'], label='训练损失')
        plt.plot(history.history['val_loss'], label='验证损失')
        plt.title('模型训练损失曲线')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.tight_layout()
        plt.show()
    
        # 评估测试集
        y_pred_test, y_test_inv, y_pred_test_inv = evaluate_model(model, X_test, y_test, scaler, "测试集")
        # 评估训练集
        print("\n评估训练集性能...")
        y_pred_train, y_train_inv, y_pred_train_inv = evaluate_model(model, X_train, y_train, scaler, "训练集")
    
        # 绘制训练集完整轨迹对比
        fig_train_all = plot_all_trajectories(y_train_inv, y_pred_train_inv, title='真实轨迹与预测轨迹对比')
        plt.show()
    
    if __name__ == "__main__":
        physical_devices = tf.config.list_physical_devices('GPU')
        if physical_devices:
            try:
                tf.config.experimental.set_memory_growth(physical_devices[0], True)
                print("GPU 加速已启用")
            except:
                print("GPU 加速启用失败,将使用CPU")
        else:
            print("未检测到GPU,将使用CPU")
        main()
    

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序员奇奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值