强化学习系列(2):深度Q网络(DQN)与经验回放
一、Q-Learning的局限
- 定义局限:传统的Q-Learning在面对复杂的状态空间和动作空间时,往往会面临“维度灾难”的问题。因为其需要为每个状态-动作对维护一个Q值,当状态和动作数量庞大时,所需的存储空间和学习时间会急剧增加。
- 收敛速度局限:在大规模问题中,Q-Learning的收敛速度可能会变得很慢,导致学习效率低下,难以快速找到较优的策略。
二、深度Q网络(DQN)原理
神经网络引入思路
将神经网络引入到Q-Learning中,用神经网络来近似Q函数,也就是用网络的输出来表示给定状态下各个动作对应的Q值。这样可以处理高维的状态空间,通过网络的学习能力自动提取状态特征,从而避免手动去设计复杂的状态表示。
目标网络的作用
DQN采用了目标网络(Target Network)的机制。它与训练网络分离,训练网络用于选择动作并更新Q值,而目标网络用于提供目标Q值来计算损失函数。这样做可以使训练更加稳定,避免Q值的估计在更新过程中产生过大的波动,有助于收敛。
损失函数设计
通常采用均方误差(MSE)损失函数,即计算预测的Q值(来自训练网络)和目标Q值(来自目标网络)之间差值的平方的均值,公式如下:
loss = torch.mean((Q_predicted - Q_target) ** 2)
通过最小化这个损失函数来训练神经网络,使其输出的Q值越来越接近真实的最优Q值。
三、经验回放(Experience Replay)机制
基本概念
经验回放是DQN中一个重要的技巧,它是指智能体在与环境交互过程中,将经历的状态、动作、奖励、下一个状态等信息(即经验元组)存储到一个回放缓冲区(Replay Buffer)中。然后在训练时,从这个缓冲区中随机采样一批经验来进行学习,而不是仅仅使用当前最新的经验。
优点
- 打破数据相关性:智能体在环境中连续交互产生的经验是具有相关性的,直接使用会导致训练不稳定。经验回放通过随机采样打破了这种相关性,让训练数据更接近独立同分布,有利于提升训练效果。
- 提高数据利用率:每个经验可以被多次使用,避免了一些经验只被使用一次就丢弃的情况,使得数据能够被更充分地利用,尤其对于一些比较难得的经验数据。
代码示例(Python简单示意)
import random
import torch
# 定义经验回放缓冲区类
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.position = 0
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.buffer, batch_size)
def __len__(self):
return len(self.buffer)
# 使用示例
buffer = ReplayBuffer(1000) # 假设缓冲区容量为1000
# 在智能体与环境交互过程中,不断往缓冲区添加经验
state = env.reset()
action = env.action_space.sample()
next_state, reward, done, _ = env.step(action)
buffer.push(state, action, reward, next_state, done)
# 训练时采样一批经验
batch_size = 32
batch = buffer.sample(batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.tensor(states, dtype=torch.float)
actions = torch.tensor(actions).unsqueeze(1)
rewards = torch.tensor(rewards).unsqueeze(1)
next_states = torch.tensor(next_states, dtype=torch.float)
dones = torch.tensor(dones).unsqueeze(1)
四、用PyTorch实现Breakout游戏AI示例
环境准备
首先需要安装 gymnasium 以及 atari-py 库(用于加载Atari游戏环境),安装命令如下:
pip install gymnasium atari-py
网络结构搭建(简单示例)
import torch
import torch.nn as nn
class DQN(nn.Module):
def __init__(self, input_size, output_size):
super(DQN, self).__init__()
self.fc1 = nn.Linear(input_size, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, output_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
训练循环核心代码片段
import gymnasium as gym
# 超参数设置
learning_rate = 0.001
gamma = 0.99
batch_size = 32
num_episodes = 1000
# 创建环境和网络实例
env = gym.make('Breakout-v0')
input_size = env.observation_space.shape[0]
output_size = env.action_space.n
policy_net = DQN(input_size, output_size)
target_net = DQN(input_size, output_size)
target_net.load_state_dict(policy_net.state_dict())
optimizer = torch.optim.Adam(policy_net.parameters(), lr=learning_rate)
# 经验回放缓冲区实例化
buffer = ReplayBuffer(10000)
for episode in range(num_episodes):
state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float).unsqueeze(0)
done = False
while not done:
# 根据策略网络选择动作(这里采用简单的epsilon-greedy策略)
if random.random() < epsilon:
action = env.action_space.sample()
else:
action = torch.argmax(policy_net(state)).item()
# 执行动作,获取下一个状态、奖励等信息
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
next_state = torch.tensor(next_state, dtype=torch.float).unsqueeze(0)
reward = torch.tensor([reward]).unsqueeze(0)
# 将经验存入缓冲区
buffer.push(state, action, reward, next_state, done)
# 从缓冲区采样一批经验进行训练(当缓冲区数据足够时)
if len(buffer) >= batch_size:
states, actions, rewards, next_states, dones = buffer.sample(batch_size)
# 计算目标Q值
q_targets = rewards + gamma * torch.max(target_net(next_states), dim=1)[0].unsqueeze(1) * (1 - dones)
# 计算预测Q值
q_preds = policy_net(states).gather(1, actions)
# 计算损失并更新网络
loss = torch.mean((q_preds - q_targets) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
state = next_state
# 定期更新目标网络
if episode % 10 == 0:
target_net.load_state_dict(policy_net.state_dict())
env.close()
五、常见问题及解决思路
1. 训练过程中网络不收敛怎么办?
- 检查超参数:比如学习率是否过大或者过小,过大可能导致无法收敛,过小则收敛速度极慢。可以尝试不同的学习率范围进行调试。
- 经验回放缓冲区设置:确保缓冲区大小合适,过小可能导致数据多样性不足,过大可能占用过多内存,同时要检查采样是否正确打破了数据相关性。
2. 如何提升AI在游戏中的表现?
- 网络结构优化:可以尝试增加网络的深度、宽度或者采用更先进的网络架构(如卷积神经网络来更好地处理图像类状态)。
- 探索策略调整:改进epsilon-greedy等探索策略,比如采用递减的epsilon值,前期多探索,后期多利用已学习到的知识来做决策。
六、下期预告:深度Q网络的改进(Double DQN、Dueling DQN等)
在**强化学习系列(3)**中,您将深入学习到:
- Double DQN的原理以及如何解决DQN中的高估问题。
- Dueling DQN的独特结构及其优势,能更好地对状态价值和动作优势进行分离学习。
- 对比不同改进版本的DQN在实际环境中的性能表现及应用场景。
欢迎继续关注本系列,在评论区分享您的实践过程和遇到的问题哦! 🔔
1457

被折叠的 条评论
为什么被折叠?



