SARSA和Q-learning示例

目录

  1. 问题设置
  2. 环境参数
  3. SARSA算法详细执行
  4. Q-learning算法详细执行
  5. Python伪代码实现
  6. 算法对比分析
  7. 收敛过程可视化

1. 问题设置

1.1 5×5复杂网格世界

网格布局:
+---+---+---+---+---+
| S | 1 | 2 | T | 4 |
+---+---+---+---+---+
| 5 | X | 7 | 8 | 9 |
+---+---+---+---+---+
|10 |11 | X |13 |14 |
+---+---+---+---+---+
|15 |16 |17 |18 | G |
+---+---+---+---+---+
|20 |21 |22 |23 |24 |
+---+---+---+---+---+

图例:
S  = 起始状态 (0,0)
G  = 目标状态 (3,4) 奖励 +100
T  = 陷阱状态 (0,3) 奖励 -50
X  = 障碍物 (1,1) 和 (2,2) 不可通过
数字 = 普通状态,奖励 -1

1.2 状态编码

状态按行优先编码:

  • 状态0: (0,0) - 起始状态S
  • 状态1: (0,1)
  • 状态2: (0,2)
  • 状态3: (0,3) - 陷阱T
  • 状态4: (0,4)
  • 状态19: (3,4) - 目标G

1.3 动作空间

  • 动作0: 上 (↑)
  • 动作1: 下 (↓)
  • 动作2: 左 (←)
  • 动作3: 右 (→)

2. 环境参数

# 环境参数
GRID_SIZE = 5
START_STATE = 0      # (0,0)
GOAL_STATE = 19      # (3,4)
TRAP_STATE = 3       # (0,3)
OBSTACLES = [6, 12]  # (1,1) 和 (2,2)

# 奖励设置
REWARD_GOAL = 100
REWARD_TRAP = -50
REWARD_STEP = -1
REWARD_WALL = -1     # 撞墙奖励

# 算法参数
GAMMA = 0.9          # 折扣因子
ALPHA = 0.1          # 学习率
EPSILON = 0.1        # ε-贪心参数
MAX_EPISODES = 3     # 示例展示3个回合
MAX_STEPS = 50       # 每回合最大步数

3. SARSA算法详细执行

3.1 初始化

# Q表初始化 (25个状态 × 4个动作)
Q_sarsa = np.zeros((25, 4))
# 目标状态和陷阱状态的Q值保持为0
Q_sarsa[GOAL_STATE, :] = 0
Q_sarsa[TRAP_STATE, :] = 0

3.2 ε-贪心策略函数

def epsilon_greedy(Q, state, epsilon):
    if np.random.random() < epsilon:
        return np.random.randint(4)  # 随机动作
    else:
        return np.argmax(Q[state])   # 贪心动作

3.3 回合1详细执行

步骤1

当前状态: S₀ = 0 (起始状态)
选择动作: A₀ = ε-贪心选择

  • Q(0, ↑) = 0, Q(0, ↓) = 0, Q(0, ←) = 0, Q(0, →) = 0
  • 所有动作Q值相等,随机选择:A₀ = 3 (→)

执行动作: 从(0,0)向右移动到(0,1)
观察结果: R₁ = -1, S₁ = 1

选择下一动作: A₁ = ε-贪心选择

  • Q(1, ↑) = 0, Q(1, ↓) = 0, Q(1, ←) = 0, Q(1, →) = 0
  • 随机选择:A₁ = 1 (↓)

SARSA更新:

δ = R₁ + γ × Q(S₁, A₁) - Q(S₀, A₀)
δ = -1 + 0.9 × Q(1, ↓) - Q(0, →)
δ = -1 + 0.9 × 0 - 0 = -1

Q(S₀, A₀) ← Q(S₀, A₀) + α × δ
Q(0, →) ← 0 + 0.1 × (-1) = -0.1

更新后Q表:

Q(0, ↑) = 0,    Q(0, ↓) = 0,    Q(0, ←) = 0,    Q(0, →) = -0.1
Q(1, ↑) = 0,    Q(1, ↓) = 0,    Q(1, ←) = 0,    Q(1, →) = 0
步骤2

