最近几个月一直在看强化学习,把强化学习导论(Reinforcement Learning:An Introduction)跟着b站视频看了一遍,发现书上的大部分是理论,实践代码并不多,所以可以看看知乎David Silver强化学习公开课中文讲解及实践,看完之后呢就会发现,强化学习在沿着解决连续动作和深度卷积神经网络逼近值函数来解决问题。下面先写出目前所学到的知识,后期会继续更新。
1.Q-Learning
Q-Learning是强化学习算法中极其重要的算法,Q即为Q(s,a)就是在某一时刻的 s 状态下,采取 动作a 动作所获得的收益,环境会根据动作反馈相应的回报reward,所以Q-Learning算法的主要思想就是将状态与动作构建成一张Q-table来存储Q值,然后根据Q值来选取能够获得最大的收益的动作。
缺点:Q-learning需要一个Q table,在状态很多的情况下,Q table会很大,查找和存储都需要消耗大量的时间和空间。而且这些表中的动作都是离散的动作,不能是连续的动作。
问题:在实际问题中,环境中有许多状态,其每个状态有许多动作,这样使得Q-table显得没有那么的好。
解决方法:提出DQN
2. DQN
DQN对Q-learning的修改主要体现在以下三个方面:
(1)DQN利用深度卷积神经网络逼近值函数;
(2)DQN利用了经验回放对强化学习的学习过程进行训练;
(3)DQN独立设置了目标网络来单独处理时间差分算法中的TD偏差。
(1)DQN利用深度卷积神经网络:
网络的作用是通过输入状态和行为可直接输出这个行为的价值Q,不需要再记录表格。
(2)DQN利用经验回放:
在训练过程中,会维护一个序列样本池Dt={e1,…,et},其中et={st,at,rt,s(t+1)},et就是在状态st下,采取了动作at,转移到了状态s(t+1),得到回报rt,这样就形成了一个样本(经验)。回放的意思就是在训练中,比如让agent玩游戏,并不是把样本按照时间顺序喂给网络,而是在一局游戏未结束之前,把生成的样本(经验)都更新地扔到经验池中,从池中平均采样minBatch个作为训练样本。
好处:
这样回放机制就会减少应用于高度相关的状态序列。:因为前后样本存在关联导致的强化学习震荡和发散的问题。
(3)DQN独立设置目标网络
Off-policy是Q-Learning的特点,DQN中也延用了这一特点。而不同的是,Q-Learning中用来计算target和预测值的Q是同一个Q,也就是说使用了相同的神经网络。这样带来的一个问题就是,每次更新神经网络的时候,target也都会更新,这样会容易导致参数不收敛。回忆在有监督学习中,标签label都是固定的,不会随着参数的更新而改变。因此DQN在原来的Q网络的基础上又引入了一个target Q网络,即用来计算target的网络。它和Q网络结构一样,初始的权重也一样,只是Q网络每次迭代都会更新,而target Q网络是每隔一段时间才会更新。
最后附上DQN过程直观图:
DQN代码实现如下:(代码中有自己加上去的解释)
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym
# Hyper Parameters 参数
BATCH_SIZE = 32
LR = 0.01 # learning rate
EPSILON = 0.9 # greedy policy
GAMMA = 0.9 # reward discount
TARGET_REPLACE_ITER = 100 # target update frequency
MEMORY_CAPACITY = 2000
env = gym.make('CartPole-v0') #创造环境
env = env.unwrapped#据说不做这个动作会有很多限制,unwrapped是打开限制的意思,用env.unwrapped可以得到原始的类,原始类想step多久就多久,不会200步后失败:
N_ACTIONS = env.action_space.n #动作的个数
N_STATES = env.observation_space.shape[0] #查看这个环境中observation的特征即状态有多少个
ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape # to confirm the shape
class Net(nn.Module): #神经网络
def __init__(self, ):
super(Net, self).__init__()
self.fc1 = nn.Linear(N_STATES, 50) #输入状态
self.fc1.weight.data.normal_(0, 0.1) # initialization
self.out = nn.Linear(50, N_ACTIONS) #输出该状态下的所有动作的价值
self.out.weight.data.normal_(0, 0.1) # initialization
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
actions_value = self.out(x)
return actions_value #返回动作的价值
class DQN(object):
def __init__(self):
self.eval_net, self.target_net = Net(), Net() #两个相同的网络,参数不同,需要两个网络:target网络每隔段时间更新一次。避免参数一直更新
self.learn_step_counter = 0 #学了多少步 # for target updating
self.memory_counter = 0 #记忆库位置的计数 # for storing memory
self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # 初始化记忆库
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
self.loss_func = nn.MSELoss()
def choose_action(self, x): #根据观测值决定动作
x = torch.unsqueeze(torch.FloatTensor(x), 0) #输入观测者,用变量包起来传入神经网络
# input only one sample
if np.random.uniform() < EPSILON: # 如果概率小于一个随机数,采取贪婪算法,选取最大值
actions_value = self.eval_net.forward(x) #扔到网络中,输出行为价值
action = torch.max(actions_value, 1)[1].data.numpy() #选取最大的行为价值
# torch.max()[0]:只返回最大值的每个数
#torch.max()[1]:只返回最大值的每个索引
action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index
else: #随机选取动作
action = np.random.randint(0, N_ACTIONS) #随机选取动作
action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) #返回的动作的下标(第几个)
return action
def store_transition(self, s, a, r, s_): #记忆库,记录之前学习的东西,就是回放机制
transition = np.hstack((s, [a, r], s_)) #所有的记忆捆到一起,并且存到相对应的位置
# replace the old memory with new memory
index = self.memory_counter % MEMORY_CAPACITY #新的覆盖旧的记忆
self.memory[index, :] = transition
self.memory_counter += 1
def learn(self): #学习
# target parameter update
if self.learn_step_counter % TARGET_REPLACE_ITER == 0: #隔多少步(学习多少次)target Q网络更新一下
self.target_net.load_state_dict(self.eval_net.state_dict()) #val_net的数值复制到target_net
self.learn_step_counter += 1
# sample batch transitions
sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
b_memory = self.memory[sample_index, :] #随机抽取记忆,并进行下面的分开打包
b_s = torch.FloatTensor(b_memory[:, :N_STATES