强化学习经典算法笔记(十二):近端策略优化算法(PPO)实现,基于A2C
本篇实现一个基于A2C框架的PPO算法,应用于连续动作空间任务。
import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
import gym
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Memory:
def __init__(self):
self.actions = []
self.states = []
self.logprobs = []
self.rewards = []
self.is_terminals = []
def clear_memory(self):
del self.actions[:]
del self.states[:]
del self.logprobs[:]
del self.rewards[:]
del self.is_terminals[:]
A2C的实现和上篇的区别在于动作的选择。Actor输出多变量高斯分布的均值向量,人为给定一个协方差矩阵,当然Var也可以学习出来。
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, action_std):
super(ActorCritic, self).__init__()
# action mean range -1 to 1
self.actor = nn.Sequential(
nn.Linear(state_dim, 64),
nn.Tanh(),
nn.Linear(64, 32),
nn.Tanh(),
nn.Linear(32, action_dim),
nn.Tanh()
)
# critic
self.critic = nn.Sequential(
nn.Linear(state_dim, 64),
nn.Tanh(),
nn.Linear(64, 32),
nn.Tanh(),
nn.Linear(32, 1)
)
self.action_var = torch.full((action_dim,), action_std*action_std).to(device)
def forward(self):
raise NotImplementedError
def act(self, state, memory):
action_mean = self.actor(state)
cov_mat = torch.diag(self.action_var).to(device)
dist = MultivariateNormal(action_mean, cov_mat)
action = dist

本文详细介绍并实现了基于A2C框架的近端策略优化算法(PPO),针对连续动作空间任务,通过多变量高斯分布进行动作选择,展示了PPO算法在BipedalWalker-v2环境中的应用。
最低0.47元/天 解锁文章