当前状态: S₁ = 1, A₁ = 1 (↓)
执行动作: 从(0,1)向下移动到(1,1) - 但这是障碍物!
实际结果: 撞墙,停留在(0,1)
观察结果: R₂ = -1, S₂ = 1

选择下一动作: A₂ = ε-贪心选择

  • 当前Q(1, ↑) = 0, Q(1, ↓) = 0, Q(1, ←) = 0, Q(1, →) = 0
  • 随机选择:A₂ = 3 (→)

SARSA更新:

δ = R₂ + γ × Q(S₂, A₂) - Q(S₁, A₁)
δ = -1 + 0.9 × Q(1, →) - Q(1, ↓)
δ = -1 + 0.9 × 0 - 0 = -1

Q(1, ↓) ← 0 + 0.1 × (-1) = -0.1

更新后Q表:

Q(0, →) = -0.1
Q(1, ↓) = -0.1
步骤3

当前状态: S₂ = 1, A₂ = 3 (→)
执行动作: 从(0,1)向右移动到(0,2)
观察结果: R₃ = -1, S₃ = 2

选择下一动作: A₃ = ε-贪心选择

  • Q(2, ↑) = 0, Q(2, ↓) = 0, Q(2, ←) = 0, Q(2, →) = 0
  • 随机选择:A₃ = 3 (→)

SARSA更新:

δ = -1 + 0.9 × 0 - 0 = -1
Q(1, →) ← 0 + 0.1 × (-1) = -0.1
步骤4

当前状态: S₃ = 2, A₃ = 3 (→)
执行动作: 从(0,2)向右移动到(0,3) - 陷阱!
观察结果: R₄ = -50, S₄ = 3 (陷阱状态,回合结束)

SARSA更新:

δ = R₄ + γ × Q(终止状态) - Q(S₃, A₃)
δ = -50 + 0.9 × 0 - 0 = -50

Q(2, →) ← 0 + 0.1 × (-50) = -5.0

回合1结束后的Q表:

状态0: Q(0, ↑)=0,    Q(0, ↓)=0,    Q(0, ←)=0,    Q(0, →)=-0.1
状态1: Q(1, ↑)=0,    Q(1, ↓)=-0.1, Q(1, ←)=0,    Q(1, →)=-0.1
状态2: Q(2, ↑)=0,    Q(2, ↓)=0,    Q(2, ←)=0,    Q(2, →)=-5.0
其他状态: 全部为0

3.4 回合2详细执行

步骤1

当前状态: S₀ = 0
选择动作: ε-贪心选择

  • Q(0, ↑)=0, Q(0, ↓)=0, Q(0, ←)=0, Q(0, →)=-0.1
  • 最大Q值为0(上、下、左),随机选择:A₀ = 1 (↓)

执行动作: 从(0,0)向下移动到(1,0)
观察结果: R₁ = -1, S₁ = 5

选择下一动作: A₁ = ε-贪心选择

  • Q(5, 所有动作) = 0,随机选择:A₁ = 3 (→)

SARSA更新:

δ = -1 + 0.9 × 0 - 0 = -1
Q(0, ↓) ← 0 + 0.1 × (-1) = -0.1
步骤2

当前状态: S₁ = 5, A₁ = 3 (→)
执行动作: 从(1,0)向右移动到(1,1) - 障碍物!
实际结果: 撞墙,停留在(1,0)
观察结果: R₂ = -1, S₂ = 5

选择下一动作: A₂ = 1 (↓)

SARSA更新:

δ = -1 + 0.9 × 0 - 0 = -1
Q(5, →) ← 0 + 0.1 × (-1) = -0.1
继续执行…

经过多步后,智能体学会避开障碍物和陷阱,最终可能到达目标。

3.5 回合3详细执行

随着学习的进行,Q表逐渐收敛,智能体行为变得更加智能。

回合3结束后的部分Q表:

状态0: Q(0, ↑)=0,     Q(0, ↓)=-0.19, Q(0, ←)=0,     Q(0, →)=-0.19
状态1: Q(1, ↑)=0,     Q(1, ↓)=-0.19, Q(1, ←)=-0.1,  Q(1, →)=-0.28
状态2: Q(2, ↑)=0,     Q(2, ↓)=-0.1,  Q(2, ←)=-0.1,  Q(2, →)=-5.0
状态5: Q(5, ↑)=-0.1,  Q(5, ↓)=-0.1,  Q(5, ←)=0,     Q(5, →)=-0.19
...

