强化学习经典算法笔记(八):LSTM加持的A2C算法解决POMDP问题

强化学习经典算法笔记(八):LSTM加持的A2C算法解决POMDP问题

最近用到LSTM构建Agent,找到了一个非常简明易读的示例代码。
https://github.com/HaiyinPiao/pytorch-a2clstm-DRQN

环境采用CartPole-v1。原本状态是一个4维向量,现删去第二维,即小车的速度,保留小车的位移,杆的角度和角速度,使问题从MDP问题变为POMDP(Partial Observable Markov Decision Process)问题。

代码如下:

导入必要package

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
import math
import random
import os
import gym

参数设置

STATE_DIM = 4-1;    # 删去小车速度这一维度,使之成为POMDP
ACTION_DIM = 2;     # 动作空间大小
NUM_EPISODE = 5000; # 训练的Episode数量
EPISODE_LEN = 1000; # episode最大长度
A_HIDDEN = 40;      # Actor网络的隐层神经元数量
C_HIDDEN = 40;      # Critic网络的隐层神经元数量

ActorCritic网络

# ActorNet使用LSTM + MLP估计完整的状态
class ActorNetwork(nn.Module):

    def __init__(self,in_size,hidden_size,out_size):
        super(ActorNetwork, self).__init__()
        self.lstm = nn.LSTM(in_size, hidden_size, batch_first = True)
        self.fc = nn.Linear(hidden_size,out_size)

    def forward(self, x, hidden):
        x, hidden = self.lstm(x, hidden)
        x = self.fc(x)
        x = F.log_softmax(x,2)  # log(softmax(x))
        return x, hidden
class ValueNetwork(nn.Module):

    def __init__(self,in_size,hidden_size,out_size):
        super(ValueNetwork, self).__init__()
        self.lstm = nn.LSTM(in_size, hidden_size, batch_first = True)
        self.fc = nn.Linear(hidden_size,out_size)

    def forward(self,x, hidden):
        x, hidden = self.lstm(x, hidden)
        x = self.fc(x)
        return x, hidden

完成单episode交互并记录trajectory

def roll_out(actor_network,env,episode_len,value_network,init_state):
    '''
    rollout最长1000frames
    返回:
    状态序列,不包括终态
    动作序列,独热编码
    奖励序列,不包括终态奖励
    state:游戏环境初始化后的初始状态
    '''
    states = []
    actions = []
    rewards = []
    is_done = False
    final_r = 0
    state = init_state # 初始状态
    a_hx = torch
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值