强化学习经典算法笔记(十四):双延迟深度确定性策略梯度算法TD3的PyTorch实现

本文深入解析了双延迟深度确定性策略梯度(TD3)算法,介绍了其如何解决过估计和高方差问题,并提供了详细的PyTorch实现代码。TD3作为DDPG的改进版,通过双重网络、目标策略平滑正则化等策略提高了策略学习的稳定性和效果。

强化学习经典算法笔记(十四):双延迟深度确定性策略梯度算法TD3的PyTorch实现

TD3算法简介

TD3是Twin Delayed Deep Deterministic policy gradient algorithm的简称,双延迟深度确定性策略梯度。从名字看出,TD3算法是DDPG的改进版本。TD3算法来自论文
Addressing Function Approximation Error in Actor-Critic Methods

TD3相对于DDPG,主要采用了以下重要改进。

  1. Double network
  2. Critic学习改进
  3. Actor学习改进
  4. target policy smoothing regularization

更详细的介绍请参考
https://zhuanlan.zhihu.com/p/111334500
https://zhuanlan.zhihu.com/p/88446488

详细介绍

解决两个问题,一个是过估计,overestimate,另一个是高方差现象,high variance。

对状态价值的过高估计是Value based方法经常遇到的问题。在Qnetwork不成熟时,对状态的估计有误差,对Q值取最大化操作时,会高于真实的最大Q值。累积下来的过高估计可能会使得算法陷入次优策略中,导致发散等行为。

TD3论文发现Actor-critic算法中也会出现overestimate现象。

解决过高估计问题的办法,文中提到了两个。一个是采取Double DQN的做法,使用target Q network和main Q network分别进行状态价值估计和选取动作,将两者解耦。另一个是Double Q-learning,即采用两个独立的Critic,分别对价值进行估计,取最小值。这个方法带来的高方差问题可以用更新时对梯度进行裁剪来改善。
y = r t + γ m i n i = 1 , 2 Q θ i ′ ( s ′ , π ϕ 1 ( s ′ ) ) y=r_t + \gamma min_{i=1,2} Q_{\theta '_i}(s',\pi_{\phi_1}(s')) y=rt+γmini=1,2Qθi(s,πϕ1(s))

解决高方差问题,文中采用了三个办法。
一个是target network。自从DQN起就在使用,通过降低critic的更新频率来降低方差。
第二个是降低Actor的更新频率,叫做Delaying policy updates。也就是将值函数和策略函数解耦。
第三个是target policy smoothing regularization,即一种正则化方法解决determinstic policy可能overfitting的问题。直观地讲,就是在计算Q值更新目标时,采用如下方式:
a ~ ← π ϕ ′ ( s ′ ) + ϵ , ϵ ∼ c l i p ( N ( 0 , σ ~ ) , − c , c ) \tilde{a} \leftarrow \pi_{\phi'}(s')+\epsilon,\quad \epsilon \sim clip(N(0,\tilde{\sigma}),-c,c) a~πϕ(s)+ϵ,ϵclip(N(0,σ~),c,c)

算法流程图

在这里插入图片描述

算法实现

import argparse
from collections import namedtuple
from itertools import count

import os, sys, random
import numpy as np

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
from tensorboardX import SummaryWriter

device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser = argparse.ArgumentParser()

parser.add_argument('--mode', default='train', type=str)   # mode = 'train' or 'test'
parser.add_argument("--env_name", default="LunarLanderContinuous-v2")  # OpenAI gym environment name, BipedalWalker-v2  Pendulum-v0
parser.add_argument('--tau',  default=0.05, type=float)    # target smoothing coefficient
parser.add_argument('--target_update_interval', default=1, type=int)
parser.add_argument('--test_episode', default=50, type=int)
parser.add_argument('--epoch', default=10, type=int)       # buffer采样的数据训练几次
parser.add_argument('--learning_rate', default=3e-4, type=float)
parser
### TD3算法足机器人中的实现 TD3(Twin Delayed Deep Deterministic Policy Gradient)是一种改进自DDPG的强化学习算法,在处理连续动作空间方面表现出色。对于足机器人的控制任务而言,这类环境通常具备复杂的动态特性以及较高的维度状态和动作空间,因此非常适合采用TD3来解决。 #### 关键技术特点 为了克服原始DDPG中存在的价值函数估计偏差问题,TD3引入了三项关键技术: 1. **重批评家网络**:通过构建两个独立的Critic模型并取两者预测值较小者作为目标更新依据,有效降低了过高的价值评估风险[^2]。 2. **延迟策略更新**:减少Actor参数调整频率至每两次Critic迭代一次,有助于稳定训练过程。 3. **目标平滑正则化**:向下一时刻的状态转移过程中加入噪声扰动,增强探索能力的同时也缓解了泛化误差的影响。 这些改进措施共同作用下,使得TD3能够在更复杂多变的任务场景中获得更好的性能表现。 #### 实现案例分析 针对`BipedalWalkerHardcore-v3`这一经典的足行走模拟器环境,可以利用PyTorch框架快速搭建起基于TD3的学习系统。以下是简化版的核心代码片段展示如何初始化相关组件及定义主要逻辑流程: ```python import gymnasium as gym import torch.nn.functional as F from torch import optim from td3_agent import Agent # 假设已有一个实现TD3 agent的模块 env = gym.make('BipedalWalkerHardcore-v3') agent = Agent(state_size=env.observation_space.shape[0], action_size=env.action_space.shape[0]) def train(n_episodes=1000): scores_deque = deque(maxlen=100) for i_episode in range(1, n_episodes+1): state = env.reset()[0] score = 0 while True: action = agent.act(state) # 获取当前状态下采取的动作 next_state, reward, done, _, _ = env.step(action) agent.step(state, action, reward, next_state, done) # 更新经验池 state = next_state # 转移到下一个状态 score += reward # 累计得分 if done: # 当回合结束时记录成绩 break scores_deque.append(score) train() ``` 此段程序展示了基本的训练循环结构,其中包含了与环境交互获取观测数据、执行决策操作并向代理传递反馈信息的过程。值得注意的是实际项目开发还需要考虑更多细节配置比如超参调优等环节。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值