4. Q-learning算法详细执行

4.1 初始化

# Q表初始化 (与SARSA相同)
Q_qlearning = np.zeros((25, 4))
Q_qlearning[GOAL_STATE, :] = 0
Q_qlearning[TRAP_STATE, :] = 0

4.2 回合1详细执行

步骤1

当前状态: S₀ = 0
选择动作: A₀ = 3 (→) (与SARSA相同的随机选择)

执行动作: 从(0,0)向右移动到(0,1)
观察结果: R₁ = -1, S₁ = 1

Q-learning更新 (关键差异):

δ = R₁ + γ × max_a Q(S₁, a) - Q(S₀, A₀)
δ = -1 + 0.9 × max{Q(1, ↑), Q(1, ↓), Q(1, ←), Q(1, →)} - Q(0, →)
δ = -1 + 0.9 × max{0, 0, 0, 0} - 0 = -1

Q(0, →) ← 0 + 0.1 × (-1) = -0.1

注意: Q-learning使用max操作,而不是实际选择的下一动作。

步骤2

当前状态: S₁ = 1
选择动作: A₁ = 1 (↓) (行为策略选择)

执行动作: 撞墙,停留在状态1
观察结果: R₂ = -1, S₂ = 1

Q-learning更新:

δ = -1 + 0.9 × max{Q(1, ↑), Q(1, ↓), Q(1, ←), Q(1, →)} - Q(1, ↓)
δ = -1 + 0.9 × 0 - 0 = -1

Q(1, ↓) ← 0 + 0.1 × (-1) = -0.1
步骤3和步骤4

类似地执行,但在更新时始终使用max操作。

回合1结束后的Q表 (与SARSA相同):

状态0: Q(0, →) = -0.1
状态1: Q(1, ↓) = -0.1, Q(1, →) = -0.1
状态2: Q(2, →) = -5.0

4.3 关键差异体现

随着学习进行,Q-learning和SARSA的差异会逐渐显现:

Q-learning特点:

  • 总是使用下一状态的最大Q值进行更新
  • 学习最优策略,不受当前行为策略影响
  • 在有"陷阱"的环境中更激进

SARSA特点:

  • 使用实际选择的下一动作进行更新
  • 学习当前策略的价值
  • 在有"陷阱"的环境中更保守

5. Python伪代码实现

5.1 环境类

import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

