TD3 (Twin Delayed Deep Deterministic Policy Gradient) 是一种先进的深度强化学习算法,专门针对连续动作空间问题设计。它是DDPG算法的改进版本,通过多项技术创新解决了DDPG存在的高估偏差问题。
1. TD3算法核心思想
1.1 算法关键创新点
-
双Q网络 (Twin Critic):使用两个独立的Q函数估计器,取最小值作为目标值,减少高估偏差
-
延迟策略更新 (Delayed Policy Updates):策略(Actor)更新频率低于Q函数(Critic)更新
-
目标策略平滑正则化 (Target Policy Smoothing):在目标动作中添加噪声,防止策略在尖锐的Q函数峰值处过拟合
1.2 与DDPG的主要区别
特性 | DDPG | TD3 |
---|---|---|
Critic数量 | 单Q网络 | 双Q网络 |
策略更新频率 | 每次迭代都更新 | 延迟更新(每n次) |
目标策略噪声 | 无 | 添加平滑噪声 |
高估偏差 | 较严重 | 显著减少 |
2. TD3算法数学原理
2.1 关键公式
-
目标Q值计算:
text
-
y = r + γ * min(Q₁'(s',ã), Q₂'(s',ã)) ã = clip(π'(s') + ε, a_low, a_high), ε ~ clip(N(0,σ), -c, c)
-
Critic损失函数:
text
-
L(θᵢ) = E[(Qᵢ(s,a) - y)²], i ∈ {1,2}
-
策略更新(仅更新一个Q函数):
text
-
∇J(ϕ) = E[∇aQ₁(s,a)|a=π(s) ∇ϕπ(s)]
2.2 算法伪代码
text
初始化Critic网络Qθ1, Qθ2和Actor网络πϕ 初始化目标网络Qθ1', Qθ2', πϕ' 初始化回放缓冲区D for 回合 = 1 to M do 初始化状态s for 时间步 = 1 to T do 选择动作a = π(s) + 噪声(探索) 执行a,观察r, s' 存储(s,a,r,s')到D 采样小批量{(s,a,r,s')}~D ã = π'(s') + 裁剪的噪声 y = r + γ * min(Qθ1'(s',ã), Qθ2'(s',ã)) 更新Qθ1和Qθ2最小化(Q - y)² if 时间步 % d then 更新πϕ最大化Qθ1(s,π(s)) 软更新目标网络: θ' ← τθ + (1-τ)θ' ϕ' ← τϕ + (1-τ)ϕ' end if end for end for
3. TD3的PyTorch实现
以下是完整的TD3实现代码:
python
import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import Adam from collections import deque import random class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size)) return np.stack(state), np.stack(action), np.stack(reward), np.stack(next_state), np.stack(done) def __len__(self): return len(self.buffer) class Actor(nn.Module): def __init__(self, state_dim, action_dim, max_action): super(Actor, self).__init__() self.l1 = nn.Linear(state_dim, 256) self.l2 = nn.Linear(256, 256) self.l3 = nn.Linear(256, action_dim) self.max_action = max_action def forward(self, state): a = F.relu(self.l1(state)) a = F.relu(self.l2(a)) return self.max_action * torch.tanh(self.l3(a)) class Critic(nn.Module): def __init__(self, state_dim, action_dim): super(Critic, self).__init__() # Q1 architecture self.l1 = nn.Linear(state_dim + action_dim, 256) self.l2 = nn.Linear(256, 256) self.l3 = nn.Linear(256, 1) # Q2 architecture self.l4 = nn.Linear(state_dim + action_dim, 256) self.l5 = nn.Linear(256, 256) self.l6 = nn.Linear(256, 1) def forward(self, state, action): sa = torch.cat([state, action], 1) q1 = F.relu(self.l1(sa)) q1 = F.relu(self.l2(q1)) q1 = self.l3(q1) q2 = F.relu(self.l4(sa)) q2 = F.relu(self.l5(q2)) q2 = self.l6(q2) return q1, q2 def Q1(self, state, action): sa = torch.cat([state, action], 1) q1 = F.relu(self.l1(sa)) q1 = F.relu(self.l2(q1)) q1 = self.l3(q1) return q1 class TD3: def __init__(self, state_dim, action_dim, max_action): self.actor = Actor(state_dim, action_dim, max_action) self.actor_target = Actor(state_dim, action_dim, max_action) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = Adam(self.actor.parameters(), lr=3e-4) self.critic = Critic(state_dim, action_dim) self.critic_target = Critic(state_dim, action_dim) self.critic_target.load_state_dict(self.critic.state_dict()) self.critic_optimizer = Adam(self.critic.parameters(), lr=3e-4) self.max_action = max_action self.replay_buffer = ReplayBuffer(1000000) self.policy_noise = 0.2 * max_action self.noise_clip = 0.5 * max_action self.policy_freq = 2 self.tau = 0.005 self.gamma = 0.99 self.batch_size = 256 self.total_it = 0 def select_action(self, state, noise=True): state = torch.FloatTensor(state.reshape(1, -1)) action = self.actor(state).cpu().data.numpy().flatten() if noise: noise = np.random.normal(0, 0.1 * self.max_action, size=action.shape) action = (action + noise).clip(-self.max_action, self.max_action) return action def train(self): self.total_it += 1 # 从回放缓冲区采样 state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size) state = torch.FloatTensor(state) action = torch.FloatTensor(action) reward = torch.FloatTensor(reward).unsqueeze(1) next_state = torch.FloatTensor(next_state) done = torch.FloatTensor(done).unsqueeze(1) with torch.no_grad(): # 目标策略平滑 noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip) next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action) # 计算目标Q值 target_Q1, target_Q2 = self.critic_target(next_state, next_action) target_Q = torch.min(target_Q1, target_Q2) target_Q = reward + (1 - done) * self.gamma * target_Q # 更新Critic current_Q1, current_Q2 = self.critic(state, action) critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # 延迟策略更新 if self.total_it % self.policy_freq == 0: # 更新Actor actor_loss = -self.critic.Q1(state, self.actor(state)).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # 软更新目标网络 for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) def save(self, filename): torch.save({ 'actor': self.actor.state_dict(), 'actor_target': self.actor_target.state_dict(), 'critic': self.critic.state_dict(), 'critic_target': self.critic_target.state_dict(), 'actor_optimizer': self.actor_optimizer.state_dict(), 'critic_optimizer': self.critic_optimizer.state_dict(), 'total_it': self.total_it }, filename) def load(self, filename): checkpoint = torch.load(filename) self.actor.load_state_dict(checkpoint['actor']) self.actor_target.load_state_dict(checkpoint['actor_target']) self.critic.load_state_dict(checkpoint['critic']) self.critic_target.load_state_dict(checkpoint['critic_target']) self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer']) self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer']) self.total_it = checkpoint['total_it']
4. TD3在机器人控制中的应用
4.1 机器人抓取任务实现
python
import gym from td3 import TD3 import numpy as np # 创建环境 (例如OpenAI Gym的机器人环境) env = gym.make('FetchReach-v1') # 或自定义环境 state_dim = env.observation_space['observation'].shape[0] action_dim = env.action_space.shape[0] max_action = float(env.action_space.high[0]) # 初始化TD3 policy = TD3(state_dim, action_dim, max_action) # 训练参数 max_episodes = 1000 max_timesteps = 1000 exploration_noise = 0.1 # 训练循环 for episode in range(1, max_episodes+1): state = env.reset() episode_reward = 0 for t in range(max_timesteps): # 选择动作并添加探索噪声 action = policy.select_action(np.concatenate([ state['observation'], state['achieved_goal'], state['desired_goal'] ]), noise=True) # 执行动作 next_state, reward, done, _ = env.step(action) # 存储转换 policy.replay_buffer.push( np.concatenate([ state['observation'], state['achieved_goal'], state['desired_goal'] ]), action, reward, np.concatenate([ next_state['observation'], next_state['achieved_goal'], next_state['desired_goal'] ]), done ) state = next_state episode_reward += reward # 训练 if len(policy.replay_buffer) > policy.batch_size: policy.train() if done: break # 打印进度 print(f"Episode: {episode}, Reward: {episode_reward:.2f}") # 定期保存模型 if episode % 50 == 0: policy.save(f"td3_robot_grasping_{episode}.pth") env.close()
4.2 实际应用中的调整技巧
-
状态表示优化:
-
使用目标物体和末端执行器的相对位置
-
加入力/力矩传感器数据
-
考虑历史状态信息
-
-
奖励函数设计:
python
-
def compute_reward(achieved_goal, desired_goal, info): # 基于距离的奖励 d = np.linalg.norm(achieved_goal - desired_goal, axis=-1) return -(d > 0.05).astype(np.float32) # 成功阈值5cm # 或更复杂的奖励 # return -d + 10*(d < 0.05) # 成功额外奖励
-
课程学习策略:
-
从简单场景开始(如固定位置抓取)
-
逐步增加难度(不同位置、不同物体)
-
动态调整探索噪声
-
5. TD3的改进与变体
5.1 常见改进方法
-
自适应噪声:
python
-
# 在训练过程中动态调整噪声大小 if episode_reward > threshold: policy.policy_noise *= 0.9 # 减少噪声 policy.noise_clip *= 0.9
-
优先经验回放:
python
-
# 使用TD误差作为优先级 td_error = (current_Q1 - target_Q).abs().cpu().data.numpy() replay_buffer.update_priority(indices, td_error + 1e-5)
-
混合探索策略:
python
-
# 结合OU噪声和高斯噪声 ou_noise = OUProcess(action_dim) action = policy.select_action(state) + 0.6*ou_noise() + 0.4*np.random.normal(0, exploration_noise, size=action_dim)
5.2 先进变体算法
-
SAC (Soft Actor-Critic):
-
最大熵框架
-
自动调整温度参数
-
更稳定的学习过程
-
-
REDQ (Randomized Ensembled Double Q-learning):
-
使用更多Q函数(通常10个)
-
随机选择子集进行目标计算
-
更高样本效率
-
-
DrQ (Data-regularized Q):
-
对观察结果应用数据增强
-
特别适合视觉输入
-
提高样本效率
-
6. 调试与性能优化
6.1 常见问题解决
-
训练不稳定:
-
减小学习率(尝试1e-4到3e-4)
-
增加目标网络更新频率τ(0.001到0.01)
-
增大回放缓冲区大小
-
-
策略收敛到局部最优:
-
增加探索噪声
-
尝试不同的网络初始化
-
使用课程学习策略
-
-
Q值爆炸:
-
添加梯度裁剪
python
-
-
torch.nn.utils.clip_grad_norm_(policy.critic.parameters(), 1.0)
-
调整奖励缩放
-
6.2 性能评估指标
-
训练指标:
-
回合奖励(平滑处理)
-
Q值大小(监控高估情况)
-
策略熵(探索程度)
-
-
测试指标:
-
任务成功率
-
平均完成时间
-
能量消耗(对机器人重要)
-
-
鲁棒性测试:
-
不同初始条件
-
传感器噪声
-
环境扰动
-
通过以上TD3实现和优化技巧,您可以构建一个高效的强化学习系统来解决复杂的机器人控制问题,如抓取、操纵和导航任务。