强化学习系列(2):深度Q网络(DQN)与经验回放

部署运行你感兴趣的模型镜像

强化学习系列(2):深度Q网络(DQN)与经验回放

一、Q-Learning的局限

  • 定义局限:传统的Q-Learning在面对复杂的状态空间和动作空间时,往往会面临“维度灾难”的问题。因为其需要为每个状态-动作对维护一个Q值,当状态和动作数量庞大时,所需的存储空间和学习时间会急剧增加。
  • 收敛速度局限:在大规模问题中,Q-Learning的收敛速度可能会变得很慢,导致学习效率低下,难以快速找到较优的策略。

二、深度Q网络(DQN)原理

神经网络引入思路

将神经网络引入到Q-Learning中,用神经网络来近似Q函数,也就是用网络的输出来表示给定状态下各个动作对应的Q值。这样可以处理高维的状态空间,通过网络的学习能力自动提取状态特征,从而避免手动去设计复杂的状态表示。

目标网络的作用

DQN采用了目标网络(Target Network)的机制。它与训练网络分离,训练网络用于选择动作并更新Q值,而目标网络用于提供目标Q值来计算损失函数。这样做可以使训练更加稳定,避免Q值的估计在更新过程中产生过大的波动,有助于收敛。

损失函数设计

通常采用均方误差(MSE)损失函数,即计算预测的Q值(来自训练网络)和目标Q值(来自目标网络)之间差值的平方的均值,公式如下:

loss = torch.mean((Q_predicted - Q_target) ** 2)

通过最小化这个损失函数来训练神经网络,使其输出的Q值越来越接近真实的最优Q值。


三、经验回放(Experience Replay)机制

基本概念

经验回放是DQN中一个重要的技巧,它是指智能体在与环境交互过程中,将经历的状态、动作、奖励、下一个状态等信息(即经验元组)存储到一个回放缓冲区(Replay Buffer)中。然后在训练时,从这个缓冲区中随机采样一批经验来进行学习,而不是仅仅使用当前最新的经验。

优点

  • 打破数据相关性:智能体在环境中连续交互产生的经验是具有相关性的,直接使用会导致训练不稳定。经验回放通过随机采样打破了这种相关性,让训练数据更接近独立同分布,有利于提升训练效果。
  • 提高数据利用率:每个经验可以被多次使用,避免了一些经验只被使用一次就丢弃的情况,使得数据能够被更充分地利用,尤其对于一些比较难得的经验数据。

代码示例(Python简单示意)

import random
import torch

# 定义经验回放缓冲区类
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

# 使用示例
buffer = ReplayBuffer(1000)  # 假设缓冲区容量为1000
# 在智能体与环境交互过程中,不断往缓冲区添加经验
state = env.reset()
action = env.action_space.sample()
next_state, reward, done, _ = env.step(action)
buffer.push(state, action, reward, next_state, done)
# 训练时采样一批经验
batch_size = 32
batch = buffer.sample(batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.tensor(states, dtype=torch.float)
actions = torch.tensor(actions).unsqueeze(1)
rewards = torch.tensor(rewards).unsqueeze(1)
next_states = torch.tensor(next_states, dtype=torch.float)
dones = torch.tensor(dones).unsqueeze(1)

四、用PyTorch实现Breakout游戏AI示例

环境准备

首先需要安装 gymnasium 以及 atari-py 库(用于加载Atari游戏环境),安装命令如下:

pip install gymnasium atari-py

网络结构搭建(简单示例)

import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

训练循环核心代码片段

import gymnasium as gym

# 超参数设置
learning_rate = 0.001
gamma = 0.99
batch_size = 32
num_episodes = 1000

# 创建环境和网络实例
env = gym.make('Breakout-v0')
input_size = env.observation_space.shape[0]
output_size = env.action_space.n
policy_net = DQN(input_size, output_size)
target_net = DQN(input_size, output_size)
target_net.load_state_dict(policy_net.state_dict())
optimizer = torch.optim.Adam(policy_net.parameters(), lr=learning_rate)

# 经验回放缓冲区实例化
buffer = ReplayBuffer(10000)

for episode in range(num_episodes):
    state, _ = env.reset()
    state = torch.tensor(state, dtype=torch.float).unsqueeze(0)
    done = False
    while not done:
        # 根据策略网络选择动作(这里采用简单的epsilon-greedy策略)
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            action = torch.argmax(policy_net(state)).item()

        # 执行动作,获取下一个状态、奖励等信息
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        next_state = torch.tensor(next_state, dtype=torch.float).unsqueeze(0)
        reward = torch.tensor([reward]).unsqueeze(0)

        # 将经验存入缓冲区
        buffer.push(state, action, reward, next_state, done)

        # 从缓冲区采样一批经验进行训练(当缓冲区数据足够时)
        if len(buffer) >= batch_size:
            states, actions, rewards, next_states, dones = buffer.sample(batch_size)
            # 计算目标Q值
            q_targets = rewards + gamma * torch.max(target_net(next_states), dim=1)[0].unsqueeze(1) * (1 - dones)
            # 计算预测Q值
            q_preds = policy_net(states).gather(1, actions)
            # 计算损失并更新网络
            loss = torch.mean((q_preds - q_targets) ** 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        state = next_state

    # 定期更新目标网络
    if episode % 10 == 0:
        target_net.load_state_dict(policy_net.state_dict())

env.close()

五、常见问题及解决思路

1. 训练过程中网络不收敛怎么办?

  • 检查超参数:比如学习率是否过大或者过小,过大可能导致无法收敛,过小则收敛速度极慢。可以尝试不同的学习率范围进行调试。
  • 经验回放缓冲区设置:确保缓冲区大小合适,过小可能导致数据多样性不足,过大可能占用过多内存,同时要检查采样是否正确打破了数据相关性。

2. 如何提升AI在游戏中的表现?

  • 网络结构优化:可以尝试增加网络的深度、宽度或者采用更先进的网络架构(如卷积神经网络来更好地处理图像类状态)。
  • 探索策略调整:改进epsilon-greedy等探索策略,比如采用递减的epsilon值,前期多探索,后期多利用已学习到的知识来做决策。

六、下期预告:深度Q网络的改进(Double DQN、Dueling DQN等)

在**强化学习系列(3)**中,您将深入学习到:

  1. Double DQN的原理以及如何解决DQN中的高估问题。
  2. Dueling DQN的独特结构及其优势,能更好地对状态价值和动作优势进行分离学习。
  3. 对比不同改进版本的DQN在实际环境中的性能表现及应用场景。

欢迎继续关注本系列,在评论区分享您的实践过程和遇到的问题哦! 🔔

您可能感兴趣的与本文相关的镜像

Python3.9

Python3.9

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值