class GridWorld:
    def __init__(self):
        self.grid_size = 5
        self.start_state = 0
        self.goal_state = 19
        self.trap_state = 3
        self.obstacles = [6, 12]
        
        # 奖励设置
        self.reward_goal = 100
        self.reward_trap = -50
        self.reward_step = -1
        self.reward_wall = -1
        
        # 动作映射
        self.actions = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}  # 上下左右
        self.action_names = ['↑', '↓', '←', '→']
        
    def state_to_coord(self, state):
        """状态编号转坐标"""
        return (state // self.grid_size, state % self.grid_size)
    
    def coord_to_state(self, row, col):
        """坐标转状态编号"""
        return row * self.grid_size + col
    
    def is_valid_state(self, state):
        """检查状态是否有效"""
        if state < 0 or state >= self.grid_size * self.grid_size:
            return False
        if state in self.obstacles:
            return False
        return True
    
    def get_next_state(self, state, action):
        """获取执行动作后的下一状态"""
        row, col = self.state_to_coord(state)
        d_row, d_col = self.actions[action]
        new_row, new_col = row + d_row, col + d_col
        
        # 检查边界
        if new_row < 0 or new_row >= self.grid_size or new_col < 0 or new_col >= self.grid_size:
            return state  # 撞墙,停留原地
        
        new_state = self.coord_to_state(new_row, new_col)
        
        # 检查障碍物
        if new_state in self.obstacles:
            return state  # 撞障碍物,停留原地
        
        return new_state
    
    def get_reward(self, state, action, next_state):
        """获取奖励"""
        if next_state == self.goal_state:
            return self.reward_goal
        elif next_state == self.trap_state:
            return self.reward_trap
        elif next_state == state:  # 撞墙或障碍物
            return self.reward_wall
        else:
            return self.reward_step
    
    def is_terminal(self, state):
        """检查是否为终止状态"""
        return state == self.goal_state or state == self.trap_state
    
    def reset(self):
        """重置环境"""
        return self.start_state
    
    def step(self, state, action):
        """执行一步"""
        next_state = self.get_next_state(state, action)
        reward = self.get_reward(state, action, next_state)
        done = self.is_terminal(next_state)
        return next_state, reward, done

5.2 SARSA算法实现

class SARSAAgent:
    def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1):
        self.env = env
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        
        # 初始化Q表
        self.Q = np.zeros((env.grid_size * env.grid_size, 4))
        
        # 记录学习过程
        self.episode_rewards = []
        self.episode_steps = []
        self.q_history = []
        
    def epsilon_greedy(self, state):
        """ε-贪心策略"""
        if np.random.random() < self.epsilon:
            return np.random.randint(4)
        else:
            return np.argmax(self.Q[state])
    
    def update_q(self, state, action, reward, next_state, next_action):
        """SARSA Q值更新"""
        td_error = reward + self.gamma * self.Q[next_state, next_action] - self.Q[state, action]
        self.Q[state, action] += self.alpha * td_error
        return td_error
    
    def train_episode(self, max_steps=50, verbose=False):
        """训练一个回合"""
        state = self.env.reset()
        action = self.epsilon_greedy(state)
        
        episode_reward = 0
        steps = 0
        
        if verbose:
            print(f"\n=== SARSA 回合开始 ===")
            print(f"初始状态: {state} {self.env.state_to_coord(state)}")
        
        while steps < max_steps:
            # 执行动作
            next_state, reward, done = self.env.step(state, action)
            episode_reward += reward
            steps += 1
            
            if verbose:
                print(f"\n步骤 {steps}:")
                print(f"  状态: {state} -> 动作: {self.env.action_names[action]} -> 下一状态: {next_state}")
                print(f"  奖励: {reward}")
            
            if done:
                # 终止状态更新
                td_error = self.update_q(state, action, reward, next_state, 0)
                if verbose:
                    print(f"  TD误差: {td_error:.3f}")
                    print(f"  Q({state}, {self.env.action_names[action]}) 更新为: {self.Q[state, action]:.3f}")
                break
            else:
                # 选择下一动作
                next_action = self.epsilon_greedy(next_state)
                
                # SARSA更新
                td_error = self.update_q(state, action, reward, next_state, next_action)
                
                if verbose:
                    print(f"  下一动作: {self.env.action_names[next_action]}")
                    print(f"  TD误差: {td_error:.3f}")
                    print(f"  Q({state}, {self.env.action_names[action]}) 更新为: {self.Q[state, action]:.3f}")
                
                # 移动到下一状态
                state = next_state
                action = next_action
        
        self.episode_rewards.append(episode_reward)
        self.episode_steps.append(steps)
        
        if verbose:
            print(f"\n回合结束: 总奖励 = {episode_reward}, 总步数 = {steps}")
        
        return episode_reward, steps
    
    def train(self, episodes, verbose_episodes=None):
        """训练多个回合"""
        if verbose_episodes is None:
            verbose_episodes = []
        
        for episode in range(episodes):
            verbose = episode in verbose_episodes
            if verbose:
                print(f"\n{'='*50}")
                print(f"SARSA 训练回合 {episode + 1}")
                print(f"{'='*50}")
            
            reward, steps = self.train_episode(verbose=verbose)
            
            # 记录Q表历史
            if episode in verbose_episodes:
                self.q_history.append(self.Q.copy())
                self.print_q_table()
    
    def print_q_table(self, states=None):
        """打印Q表"""
        if states is None:
            states = [0, 1, 2, 5]  # 打印关键状态
        
        print("\n当前Q表 (部分):")
        for state in states:
            if state < len(self.Q):
                coord = self.env.state_to_coord(state)
                print(f"状态 {state} {coord}: ", end="")
                for action in range(4):
                    action_name = self.env.action_names[action]
                    q_value = self.Q[state, action]
                    print(f"Q({action_name})={q_value:.3f} ", end="")
                print()

