目录
1. 问题设置
1.1 5×5复杂网格世界
网格布局:
+---+---+---+---+---+
| S | 1 | 2 | T | 4 |
+---+---+---+---+---+
| 5 | X | 7 | 8 | 9 |
+---+---+---+---+---+
|10 |11 | X |13 |14 |
+---+---+---+---+---+
|15 |16 |17 |18 | G |
+---+---+---+---+---+
|20 |21 |22 |23 |24 |
+---+---+---+---+---+
图例:
S = 起始状态 (0,0)
G = 目标状态 (3,4) 奖励 +100
T = 陷阱状态 (0,3) 奖励 -50
X = 障碍物 (1,1) 和 (2,2) 不可通过
数字 = 普通状态,奖励 -1
1.2 状态编码
状态按行优先编码:
- 状态0: (0,0) - 起始状态S
- 状态1: (0,1)
- 状态2: (0,2)
- 状态3: (0,3) - 陷阱T
- 状态4: (0,4)
- …
- 状态19: (3,4) - 目标G
1.3 动作空间
- 动作0: 上 (↑)
- 动作1: 下 (↓)
- 动作2: 左 (←)
- 动作3: 右 (→)
2. 环境参数
# 环境参数
GRID_SIZE = 5
START_STATE = 0 # (0,0)
GOAL_STATE = 19 # (3,4)
TRAP_STATE = 3 # (0,3)
OBSTACLES = [6, 12] # (1,1) 和 (2,2)
# 奖励设置
REWARD_GOAL = 100
REWARD_TRAP = -50
REWARD_STEP = -1
REWARD_WALL = -1 # 撞墙奖励
# 算法参数
GAMMA = 0.9 # 折扣因子
ALPHA = 0.1 # 学习率
EPSILON = 0.1 # ε-贪心参数
MAX_EPISODES = 3 # 示例展示3个回合
MAX_STEPS = 50 # 每回合最大步数
3. SARSA算法详细执行
3.1 初始化
# Q表初始化 (25个状态 × 4个动作)
Q_sarsa = np.zeros((25, 4))
# 目标状态和陷阱状态的Q值保持为0
Q_sarsa[GOAL_STATE, :] = 0
Q_sarsa[TRAP_STATE, :] = 0
3.2 ε-贪心策略函数
def epsilon_greedy(Q, state, epsilon):
if np.random.random() < epsilon:
return np.random.randint(4) # 随机动作
else:
return np.argmax(Q[state]) # 贪心动作
3.3 回合1详细执行
步骤1
当前状态: S₀ = 0 (起始状态)
选择动作: A₀ = ε-贪心选择
- Q(0, ↑) = 0, Q(0, ↓) = 0, Q(0, ←) = 0, Q(0, →) = 0
- 所有动作Q值相等,随机选择:A₀ = 3 (→)
执行动作: 从(0,0)向右移动到(0,1)
观察结果: R₁ = -1, S₁ = 1
选择下一动作: A₁ = ε-贪心选择
- Q(1, ↑) = 0, Q(1, ↓) = 0, Q(1, ←) = 0, Q(1, →) = 0
- 随机选择:A₁ = 1 (↓)
SARSA更新:
δ = R₁ + γ × Q(S₁, A₁) - Q(S₀, A₀)
δ = -1 + 0.9 × Q(1, ↓) - Q(0, →)
δ = -1 + 0.9 × 0 - 0 = -1
Q(S₀, A₀) ← Q(S₀, A₀) + α × δ
Q(0, →) ← 0 + 0.1 × (-1) = -0.1
更新后Q表:
Q(0, ↑) = 0, Q(0, ↓) = 0, Q(0, ←) = 0, Q(0, →) = -0.1
Q(1, ↑) = 0, Q(1, ↓) = 0, Q(1, ←) = 0, Q(1, →) = 0
步骤2
当前状态: S₁ = 1, A₁ = 1 (↓)
执行动作: 从(0,1)向下移动到(1,1) - 但这是障碍物!
实际结果: 撞墙,停留在(0,1)
观察结果: R₂ = -1, S₂ = 1
选择下一动作: A₂ = ε-贪心选择
- 当前Q(1, ↑) = 0, Q(1, ↓) = 0, Q(1, ←) = 0, Q(1, →) = 0
- 随机选择:A₂ = 3 (→)
SARSA更新:
δ = R₂ + γ × Q(S₂, A₂) - Q(S₁, A₁)
δ = -1 + 0.9 × Q(1, →) - Q(1, ↓)
δ = -1 + 0.9 × 0 - 0 = -1
Q(1, ↓) ← 0 + 0.1 × (-1) = -0.1
更新后Q表:
Q(0, →) = -0.1
Q(1, ↓) = -0.1
步骤3
当前状态: S₂ = 1, A₂ = 3 (→)
执行动作: 从(0,1)向右移动到(0,2)
观察结果: R₃ = -1, S₃ = 2
选择下一动作: A₃ = ε-贪心选择
- Q(2, ↑) = 0, Q(2, ↓) = 0, Q(2, ←) = 0, Q(2, →) = 0
- 随机选择:A₃ = 3 (→)
SARSA更新:
δ = -1 + 0.9 × 0 - 0 = -1
Q(1, →) ← 0 + 0.1 × (-1) = -0.1
步骤4
当前状态: S₃ = 2, A₃ = 3 (→)
执行动作: 从(0,2)向右移动到(0,3) - 陷阱!
观察结果: R₄ = -50, S₄ = 3 (陷阱状态,回合结束)
SARSA更新:
δ = R₄ + γ × Q(终止状态) - Q(S₃, A₃)
δ = -50 + 0.9 × 0 - 0 = -50
Q(2, →) ← 0 + 0.1 × (-50) = -5.0
回合1结束后的Q表:
状态0: Q(0, ↑)=0, Q(0, ↓)=0, Q(0, ←)=0, Q(0, →)=-0.1
状态1: Q(1, ↑)=0, Q(1, ↓)=-0.1, Q(1, ←)=0, Q(1, →)=-0.1
状态2: Q(2, ↑)=0, Q(2, ↓)=0, Q(2, ←)=0, Q(2, →)=-5.0
其他状态: 全部为0
3.4 回合2详细执行
步骤1
当前状态: S₀ = 0
选择动作: ε-贪心选择
- Q(0, ↑)=0, Q(0, ↓)=0, Q(0, ←)=0, Q(0, →)=-0.1
- 最大Q值为0(上、下、左),随机选择:A₀ = 1 (↓)
执行动作: 从(0,0)向下移动到(1,0)
观察结果: R₁ = -1, S₁ = 5
选择下一动作: A₁ = ε-贪心选择
- Q(5, 所有动作) = 0,随机选择:A₁ = 3 (→)
SARSA更新:
δ = -1 + 0.9 × 0 - 0 = -1
Q(0, ↓) ← 0 + 0.1 × (-1) = -0.1
步骤2
当前状态: S₁ = 5, A₁ = 3 (→)
执行动作: 从(1,0)向右移动到(1,1) - 障碍物!
实际结果: 撞墙,停留在(1,0)
观察结果: R₂ = -1, S₂ = 5
选择下一动作: A₂ = 1 (↓)
SARSA更新:
δ = -1 + 0.9 × 0 - 0 = -1
Q(5, →) ← 0 + 0.1 × (-1) = -0.1
继续执行…
经过多步后,智能体学会避开障碍物和陷阱,最终可能到达目标。
3.5 回合3详细执行
随着学习的进行,Q表逐渐收敛,智能体行为变得更加智能。
回合3结束后的部分Q表:
状态0: Q(0, ↑)=0, Q(0, ↓)=-0.19, Q(0, ←)=0, Q(0, →)=-0.19
状态1: Q(1, ↑)=0, Q(1, ↓)=-0.19, Q(1, ←)=-0.1, Q(1, →)=-0.28
状态2: Q(2, ↑)=0, Q(2, ↓)=-0.1, Q(2, ←)=-0.1, Q(2, →)=-5.0
状态5: Q(5, ↑)=-0.1, Q(5, ↓)=-0.1, Q(5, ←)=0, Q(5, →)=-0.19
...
4. Q-learning算法详细执行
4.1 初始化
# Q表初始化 (与SARSA相同)
Q_qlearning = np.zeros((25, 4))
Q_qlearning[GOAL_STATE, :] = 0
Q_qlearning[TRAP_STATE, :] = 0
4.2 回合1详细执行
步骤1
当前状态: S₀ = 0
选择动作: A₀ = 3 (→) (与SARSA相同的随机选择)
执行动作: 从(0,0)向右移动到(0,1)
观察结果: R₁ = -1, S₁ = 1
Q-learning更新 (关键差异):
δ = R₁ + γ × max_a Q(S₁, a) - Q(S₀, A₀)
δ = -1 + 0.9 × max{Q(1, ↑), Q(1, ↓), Q(1, ←), Q(1, →)} - Q(0, →)
δ = -1 + 0.9 × max{0, 0, 0, 0} - 0 = -1
Q(0, →) ← 0 + 0.1 × (-1) = -0.1
注意: Q-learning使用max操作,而不是实际选择的下一动作。
步骤2
当前状态: S₁ = 1
选择动作: A₁ = 1 (↓) (行为策略选择)
执行动作: 撞墙,停留在状态1
观察结果: R₂ = -1, S₂ = 1
Q-learning更新:
δ = -1 + 0.9 × max{Q(1, ↑), Q(1, ↓), Q(1, ←), Q(1, →)} - Q(1, ↓)
δ = -1 + 0.9 × 0 - 0 = -1
Q(1, ↓) ← 0 + 0.1 × (-1) = -0.1
步骤3和步骤4
类似地执行,但在更新时始终使用max操作。
回合1结束后的Q表 (与SARSA相同):
状态0: Q(0, →) = -0.1
状态1: Q(1, ↓) = -0.1, Q(1, →) = -0.1
状态2: Q(2, →) = -5.0
4.3 关键差异体现
随着学习进行,Q-learning和SARSA的差异会逐渐显现:
Q-learning特点:
- 总是使用下一状态的最大Q值进行更新
- 学习最优策略,不受当前行为策略影响
- 在有"陷阱"的环境中更激进
SARSA特点:
- 使用实际选择的下一动作进行更新
- 学习当前策略的价值
- 在有"陷阱"的环境中更保守
5. Python伪代码实现
5.1 环境类
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
class GridWorld:
def __init__(self):
self.grid_size = 5
self.start_state = 0
self.goal_state = 19
self.trap_state = 3
self.obstacles = [6, 12]
# 奖励设置
self.reward_goal = 100
self.reward_trap = -50
self.reward_step = -1
self.reward_wall = -1
# 动作映射
self.actions = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)} # 上下左右
self.action_names = ['↑', '↓', '←', '→']
def state_to_coord(self, state):
"""状态编号转坐标"""
return (state // self.grid_size, state % self.grid_size)
def coord_to_state(self, row, col):
"""坐标转状态编号"""
return row * self.grid_size + col
def is_valid_state(self, state):
"""检查状态是否有效"""
if state < 0 or state >= self.grid_size * self.grid_size:
return False
if state in self.obstacles:
return False
return True
def get_next_state(self, state, action):
"""获取执行动作后的下一状态"""
row, col = self.state_to_coord(state)
d_row, d_col = self.actions[action]
new_row, new_col = row + d_row, col + d_col
# 检查边界
if new_row < 0 or new_row >= self.grid_size or new_col < 0 or new_col >= self.grid_size:
return state # 撞墙,停留原地
new_state = self.coord_to_state(new_row, new_col)
# 检查障碍物
if new_state in self.obstacles:
return state # 撞障碍物,停留原地
return new_state
def get_reward(self, state, action, next_state):
"""获取奖励"""
if next_state == self.goal_state:
return self.reward_goal
elif next_state == self.trap_state:
return self.reward_trap
elif next_state == state: # 撞墙或障碍物
return self.reward_wall
else:
return self.reward_step
def is_terminal(self, state):
"""检查是否为终止状态"""
return state == self.goal_state or state == self.trap_state
def reset(self):
"""重置环境"""
return self.start_state
def step(self, state, action):
"""执行一步"""
next_state = self.get_next_state(state, action)
reward = self.get_reward(state, action, next_state)
done = self.is_terminal(next_state)
return next_state, reward, done
5.2 SARSA算法实现
class SARSAAgent:
def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1):
self.env = env
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
# 初始化Q表
self.Q = np.zeros((env.grid_size * env.grid_size, 4))
# 记录学习过程
self.episode_rewards = []
self.episode_steps = []
self.q_history = []
def epsilon_greedy(self, state):
"""ε-贪心策略"""
if np.random.random() < self.epsilon:
return np.random.randint(4)
else:
return np.argmax(self.Q[state])
def update_q(self, state, action, reward, next_state, next_action):
"""SARSA Q值更新"""
td_error = reward + self.gamma * self.Q[next_state, next_action] - self.Q[state, action]
self.Q[state, action] += self.alpha * td_error
return td_error
def train_episode(self, max_steps=50, verbose=False):
"""训练一个回合"""
state = self.env.reset()
action = self.epsilon_greedy(state)
episode_reward = 0
steps = 0
if verbose:
print(f"\n=== SARSA 回合开始 ===")
print(f"初始状态: {state} {self.env.state_to_coord(state)}")
while steps < max_steps:
# 执行动作
next_state, reward, done = self.env.step(state, action)
episode_reward += reward
steps += 1
if verbose:
print(f"\n步骤 {steps}:")
print(f" 状态: {state} -> 动作: {self.env.action_names[action]} -> 下一状态: {next_state}")
print(f" 奖励: {reward}")
if done:
# 终止状态更新
td_error = self.update_q(state, action, reward, next_state, 0)
if verbose:
print(f" TD误差: {td_error:.3f}")
print(f" Q({state}, {self.env.action_names[action]}) 更新为: {self.Q[state, action]:.3f}")
break
else:
# 选择下一动作
next_action = self.epsilon_greedy(next_state)
# SARSA更新
td_error = self.update_q(state, action, reward, next_state, next_action)
if verbose:
print(f" 下一动作: {self.env.action_names[next_action]}")
print(f" TD误差: {td_error:.3f}")
print(f" Q({state}, {self.env.action_names[action]}) 更新为: {self.Q[state, action]:.3f}")
# 移动到下一状态
state = next_state
action = next_action
self.episode_rewards.append(episode_reward)
self.episode_steps.append(steps)
if verbose:
print(f"\n回合结束: 总奖励 = {episode_reward}, 总步数 = {steps}")
return episode_reward, steps
def train(self, episodes, verbose_episodes=None):
"""训练多个回合"""
if verbose_episodes is None:
verbose_episodes = []
for episode in range(episodes):
verbose = episode in verbose_episodes
if verbose:
print(f"\n{'='*50}")
print(f"SARSA 训练回合 {episode + 1}")
print(f"{'='*50}")
reward, steps = self.train_episode(verbose=verbose)
# 记录Q表历史
if episode in verbose_episodes:
self.q_history.append(self.Q.copy())
self.print_q_table()
def print_q_table(self, states=None):
"""打印Q表"""
if states is None:
states = [0, 1, 2, 5] # 打印关键状态
print("\n当前Q表 (部分):")
for state in states:
if state < len(self.Q):
coord = self.env.state_to_coord(state)
print(f"状态 {state} {coord}: ", end="")
for action in range(4):
action_name = self.env.action_names[action]
q_value = self.Q[state, action]
print(f"Q({action_name})={q_value:.3f} ", end="")
print()
5.3 Q-learning算法实现
class QLearningAgent:
def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1):
self.env = env
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
# 初始化Q表
self.Q = np.zeros((env.grid_size * env.grid_size, 4))
# 记录学习过程
self.episode_rewards = []
self.episode_steps = []
self.q_history = []
def epsilon_greedy(self, state):
"""ε-贪心策略"""
if np.random.random() < self.epsilon:
return np.random.randint(4)
else:
return np.argmax(self.Q[state])
def update_q(self, state, action, reward, next_state):
"""Q-learning Q值更新"""
max_next_q = np.max(self.Q[next_state]) if not self.env.is_terminal(next_state) else 0
td_error = reward + self.gamma * max_next_q - self.Q[state, action]
self.Q[state, action] += self.alpha * td_error
return td_error
def train_episode(self, max_steps=50, verbose=False):
"""训练一个回合"""
state = self.env.reset()
episode_reward = 0
steps = 0
if verbose:
print(f"\n=== Q-learning 回合开始 ===")
print(f"初始状态: {state} {self.env.state_to_coord(state)}")
while steps < max_steps:
# 选择动作
action = self.epsilon_greedy(state)
# 执行动作
next_state, reward, done = self.env.step(state, action)
episode_reward += reward
steps += 1
if verbose:
print(f"\n步骤 {steps}:")
print(f" 状态: {state} -> 动作: {self.env.action_names[action]} -> 下一状态: {next_state}")
print(f" 奖励: {reward}")
# Q-learning更新
td_error = self.update_q(state, action, reward, next_state)
if verbose:
max_next_q = np.max(self.Q[next_state]) if not done else 0
print(f" max Q(s', a') = {max_next_q:.3f}")
print(f" TD误差: {td_error:.3f}")
print(f" Q({state}, {self.env.action_names[action]}) 更新为: {self.Q[state, action]:.3f}")
if done:
break
state = next_state
self.episode_rewards.append(episode_reward)
self.episode_steps.append(steps)
if verbose:
print(f"\n回合结束: 总奖励 = {episode_reward}, 总步数 = {steps}")
return episode_reward, steps
def train(self, episodes, verbose_episodes=None):
"""训练多个回合"""
if verbose_episodes is None:
verbose_episodes = []
for episode in range(episodes):
verbose = episode in verbose_episodes
if verbose:
print(f"\n{'='*50}")
print(f"Q-learning 训练回合 {episode + 1}")
print(f"{'='*50}")
reward, steps = self.train_episode(verbose=verbose)
# 记录Q表历史
if episode in verbose_episodes:
self.q_history.append(self.Q.copy())
self.print_q_table()
def print_q_table(self, states=None):
"""打印Q表"""
if states is None:
states = [0, 1, 2, 5] # 打印关键状态
print("\n当前Q表 (部分):")
for state in states:
if state < len(self.Q):
coord = self.env.state_to_coord(state)
print(f"状态 {state} {coord}: ", end="")
for action in range(4):
action_name = self.env.action_names[action]
q_value = self.Q[state, action]
print(f"Q({action_name})={q_value:.3f} ", end="")
print()
5.4 主程序
def main():
# 创建环境
env = GridWorld()
# 设置随机种子以便复现
np.random.seed(42)
# 创建智能体
sarsa_agent = SARSAAgent(env)
qlearning_agent = QLearningAgent(env)
print("5x5 复杂网格世界 SARSA vs Q-learning 详细示例")
print("="*60)
# 打印环境信息
print("\n环境设置:")
print(f"起始状态: {env.start_state} {env.state_to_coord(env.start_state)}")
print(f"目标状态: {env.goal_state} {env.state_to_coord(env.goal_state)} (奖励: +{env.reward_goal})")
print(f"陷阱状态: {env.trap_state} {env.state_to_coord(env.trap_state)} (奖励: {env.reward_trap})")
print(f"障碍物: {env.obstacles} {[env.state_to_coord(obs) for obs in env.obstacles]}")
print(f"普通移动奖励: {env.reward_step}")
# 训练SARSA (显示前3个回合的详细过程)
print("\n" + "="*60)
print("SARSA 算法训练")
print("="*60)
sarsa_agent.train(episodes=3, verbose_episodes=[0, 1, 2])
# 重置随机种子
np.random.seed(42)
# 训练Q-learning (显示前3个回合的详细过程)
print("\n" + "="*60)
print("Q-learning 算法训练")
print("="*60)
qlearning_agent.train(episodes=3, verbose_episodes=[0, 1, 2])
# 比较结果
print("\n" + "="*60)
print("算法对比")
print("="*60)
print("\n回合奖励对比:")
for i in range(3):
sarsa_reward = sarsa_agent.episode_rewards[i]
qlearning_reward = qlearning_agent.episode_rewards[i]
print(f"回合 {i+1}: SARSA = {sarsa_reward:6.1f}, Q-learning = {qlearning_reward:6.1f}")
print("\n最终Q表对比 (关键状态):")
states_to_compare = [0, 1, 2, 5, 10]
for state in states_to_compare:
coord = env.state_to_coord(state)
print(f"\n状态 {state} {coord}:")
print(" SARSA : ", end="")
for action in range(4):
action_name = env.action_names[action]
q_value = sarsa_agent.Q[state, action]
print(f"Q({action_name})={q_value:6.3f} ", end="")
print()
print(" Q-learn : ", end="")
for action in range(4):
action_name = env.action_names[action]
q_value = qlearning_agent.Q[state, action]
print(f"Q({action_name})={q_value:6.3f} ", end="")
print()
if __name__ == "__main__":
main()
6. 算法对比分析
6.1 更新公式对比
SARSA更新:
# 使用实际选择的下一动作
td_error = reward + gamma * Q[next_state, next_action] - Q[state, action]
Q[state, action] += alpha * td_error
Q-learning更新:
# 使用下一状态的最大Q值
max_next_q = np.max(Q[next_state])
td_error = reward + gamma * max_next_q - Q[state, action]
Q[state, action] += alpha * td_error
6.2 数值差异示例
假设在某个状态下:
- Q(next_state, ↑) = -0.2
- Q(next_state, ↓) = -0.1 ← 实际选择的动作
- Q(next_state, ←) = -0.3
- Q(next_state, →) = -0.05 ← 最大Q值
SARSA使用: Q(next_state, ↓) = -0.1
Q-learning使用: max Q(next_state, ·) = -0.05
这导致Q-learning的更新更加"乐观",因为它总是假设下一步会选择最优动作。
6.3 行为差异
在我们的5×5网格世界中:
SARSA行为特点:
- 更保守,倾向于远离陷阱
- 学习过程中会考虑探索的风险
- 最终策略可能不是严格最优,但更安全
Q-learning行为特点:
- 更激进,敢于走险路
- 不受探索策略影响
- 最终策略趋向于真正的最优策略
6.4 收敛性对比
理论收敛:
- 两种算法在满足一定条件下都能收敛
- Q-learning收敛到最优Q*函数
- SARSA收敛到当前策略的Q^π函数
实际收敛速度:
- Q-learning通常收敛更快
- SARSA收敛速度依赖于探索策略
7. 收敛过程可视化
7.1 Q值演化图
def plot_q_evolution(sarsa_agent, qlearning_agent):
"""绘制Q值演化过程"""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# 选择关键状态-动作对进行可视化
key_state_actions = [(0, 3), (1, 1), (2, 3), (5, 1)] # (state, action)
for i, (state, action) in enumerate(key_state_actions):
row, col = i // 2, i % 2
ax = axes[row, col]
# 提取Q值历史
sarsa_q_history = [q_table[state, action] for q_table in sarsa_agent.q_history]
qlearning_q_history = [q_table[state, action] for q_table in qlearning_agent.q_history]
episodes = range(1, len(sarsa_q_history) + 1)
ax.plot(episodes, sarsa_q_history, 'b-o', label='SARSA', linewidth=2, markersize=6)
ax.plot(episodes, qlearning_q_history, 'r-s', label='Q-learning', linewidth=2, markersize=6)
coord = sarsa_agent.env.state_to_coord(state)
action_name = sarsa_agent.env.action_names[action]
ax.set_title(f'Q({state}{coord}, {action_name}) 演化', fontsize=12)
ax.set_xlabel('回合')
ax.set_ylabel('Q值')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('q_evolution.png', dpi=300, bbox_inches='tight')
plt.show()
def plot_reward_comparison(sarsa_agent, qlearning_agent):
"""绘制奖励对比"""
episodes = range(1, len(sarsa_agent.episode_rewards) + 1)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(episodes, sarsa_agent.episode_rewards, 'b-o', label='SARSA', linewidth=2)
plt.plot(episodes, qlearning_agent.episode_rewards, 'r-s', label='Q-learning', linewidth=2)
plt.title('每回合奖励对比')
plt.xlabel('回合')
plt.ylabel('累积奖励')
plt.legend()
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
plt.plot(episodes, sarsa_agent.episode_steps, 'b-o', label='SARSA', linewidth=2)
plt.plot(episodes, qlearning_agent.episode_steps, 'r-s', label='Q-learning', linewidth=2)
plt.title('每回合步数对比')
plt.xlabel('回合')
plt.ylabel('步数')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('reward_comparison.png', dpi=300, bbox_inches='tight')
plt.show()
7.2 策略可视化
def visualize_policy(agent, title):
"""可视化学习到的策略"""
policy = np.zeros((agent.env.grid_size, agent.env.grid_size), dtype=int)
for row in range(agent.env.grid_size):
for col in range(agent.env.grid_size):
state = agent.env.coord_to_state(row, col)
if state not in agent.env.obstacles and not agent.env.is_terminal(state):
policy[row, col] = np.argmax(agent.Q[state])
else:
policy[row, col] = -1 # 标记障碍物和终止状态
# 创建可视化
fig, ax = plt.subplots(figsize=(8, 8))
# 绘制网格
for i in range(agent.env.grid_size + 1):
ax.axhline(i - 0.5, color='black', linewidth=1)
ax.axvline(i - 0.5, color='black', linewidth=1)
# 绘制状态和策略
arrow_symbols = ['↑', '↓', '←', '→']
colors = ['lightblue', 'lightgreen', 'lightcoral', 'lightyellow']
for row in range(agent.env.grid_size):
for col in range(agent.env.grid_size):
state = agent.env.coord_to_state(row, col)
if state == agent.env.start_state:
ax.add_patch(plt.Rectangle((col-0.5, row-0.5), 1, 1, facecolor='green', alpha=0.7))
ax.text(col, row, 'S', ha='center', va='center', fontsize=16, fontweight='bold')
elif state == agent.env.goal_state:
ax.add_patch(plt.Rectangle((col-0.5, row-0.5), 1, 1, facecolor='gold', alpha=0.7))
ax.text(col, row, 'G', ha='center', va='center', fontsize=16, fontweight='bold')
elif state == agent.env.trap_state:
ax.add_patch(plt.Rectangle((col-0.5, row-0.5), 1, 1, facecolor='red', alpha=0.7))
ax.text(col, row, 'T', ha='center', va='center', fontsize=16, fontweight='bold')
elif state in agent.env.obstacles:
ax.add_patch(plt.Rectangle((col-0.5, row-0.5), 1, 1, facecolor='black', alpha=0.8))
ax.text(col, row, 'X', ha='center', va='center', fontsize=16, fontweight='bold', color='white')
else:
action = policy[row, col]
ax.add_patch(plt.Rectangle((col-0.5, row-0.5), 1, 1, facecolor=colors[action], alpha=0.5))
ax.text(col, row, arrow_symbols[action], ha='center', va='center', fontsize=20, fontweight='bold')
ax.set_xlim(-0.5, agent.env.grid_size - 0.5)
ax.set_ylim(-0.5, agent.env.grid_size - 0.5)
ax.set_aspect('equal')
ax.invert_yaxis() # 让(0,0)在左上角
ax.set_title(f'{title} 学习到的策略', fontsize=16)
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
plt.savefig(f'{title.lower()}_policy.png', dpi=300, bbox_inches='tight')
plt.show()
# 使用示例
def visualize_results(sarsa_agent, qlearning_agent):
"""可视化所有结果"""
plot_q_evolution(sarsa_agent, qlearning_agent)
plot_reward_comparison(sarsa_agent, qlearning_agent)
visualize_policy(sarsa_agent, 'SARSA')
visualize_policy(qlearning_agent, 'Q-learning')
7.3 收敛分析
def analyze_convergence(sarsa_agent, qlearning_agent, episodes=1000):
"""分析收敛性"""
print("\n收敛性分析")
print("="*50)
# 训练更多回合
print("训练1000个回合以分析收敛性...")
# 继续训练SARSA
for _ in range(episodes - len(sarsa_agent.episode_rewards)):
sarsa_agent.train_episode()
# 继续训练Q-learning
for _ in range(episodes - len(qlearning_agent.episode_rewards)):
qlearning_agent.train_episode()
# 计算移动平均奖励
window = 50
sarsa_ma = np.convolve(sarsa_agent.episode_rewards, np.ones(window)/window, mode='valid')
qlearning_ma = np.convolve(qlearning_agent.episode_rewards, np.ones(window)/window, mode='valid')
# 绘制收敛曲线
plt.figure(figsize=(12, 8))
plt.subplot(2, 1, 1)
episodes_range = range(len(sarsa_agent.episode_rewards))
plt.plot(episodes_range, sarsa_agent.episode_rewards, 'b-', alpha=0.3, label='SARSA (原始)')
plt.plot(episodes_range, qlearning_agent.episode_rewards, 'r-', alpha=0.3, label='Q-learning (原始)')
ma_episodes = range(window-1, len(sarsa_agent.episode_rewards))
plt.plot(ma_episodes, sarsa_ma, 'b-', linewidth=2, label=f'SARSA (MA-{window})')
plt.plot(ma_episodes, qlearning_ma, 'r-', linewidth=2, label=f'Q-learning (MA-{window})')
plt.title('收敛过程:每回合奖励')
plt.xlabel('回合')
plt.ylabel('累积奖励')
plt.legend()
plt.grid(True, alpha=0.3)
plt.subplot(2, 1, 2)
plt.plot(episodes_range, sarsa_agent.episode_steps, 'b-', alpha=0.3, label='SARSA')
plt.plot(episodes_range, qlearning_agent.episode_steps, 'r-', alpha=0.3, label='Q-learning')
plt.title('收敛过程:每回合步数')
plt.xlabel('回合')
plt.ylabel('步数')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('convergence_analysis.png', dpi=300, bbox_inches='tight')
plt.show()
# 打印收敛统计
print(f"\n最后100回合平均奖励:")
print(f" SARSA: {np.mean(sarsa_agent.episode_rewards[-100:]):.2f}")
print(f" Q-learning: {np.mean(qlearning_agent.episode_rewards[-100:]):.2f}")
print(f"\n最后100回合平均步数:")
print(f" SARSA: {np.mean(sarsa_agent.episode_steps[-100:]):.2f}")
print(f" Q-learning: {np.mean(qlearning_agent.episode_steps[-100:]):.2f}")
总结
本示例通过5×5复杂网格世界详细展示了SARSA和Q-learning算法的:
- 完整的数值计算过程:每一步的Q值更新都有明确的数学计算
- 算法实现细节:包含完整的Python代码实现
- 行为差异对比:通过具体数值展示两种算法的不同特性
- 可视化分析:策略演化、收敛过程的图形化展示
关键洞察:
- SARSA更保守,适合安全关键应用
- Q-learning更激进,追求最优性能
- 两种算法的差异在复杂环境中更加明显
- 理解算法原理有助于选择合适的方法解决实际问题

被折叠的 条评论
为什么被折叠?



