import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.patches import ConnectionPatch
def visualize_encoder_process(hidden_dim=64):
"""
可视化 Encoder (LSTM/GRU) 如何逐步处理轨迹并更新其隐藏状态。
"""
# 1. 模拟一条有代表性的轨迹 (先直行,后急转)
t1 = np.linspace(0, 10, 20)
traj_part1_x = t1
traj_part1_y = np.full_like(t1, 10) # 直线
t2 = np.linspace(0, np.pi, 30)
traj_part2_x = 10 + 5 * np.sin(t2)
traj_part2_y = 5 + 5 * np.cos(t2) # 半圆转弯
trajectory_x = np.concatenate([traj_part1_x, traj_part2_x])
trajectory_y = np.concatenate([traj_part1_y, traj_part2_y])
# 2. 模拟隐藏状态的演化 (这是关键的“伪造”步骤)
# 真实情况下,这些 h_t 来自模型内部,但我们可以根据逻辑来模拟它们
np.random.seed(42)
# t=1: 初始状态,信息量很少
h_start = np.random.randn(hidden_dim) * 0.1
# t=15: 直线航行阶段,形成一种稳定的模式
h_straight_pattern = np.sin(np.arange(hidden_dim) / 5)
h_mid_straight = h_start + h_straight_pattern * 0.5 + np.random.randn(hidden_dim) * 0.1
# t=35: 正在转弯,模式剧烈变化
h_turn_pattern = np.cos(np.arange(hidden_dim) / 3) * -1
h_turning = h_mid_straight * 0.3 + h_turn_pattern * 0.8 + np.random.randn(hidden_dim) * 0.2 # 遗忘旧模式,学习新模式
# t=50 (End): 最终状态,融合了所有信息
h_final = h_turning + h_straight_pattern * 0.4 + np.random.randn(hidden_dim) * 0.1
h_final = (h_final - h_final.mean()) / h_final.std() # 标准化,让模式更清晰
# 3. 开始绘图
fig, axes = plt.subplots(4, 2, figsize=(12, 14), gridspec_kw={'width_ratios': [1.5, 2]})
fig.suptitle("Encoder in Action: How Trajectory Representation is Built Step-by-Step", fontsize=18, weight='bold')
# 定义一些绘图参数
time_points = [0, 19, 34, 49] # t=1, t=20, t=35, t=50
hidden_states = [h_start, h_mid_straight, h_turning, h_final]
titles = [
"t=1: Journey Begins",
"t=20: Cruising Straight",
"t=35: Making a Sharp Turn",
"t=50: Final Representation (Context Vector)"
]
colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"]
for i in range(4):
ax_traj = axes[i, 0]
ax_vec = axes[i, 1]
t = time_points[i]
# 绘制左侧的轨迹图
ax_traj.plot(trajectory_x, trajectory_y, color='lightgray', linestyle='--', linewidth=2, zorder=1)
ax_traj.plot(trajectory_x[:t+1], trajectory_y[:t+1], color=colors[i], linewidth=3, zorder=2)
ax_traj.scatter(trajectory_x[t], trajectory_y[t], color='red', s=100, zorder=3, label=f"Current Point (t={t+1})")
ax_traj.set_title(titles[i], fontsize=12)
ax_traj.set_xlim(-2, 20)
ax_traj.set_ylim(-2, 12)
ax_traj.legend(loc="upper left")
ax_traj.grid(True, linestyle=':', alpha=0.6)
# 绘制右侧的向量热力图
sns.heatmap([hidden_states[i]], ax=ax_vec, cmap="viridis", cbar=False, xticklabels=False, yticklabels=False)
ax_vec.set_title(f"Hidden State Vector h_{t+1} (Memory at this moment)", fontsize=12)
# 添加从轨迹点到向量的连接线
con = ConnectionPatch(xyA=(trajectory_x[t], trajectory_y[t]), xyB=(0, 0.5),
coordsA=ax_traj.transData, coordsB=ax_vec.transAxes,
axesA=ax_traj, axesB=ax_vec,
arrowstyle="->", shrinkB=5, color='gray', linestyle='dashed')
fig.add_artist(con)
axes[3, 1].set_xlabel("Dimension of Hidden State Vector (e.g., 64-dim)", fontsize=12)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
if __name__ == "__main__":
visualize_encoder_process()
修改一个这个图形中图的间隙不用太密集