5.3 Q-learning算法实现

class QLearningAgent:
    def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1):
        self.env = env
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        
        # 初始化Q表
        self.Q = np.zeros((env.grid_size * env.grid_size, 4))
        
        # 记录学习过程
        self.episode_rewards = []
        self.episode_steps = []
        self.q_history = []
    
    def epsilon_greedy(self, state):
        """ε-贪心策略"""
        if np.random.random() < self.epsilon:
            return np.random.randint(4)
        else:
            return np.argmax(self.Q[state])
    
    def update_q(self, state, action, reward, next_state):
        """Q-learning Q值更新"""
        max_next_q = np.max(self.Q[next_state]) if not self.env.is_terminal(next_state) else 0
        td_error = reward + self.gamma * max_next_q - self.Q[state, action]
        self.Q[state, action] += self.alpha * td_error
        return td_error
    
    def train_episode(self, max_steps=50, verbose=False):
        """训练一个回合"""
        state = self.env.reset()
        
        episode_reward = 0
        steps = 0
        
        if verbose:
            print(f"\n=== Q-learning 回合开始 ===")
            print(f"初始状态: {state} {self.env.state_to_coord(state)}")
        
        while steps < max_steps:
            # 选择动作
            action = self.epsilon_greedy(state)
            
            # 执行动作
            next_state, reward, done = self.env.step(state, action)
            episode_reward += reward
            steps += 1
            
            if verbose:
                print(f"\n步骤 {steps}:")
                print(f"  状态: {state} -> 动作: {self.env.action_names[action]} -> 下一状态: {next_state}")
                print(f"  奖励: {reward}")
            
            # Q-learning更新
            td_error = self.update_q(state, action, reward, next_state)
            
            if verbose:
                max_next_q = np.max(self.Q[next_state]) if not done else 0
                print(f"  max Q(s', a') = {max_next_q:.3f}")
                print(f"  TD误差: {td_error:.3f}")
                print(f"  Q({state}, {self.env.action_names[action]}) 更新为: {self.Q[state, action]:.3f}")
            
            if done:
                break
            
            state = next_state
        
        self.episode_rewards.append(episode_reward)
        self.episode_steps.append(steps)
        
        if verbose:
            print(f"\n回合结束: 总奖励 = {episode_reward}, 总步数 = {steps}")
        
        return episode_reward, steps
    
    def train(self, episodes, verbose_episodes=None):
        """训练多个回合"""
        if verbose_episodes is None:
            verbose_episodes = []
        
        for episode in range(episodes):
            verbose = episode in verbose_episodes
            if verbose:
                print(f"\n{'='*50}")
                print(f"Q-learning 训练回合 {episode + 1}")
                print(f"{'='*50}")
            
            reward, steps = self.train_episode(verbose=verbose)
            
            # 记录Q表历史
            if episode in verbose_episodes:
                self.q_history.append(self.Q.copy())
                self.print_q_table()
    
    def print_q_table(self, states=None):
        """打印Q表"""
        if states is None:
            states = [0, 1, 2, 5]  # 打印关键状态
        
        print("\n当前Q表 (部分):")
        for state in states:
            if state < len(self.Q):
                coord = self.env.state_to_coord(state)
                print(f"状态 {state} {coord}: ", end="")
                for action in range(4):
                    action_name = self.env.action_names[action]
                    q_value = self.Q[state, action]
                    print(f"Q({action_name})={q_value:.3f} ", end="")
                print()

5.4 主程序

