TD3 (Twin Delayed Deep Deterministic Policy Gradient) 强化学习算法详解与实现

TD3 (Twin Delayed Deep Deterministic Policy Gradient) 是一种先进的深度强化学习算法,专门针对连续动作空间问题设计。它是DDPG算法的改进版本,通过多项技术创新解决了DDPG存在的高估偏差问题。

1. TD3算法核心思想

1.1 算法关键创新点

  1. 双Q网络 (Twin Critic):使用两个独立的Q函数估计器,取最小值作为目标值,减少高估偏差

  2. 延迟策略更新 (Delayed Policy Updates):策略(Actor)更新频率低于Q函数(Critic)更新

  3. 目标策略平滑正则化 (Target Policy Smoothing):在目标动作中添加噪声,防止策略在尖锐的Q函数峰值处过拟合

1.2 与DDPG的主要区别

特性DDPGTD3
Critic数量单Q网络双Q网络
策略更新频率每次迭代都更新延迟更新(每n次)
目标策略噪声添加平滑噪声
高估偏差较严重显著减少

2. TD3算法数学原理

2.1 关键公式

  1. 目标Q值计算

    text

  • y = r + γ * min(Q₁'(s',ã), Q₂'(s',ã))
    ã = clip(π'(s') + ε, a_low, a_high), ε ~ clip(N(0,σ), -c, c)
  • Critic损失函数

    text

  • L(θᵢ) = E[(Qᵢ(s,a) - y)²], i ∈ {1,2}
  • 策略更新(仅更新一个Q函数):

    text

  1. ∇J(ϕ) = E[∇aQ₁(s,a)|a=π(s) ∇ϕπ(s)]

2.2 算法伪代码

text

初始化Critic网络Qθ1, Qθ2和Actor网络πϕ
初始化目标网络Qθ1', Qθ2', πϕ'
初始化回放缓冲区D

for 回合 = 1 to M do
    初始化状态s
    for 时间步 = 1 to T do
        选择动作a = π(s) + 噪声(探索)
        执行a,观察r, s'
        存储(s,a,r,s')到D
        
        采样小批量{(s,a,r,s')}~D
        ã = π'(s') + 裁剪的噪声
        y = r + γ * min(Qθ1'(s',ã), Qθ2'(s',ã))
        
        更新Qθ1和Qθ2最小化(Q - y)²
        
        if 时间步 % d then
            更新πϕ最大化Qθ1(s,π(s))
            软更新目标网络:
            θ' ← τθ + (1-τ)θ'
            ϕ' ← τϕ + (1-τ)ϕ'
        end if
    end for
end for

3. TD3的PyTorch实现

以下是完整的TD3实现代码:

