强化学习经典算法笔记(十五):Soft Actor-Critic算法实现
算法简介
Soft Actor Critic,SAC算法是一种Off-policy算法,相比于PPO这种On-policy算法,sample efficiency有了提高,相比于DDPG及其变种D4PG,SAC又是一种随机策略算法。
SAC算法是在最大熵强化学习(Maximum Entropy Reinforcement Learning)的框架下构建起来的,目的是让策略随机化,好处是对于机器人控制问题非常友好,甚至可以在真实环境中使用。
策略的最大熵还意味着对策略空间、轨迹空间的探索比确定型算法要更充分,对于最优动作不止一个的状态来说,SAC就可以输出一个动作的概率分布而非确定的其中一个动作。
总结起来有三点:
- 学到的policy可以作为更复杂具体任务的初始化。
- 更强的exploration能力,这是显而易见的,能够更容易的在多模态reward (multimodal reward)下找到更好的模式。比如既要求机器人走的好,又要求机器人节约能源。
- 更robust鲁棒,更强的generalization。因为要从不同的方式来探索各种最优的可能性,也因此面对干扰的时候能够更容易做出调整。
对SAC算法的更详细解读可以参考
最前沿:深度解读Soft Actor-Critic算法。来龙去脉讲的非常详细。
PyTorch实现
import argparse
import pickle
from collections import namedtuple
from itertools import count
import os
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal,MultivariateNormal
from tensorboardX import SummaryWriter
'''
Implementation of soft actor critic, dual Q network version
Original paper: https://arxiv.org/abs/1801.01290
'''
device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser = argparse.ArgumentParser()
parser.add_argument("--env_name", default="LunarLanderContinuous-v2") # OpenAI gym environment name Pendulum-v0
parser.add_argument('--tau', default=0.005, type=float) # target smoothing coefficient
parser.add_argument('--target_update_interval', default=1, type=int)
parser.add_argument('--epoch', default=1, type=int) # 每次sample batch训练几次
parser.add_argument('--learning_rate', default=3e-4, type=int)
parser.add_argument('--gamma', default=0.99, type=int) # discount gamma
parser.add_argument('--capacity', default=10000, type=int) # replay buffer size
parser.add_argument('--num_episode', default=2000, type=int) # num of games
parser.add_argument('--batch_size', default=128, type=int) # mini batch size
parser.add_argument('--max_frame', default=500, type=int) # max frame
parser.add_argument('--seed', default=1, type=int)
# optional parameters
parser.add_argument('--hidden_size', default=64, type=int)
parser.add_argument('--render', default=False, type=bool) # show UI or not
parser.add_argument('--log_interval', default=20, type=int) # 每20episode保存1次模型
parser.add_argument('--load', default=False, type=bool) # load model
args = parser.parse_args()
class NormalizedActions(gym.ActionWrapper):
def _action(self, action):
low = self.action_space.low
high = self.action_space.high
action = low + (action + 1.0) * 0.5 * (high - low)
action = np.clip(action, low, high)
return action
def _reverse_action(self, action):
low = self.action_space.low
high = self.action_space.high
action = 2 * (action - low) / (high - low) - 1
action = np.clip(action, low, high)
return action
# env = NormalizedActions(gym.make(args.env_name))
env = gym.make(args.env_name)
# Set seeds
env.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action =