def main():
    # 创建环境
    env = GridWorld()
    
    # 设置随机种子以便复现
    np.random.seed(42)
    
    # 创建智能体
    sarsa_agent = SARSAAgent(env)
    qlearning_agent = QLearningAgent(env)
    
    print("5x5 复杂网格世界 SARSA vs Q-learning 详细示例")
    print("="*60)
    
    # 打印环境信息
    print("\n环境设置:")
    print(f"起始状态: {env.start_state} {env.state_to_coord(env.start_state)}")
    print(f"目标状态: {env.goal_state} {env.state_to_coord(env.goal_state)} (奖励: +{env.reward_goal})")
    print(f"陷阱状态: {env.trap_state} {env.state_to_coord(env.trap_state)} (奖励: {env.reward_trap})")
    print(f"障碍物: {env.obstacles} {[env.state_to_coord(obs) for obs in env.obstacles]}")
    print(f"普通移动奖励: {env.reward_step}")
    
    # 训练SARSA (显示前3个回合的详细过程)
    print("\n" + "="*60)
    print("SARSA 算法训练")
    print("="*60)
    sarsa_agent.train(episodes=3, verbose_episodes=[0, 1, 2])
    
    # 重置随机种子
    np.random.seed(42)
    
    # 训练Q-learning (显示前3个回合的详细过程)
    print("\n" + "="*60)
    print("Q-learning 算法训练")
    print("="*60)
    qlearning_agent.train(episodes=3, verbose_episodes=[0, 1, 2])
    
    # 比较结果
    print("\n" + "="*60)
    print("算法对比")
    print("="*60)
    
    print("\n回合奖励对比:")
    for i in range(3):
        sarsa_reward = sarsa_agent.episode_rewards[i]
        qlearning_reward = qlearning_agent.episode_rewards[i]
        print(f"回合 {i+1}: SARSA = {sarsa_reward:6.1f}, Q-learning = {qlearning_reward:6.1f}")
    
    print("\n最终Q表对比 (关键状态):")
    states_to_compare = [0, 1, 2, 5, 10]
    
    for state in states_to_compare:
        coord = env.state_to_coord(state)
        print(f"\n状态 {state} {coord}:")
        print("  SARSA   : ", end="")
        for action in range(4):
            action_name = env.action_names[action]
            q_value = sarsa_agent.Q[state, action]
            print(f"Q({action_name})={q_value:6.3f} ", end="")
        print()
        print("  Q-learn : ", end="")
        for action in range(4):
            action_name = env.action_names[action]
            q_value = qlearning_agent.Q[state, action]
            print(f"Q({action_name})={q_value:6.3f} ", end="")
        print()

if __name__ == "__main__":
    main()

6. 算法对比分析

6.1 更新公式对比

SARSA更新:

# 使用实际选择的下一动作
td_error = reward + gamma * Q[next_state, next_action] - Q[state, action]
Q[state, action] += alpha * td_error

Q-learning更新:

# 使用下一状态的最大Q值
max_next_q = np.max(Q[next_state])
td_error = reward + gamma * max_next_q - Q[state, action]
Q[state, action] += alpha * td_error

6.2 数值差异示例

假设在某个状态下:

  • Q(next_state, ↑) = -0.2
  • Q(next_state, ↓) = -0.1 ← 实际选择的动作
  • Q(next_state, ←) = -0.3
  • Q(next_state, →) = -0.05 ← 最大Q值

SARSA使用: Q(next_state, ↓) = -0.1
Q-learning使用: max Q(next_state, ·) = -0.05

这导致Q-learning的更新更加"乐观",因为它总是假设下一步会选择最优动作。

6.3 行为差异

在我们的5×5网格世界中:

SARSA行为特点:

  1. 更保守,倾向于远离陷阱
  2. 学习过程中会考虑探索的风险
  3. 最终策略可能不是严格最优,但更安全

Q-learning行为特点:

  1. 更激进,敢于走险路
  2. 不受探索策略影响
  3. 最终策略趋向于真正的最优策略

6.4 收敛性对比

理论收敛:

  • 两种算法在满足一定条件下都能收敛
  • Q-learning收敛到最优Q*函数
  • SARSA收敛到当前策略的Q^π函数

实际收敛速度:

  • Q-learning通常收敛更快
  • SARSA收敛速度依赖于探索策略

7. 收敛过程可视化

7.1 Q值演化图