python

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.stack(state), np.stack(action), np.stack(reward), np.stack(next_state), np.stack(done)
    
    def __len__(self):
        return len(self.buffer)

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, action_dim)
        self.max_action = max_action
    
    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        return self.max_action * torch.tanh(self.l3(a))

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)
        
        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, 256)
        self.l5 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 1)
    
    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        
        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2
    
    def Q1(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1

class TD3:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action)
        self.actor_target = Actor(state_dim, action_dim, max_action)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = Adam(self.actor.parameters(), lr=3e-4)
        
        self.critic = Critic(state_dim, action_dim)
        self.critic_target = Critic(state_dim, action_dim)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = Adam(self.critic.parameters(), lr=3e-4)
        
        self.max_action = max_action
        self.replay_buffer = ReplayBuffer(1000000)
        
        self.policy_noise = 0.2 * max_action
        self.noise_clip = 0.5 * max_action
        self.policy_freq = 2
        self.tau = 0.005
        self.gamma = 0.99
        self.batch_size = 256
        self.total_it = 0
    
    def select_action(self, state, noise=True):
        state = torch.FloatTensor(state.reshape(1, -1))
        action = self.actor(state).cpu().data.numpy().flatten()
        
        if noise:
            noise = np.random.normal(0, 0.1 * self.max_action, size=action.shape)
            action = (action + noise).clip(-self.max_action, self.max_action)
        return action
    
    def train(self):
        self.total_it += 1
        
        # 从回放缓冲区采样
        state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size)
        state = torch.FloatTensor(state)
        action = torch.FloatTensor(action)
        reward = torch.FloatTensor(reward).unsqueeze(1)
        next_state = torch.FloatTensor(next_state)
        done = torch.FloatTensor(done).unsqueeze(1)
        
        with torch.no_grad():
            # 目标策略平滑
            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
            
            # 计算目标Q值
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + (1 - done) * self.gamma * target_Q
        
        # 更新Critic
        current_Q1, current_Q2 = self.critic(state, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        # 延迟策略更新
        if self.total_it % self.policy_freq == 0:
            # 更新Actor
            actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
            
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            
            # 软更新目标网络
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            
            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    
    def save(self, filename):
        torch.save({
            'actor': self.actor.state_dict(),
            'actor_target': self.actor_target.state_dict(),
            'critic': self.critic.state_dict(),
            'critic_target': self.critic_target.state_dict(),
            'actor_optimizer': self.actor_optimizer.state_dict(),
            'critic_optimizer': self.critic_optimizer.state_dict(),
            'total_it': self.total_it
        }, filename)
    
    def load(self, filename):
        checkpoint = torch.load(filename)
        self.actor.load_state_dict(checkpoint['actor'])
        self.actor_target.load_state_dict(checkpoint['actor_target'])
        self.critic.load_state_dict(checkpoint['critic'])
        self.critic_target.load_state_dict(checkpoint['critic_target'])
        self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
        self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer'])
        self.total_it = checkpoint['total_it']

4. TD3在机器人控制中的应用

4.1 机器人抓取任务实现

python

import gym
from td3 import TD3
import numpy as np

# 创建环境 (例如OpenAI Gym的机器人环境)
env = gym.make('FetchReach-v1')  # 或自定义环境
state_dim = env.observation_space['observation'].shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

# 初始化TD3
policy = TD3(state_dim, action_dim, max_action)

# 训练参数
max_episodes = 1000
max_timesteps = 1000
exploration_noise = 0.1

# 训练循环
for episode in range(1, max_episodes+1):
    state = env.reset()
    episode_reward = 0
    
    for t in range(max_timesteps):
        # 选择动作并添加探索噪声
        action = policy.select_action(np.concatenate([
            state['observation'],
            state['achieved_goal'],
            state['desired_goal']
        ]), noise=True)
        
        # 执行动作
        next_state, reward, done, _ = env.step(action)
        
        # 存储转换
        policy.replay_buffer.push(
            np.concatenate([
                state['observation'],
                state['achieved_goal'],
                state['desired_goal']
            ]),
            action,
            reward,
            np.concatenate([
                next_state['observation'],
                next_state['achieved_goal'],
                next_state['desired_goal']
            ]),
            done
        )
        
        state = next_state
        episode_reward += reward
        
        # 训练
        if len(policy.replay_buffer) > policy.batch_size:
            policy.train()
        
        if done:
            break
    
    # 打印进度
    print(f"Episode: {episode}, Reward: {episode_reward:.2f}")
    
    # 定期保存模型
    if episode % 50 == 0:
        policy.save(f"td3_robot_grasping_{episode}.pth")

env.close()

4.2 实际应用中的调整技巧

  1. 状态表示优化

    • 使用目标物体和末端执行器的相对位置

    • 加入力/力矩传感器数据

    • 考虑历史状态信息

  2. 奖励函数设计

    python

  1. def compute_reward(achieved_goal, desired_goal, info):
        # 基于距离的奖励
        d = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
        return -(d > 0.05).astype(np.float32)  # 成功阈值5cm
        
        # 或更复杂的奖励
        # return -d + 10*(d < 0.05)  # 成功额外奖励
  2. 课程学习策略

    • 从简单场景开始(如固定位置抓取)

    • 逐步增加难度(不同位置、不同物体)

    • 动态调整探索噪声

5. TD3的改进与变体

5.1 常见改进方法

  1. 自适应噪声

    python

  • # 在训练过程中动态调整噪声大小
    if episode_reward > threshold:
        policy.policy_noise *= 0.9  # 减少噪声
        policy.noise_clip *= 0.9
  • 优先经验回放

    python

  • # 使用TD误差作为优先级
    td_error = (current_Q1 - target_Q).abs().cpu().data.numpy()
    replay_buffer.update_priority(indices, td_error + 1e-5)
  • 混合探索策略

    python

  1. # 结合OU噪声和高斯噪声
    ou_noise = OUProcess(action_dim)
    action = policy.select_action(state) + 0.6*ou_noise() + 0.4*np.random.normal(0, exploration_noise, size=action_dim)

5.2 先进变体算法

  1. SAC (Soft Actor-Critic)

    • 最大熵框架

    • 自动调整温度参数

    • 更稳定的学习过程

  2. REDQ (Randomized Ensembled Double Q-learning)

    • 使用更多Q函数(通常10个)

    • 随机选择子集进行目标计算

    • 更高样本效率

  3. DrQ (Data-regularized Q)

    • 对观察结果应用数据增强

    • 特别适合视觉输入

    • 提高样本效率

6. 调试与性能优化

6.1 常见问题解决

  1. 训练不稳定

    • 减小学习率(尝试1e-4到3e-4)

    • 增加目标网络更新频率τ(0.001到0.01)

    • 增大回放缓冲区大小

  2. 策略收敛到局部最优

    • 增加探索噪声

    • 尝试不同的网络初始化

    • 使用课程学习策略

  3. Q值爆炸

    • 添加梯度裁剪

    python

  1. torch.nn.utils.clip_grad_norm_(policy.critic.parameters(), 1.0)
    • 调整奖励缩放

6.2 性能评估指标

  1. 训练指标

    • 回合奖励(平滑处理)

    • Q值大小(监控高估情况)

    • 策略熵(探索程度)

  2. 测试指标

    • 任务成功率

    • 平均完成时间

    • 能量消耗(对机器人重要)

  3. 鲁棒性测试

    • 不同初始条件

    • 传感器噪声

    • 环境扰动

通过以上TD3实现和优化技巧,您可以构建一个高效的强化学习系统来解决复杂的机器人控制问题,如抓取、操纵和导航任务。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值