SAC(Soft Actor-Critic)是一种深度强化学习算法,它结合了最大熵强化学习和基于策略梯度的方法。以下是对SAC算法的数学原理、网络架构及其PyTorch实现的详细阐述:
一、数学原理
SAC算法的核心思想是在最大化期望回报的同时,最大化策略的熵。熵是衡量策略随机性的指标,通过最大化熵,SAC算法鼓励智能体在探索和利用之间找到平衡,从而提高策略的稳定性和性能。
SAC算法的目标函数通常表示为:
J(π)=Eτ∼π[Σtr(st,at)+αH(π(⋅|st))]
其中,τ表示智能体与环境交互产生的轨迹,r(st,at)表示在状态st下执行动作at获得的奖励,H(π(⋅|st))表示策略π在状态st下的熵,α是控制熵重要性的超参数。
SAC算法通过策略梯度方法来优化上述目标函数。具体来说,它使用演员-评论家架构,其中演员网络负责生成动作策略,评论家网络评估动作价值。通过两个网络的协同优化,实现策略的逐步改进。
二、网络架构
SAC算法的网络架构通常包括以下几个部分:
演员网络:负责根据当前状态生成动作策略。对于连续动作空间,演员网络通常输出动作的均值和对数标准差,从而得到动作的分布及其对数概率。对数概率用于熵正则化,即目标函数中包含一个用于最大化概率分布广度(熵)的项,以促进智能体的探索行为。
评论家网络:负责评估动作价值。SAC算法通常包含两个评论家网络(Q1和Q2),它们分别输出给定状态和动作下的价值估计。这两个价值估计用于计算目标价值,并通过最小化贝尔曼误差来更新评论家网络的参数。
目标网络:为了稳定训练过程,SAC算法还引入了目标网络(包括目标演员网络和目标评论家网络)。目标网络是评论家网络的延迟副本,用于计算目标价值,从而减缓训练过程中的波动。
三、PyTorch实现
以下是一个简化的SAC算法的PyTorch实现示例:
python
复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
演员网络
class Actor(nn.Module):
def init(self, state_dim, action_dim):
super(Actor, self).init()
self.fc1 = nn.Linear(state_dim, 128)
self.fc2 = nn.Linear(128, 128)
self.mean_linear = nn.Linear(128, action_dim)
self.log_std_linear = nn.Linear(128, action_dim)
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
mean = self.mean_linear(x)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, min=-20, max=2)
return mean, log_std
def sample(self, state):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
z = normal.rsample() # 重参数化技巧
action = torch.tanh(z)
log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + 1e-6)
log_prob = log_prob.sum(dim=1, keepdim=True)
return action, log_prob
评论家网络
class Critic(nn.Module):
def init(self, state_dim, action_dim):
super(Critic, self).init()
self.q1_net = nn.Sequential(
nn.Linear(state_dim + action_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1),
)
self.q2_net = nn.Sequential(
nn.Linear(state_dim + action_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1),
)
def forward(self, state, action):
q1 = self.q1_net(torch.cat([state, action], dim=1))
q2 = self.q2_net(torch.cat([state, action], dim=1))
return q1, q2
SAC算法类
class SAC:
def init(self, state_dim, action_dim, lr_actor, lr_critic, gamma, alpha):
self.actor = Actor(state_dim, action_dim)
self.critic = Critic(state_dim, action_dim)
self.target_critic = Critic(state_dim, action_dim) # 目标评论家网络
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_critic)
self.gamma = gamma # 折扣因子
self.alpha = alpha # 熵正则化系数
def update_critic(self, state, action, reward, next_state, done):
# 计算目标价值
with torch.no_grad():
next_action, next_log_prob = self.actor.sample(next_state)
next_q1, next_q2 = self.target_critic(next_state, next_action)
next_q_target = torch.min(next_q1, next_q2) - self.alpha * next_log_prob
target_q = reward + (1 - done) * self.gamma * next_q_target
# 计算当前价值并更新评论家网络
current_q1, current_q2 = self.critic(state, action)
critic_loss = nn.MSELoss()(current_q1, target_q) + nn.MSELoss()(current_q2, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
def update_actor(self, state):
# 采样动作并计算损失
action, log_prob = self.actor.sample(state)
q1, q2 = self.critic(state, action)
actor_loss = -torch.min(q1, q2) - self.alpha * log_prob
# 更新演员网络
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 其他方法,如软更新目标网络等...
上述代码实现了一个简化的SAC算法,包括演员网络、评论家网络以及SAC算法类的基本框架。在实际应用中,还需要添加经验回放、目标网络软更新等机制,并根据具体问题调整超参数和网络结构。
请注意,这只是一个简化的示例,用于阐述SAC算法的基本原理和PyTorch实现方法。在实际应用中,可能需要更复杂的网络和训练策略来获得最佳性能。