def plot_q_evolution(sarsa_agent, qlearning_agent):
    """绘制Q值演化过程"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 选择关键状态-动作对进行可视化
    key_state_actions = [(0, 3), (1, 1), (2, 3), (5, 1)]  # (state, action)
    
    for i, (state, action) in enumerate(key_state_actions):
        row, col = i // 2, i % 2
        ax = axes[row, col]
        
        # 提取Q值历史
        sarsa_q_history = [q_table[state, action] for q_table in sarsa_agent.q_history]
        qlearning_q_history = [q_table[state, action] for q_table in qlearning_agent.q_history]
        
        episodes = range(1, len(sarsa_q_history) + 1)
        
        ax.plot(episodes, sarsa_q_history, 'b-o', label='SARSA', linewidth=2, markersize=6)
        ax.plot(episodes, qlearning_q_history, 'r-s', label='Q-learning', linewidth=2, markersize=6)
        
        coord = sarsa_agent.env.state_to_coord(state)
        action_name = sarsa_agent.env.action_names[action]
        ax.set_title(f'Q({state}{coord}, {action_name}) 演化', fontsize=12)
        ax.set_xlabel('回合')
        ax.set_ylabel('Q值')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('q_evolution.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_reward_comparison(sarsa_agent, qlearning_agent):
    """绘制奖励对比"""
    episodes = range(1, len(sarsa_agent.episode_rewards) + 1)
    
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.plot(episodes, sarsa_agent.episode_rewards, 'b-o', label='SARSA', linewidth=2)
    plt.plot(episodes, qlearning_agent.episode_rewards, 'r-s', label='Q-learning', linewidth=2)
    plt.title('每回合奖励对比')
    plt.xlabel('回合')
    plt.ylabel('累积奖励')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(episodes, sarsa_agent.episode_steps, 'b-o', label='SARSA', linewidth=2)
    plt.plot(episodes, qlearning_agent.episode_steps, 'r-s', label='Q-learning', linewidth=2)
    plt.title('每回合步数对比')
    plt.xlabel('回合')
    plt.ylabel('步数')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('reward_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

7.2 策略可视化

def visualize_policy(agent, title):
    """可视化学习到的策略"""
    policy = np.zeros((agent.env.grid_size, agent.env.grid_size), dtype=int)
    
    for row in range(agent.env.grid_size):
        for col in range(agent.env.grid_size):
            state = agent.env.coord_to_state(row, col)
            if state not in agent.env.obstacles and not agent.env.is_terminal(state):
                policy[row, col] = np.argmax(agent.Q[state])
            else:
                policy[row, col] = -1  # 标记障碍物和终止状态
    
    # 创建可视化
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # 绘制网格
    for i in range(agent.env.grid_size + 1):
        ax.axhline(i - 0.5, color='black', linewidth=1)
        ax.axvline(i - 0.5, color='black', linewidth=1)
    
    # 绘制状态和策略
    arrow_symbols = ['↑', '↓', '←', '→']
    colors = ['lightblue', 'lightgreen', 'lightcoral', 'lightyellow']
    
    for row in range(agent.env.grid_size):
        for col in range(agent.env.grid_size):
            state = agent.env.coord_to_state(row, col)
            
            if state == agent.env.start_state:
                ax.add_patch(plt.Rectangle((col-0.5, row-0.5), 1, 1, facecolor='green', alpha=0.7))
                ax.text(col, row, 'S', ha='center', va='center', fontsize=16, fontweight='bold')
            elif state == agent.env.goal_state:
                ax.add_patch(plt.Rectangle((col-0.5, row-0.5), 1, 1, facecolor='gold', alpha=0.7))
                ax.text(col, row, 'G', ha='center', va='center', fontsize=16, fontweight='bold')
            elif state == agent.env.trap_state:
                ax.add_patch(plt.Rectangle((col-0.5, row-0.5), 1, 1, facecolor='red', alpha=0.7))
                ax.text(col, row, 'T', ha='center', va='center', fontsize=16, fontweight='bold')
            elif state in agent.env.obstacles:
                ax.add_patch(plt.Rectangle((col-0.5, row-0.5), 1, 1, facecolor='black', alpha=0.8))
                ax.text(col, row, 'X', ha='center', va='center', fontsize=16, fontweight='bold', color='white')
            else:
                action = policy[row, col]
                ax.add_patch(plt.Rectangle((col-0.5, row-0.5), 1, 1, facecolor=colors[action], alpha=0.5))
                ax.text(col, row, arrow_symbols[action], ha='center', va='center', fontsize=20, fontweight='bold')
    
    ax.set_xlim(-0.5, agent.env.grid_size - 0.5)
    ax.set_ylim(-0.5, agent.env.grid_size - 0.5)
    ax.set_aspect('equal')
    ax.invert_yaxis()  # 让(0,0)在左上角
    ax.set_title(f'{title} 学习到的策略', fontsize=16)
    ax.set_xticks([])
    ax.set_yticks([])
    
    plt.tight_layout()
    plt.savefig(f'{title.lower()}_policy.png', dpi=300, bbox_inches='tight')
    plt.show()

# 使用示例
def visualize_results(sarsa_agent, qlearning_agent):
    """可视化所有结果"""
    plot_q_evolution(sarsa_agent, qlearning_agent)
    plot_reward_comparison(sarsa_agent, qlearning_agent)
    visualize_policy(sarsa_agent, 'SARSA')
    visualize_policy(qlearning_agent, 'Q-learning')

7.3 收敛分析

def analyze_convergence(sarsa_agent, qlearning_agent, episodes=1000):
    """分析收敛性"""
    print("\n收敛性分析")
    print("="*50)
    
    # 训练更多回合
    print("训练1000个回合以分析收敛性...")
    
    # 继续训练SARSA
    for _ in range(episodes - len(sarsa_agent.episode_rewards)):
        sarsa_agent.train_episode()
    
    # 继续训练Q-learning
    for _ in range(episodes - len(qlearning_agent.episode_rewards)):
        qlearning_agent.train_episode()
    
    # 计算移动平均奖励
    window = 50
    sarsa_ma = np.convolve(sarsa_agent.episode_rewards, np.ones(window)/window, mode='valid')
    qlearning_ma = np.convolve(qlearning_agent.episode_rewards, np.ones(window)/window, mode='valid')
    
    # 绘制收敛曲线
    plt.figure(figsize=(12, 8))
    
    plt.subplot(2, 1, 1)
    episodes_range = range(len(sarsa_agent.episode_rewards))
    plt.plot(episodes_range, sarsa_agent.episode_rewards, 'b-', alpha=0.3, label='SARSA (原始)')
    plt.plot(episodes_range, qlearning_agent.episode_rewards, 'r-', alpha=0.3, label='Q-learning (原始)')
    
    ma_episodes = range(window-1, len(sarsa_agent.episode_rewards))
    plt.plot(ma_episodes, sarsa_ma, 'b-', linewidth=2, label=f'SARSA (MA-{window})')
    plt.plot(ma_episodes, qlearning_ma, 'r-', linewidth=2, label=f'Q-learning (MA-{window})')
    
    plt.title('收敛过程:每回合奖励')
    plt.xlabel('回合')
    plt.ylabel('累积奖励')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 1, 2)
    plt.plot(episodes_range, sarsa_agent.episode_steps, 'b-', alpha=0.3, label='SARSA')
    plt.plot(episodes_range, qlearning_agent.episode_steps, 'r-', alpha=0.3, label='Q-learning')
    plt.title('收敛过程:每回合步数')
    plt.xlabel('回合')
    plt.ylabel('步数')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('convergence_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # 打印收敛统计
    print(f"\n最后100回合平均奖励:")
    print(f"  SARSA: {np.mean(sarsa_agent.episode_rewards[-100:]):.2f}")
    print(f"  Q-learning: {np.mean(qlearning_agent.episode_rewards[-100:]):.2f}")
    
    print(f"\n最后100回合平均步数:")
    print(f"  SARSA: {np.mean(sarsa_agent.episode_steps[-100:]):.2f}")
    print(f"  Q-learning: {np.mean(qlearning_agent.episode_steps[-100:]):.2f}")

总结

本示例通过5×5复杂网格世界详细展示了SARSA和Q-learning算法的:

  1. 完整的数值计算过程:每一步的Q值更新都有明确的数学计算
  2. 算法实现细节:包含完整的Python代码实现
  3. 行为差异对比:通过具体数值展示两种算法的不同特性
  4. 可视化分析:策略演化、收敛过程的图形化展示

关键洞察

  • SARSA更保守,适合安全关键应用
  • Q-learning更激进,追求最优性能
  • 两种算法的差异在复杂环境中更加明显
  • 理解算法原理有助于选择合适的方法解决实际问题
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

白云千载尽

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值