Dyna—Q算法
6.1 简介
在强化学习中,“模型”通常指与智能体交互的环境模型,即对环境的状态转移概率和奖励函数进行建模。根据是否有环境模型,强化学习算法分为两种:基于模型的强化学习(model-based reinforcement learning)和无模型的强化学习算法(model-free reinforcement learning)。无模型强化学习根据智能体与环境交互采样到的数据直接进行策略提升或者价值估计,上一章节讨论的时序差分算法,即 S a r s a Sarsa Sarsa和Q-learning算法,便是两种无模型的强化学习算法,本书在后续章节中将要介绍的方法也大都是无模型的强化学习算法。在基于模型的强化学习中,模型可以是事先知道的,也可以是根据智能体与环境交互采样到的数据学习得到的,然后用这个模型帮助策略提升或者价值估计。动态规划算法章节讨论的两种动态规划算法,即策略迭代和价值迭代,则是基于模型的强化学习算法,在这两种算法中环境模型是事先已知的。本章即将介绍的Dyna-Q算法也是非常基础的基于模型的强化学习算法,不过它的环境模型是通过采样数据估计得到的。
强化学习算法有两个重要的评价指标:一个是算法收敛后的策略在初始状态下的期望回报,另一个是样本复杂度,即算法达到收敛结果需要在真实环境中采样的样本数量。基于模型的强化学习算法由于具有一个环境模型,智能体可以额外和环境模型进行交互,对真实环境中样本的需求量往往就会减少,因此通常会比无模型的强化学习算法具有更低的样本复杂度。但是,环境模型可能不准确,不能完全代替真实环境,因此基于模型的强化学习算法收敛后其策略的期望回报可能不如无模型的强化学习算法。
6.2 Dyna-Q
Dyna-Q算法是一个经典的基于模型的强化学习算法。如图6-1所示,Dyna-Q使用一种叫做Q-palnning的方法来基于模型生成一些模拟数据,然后用模拟数据和真实数据一起改进策略。Q-planing每次选取一个曾经访问过的状态s,采取一个曾经在该状态下执行过的动作a,通过模型得到转移后的状态s‘以及奖励r,并根据这个模拟数据 ( s , a , r , s ′ ) (s,a,r,s' ) (s,a,r,s′) ,用Q-learning的更新方式来更新价值函数。
下面我们来看一下 Dyna-Q 算法的具体流程:
- 初始化 Q ( s , a ) Q(s,a) Q(s,a),初始化模型 M ( s , a ) M(s,a) M(s,a)
- for序列 e = 1 − − > E e=1-->E e=1−−>E do
- 得到初始状态s
- **for ** t = 1 − − > T t=1-->T t=1−−>T do
- 用 ϵ − \epsilon- ϵ−贪婪策略根据Q选择当前状态s下的动作a
- 得到环境反馈的 r , s ′ r,s' r,s′
- Q ( s , a ) < − − Q ( s , a ) + α [ r + γ m a x a ′ Q ( s ′ , a ′ ) − Q ( s , , a ) ] Q(s,a)<--Q(s,a)+\alpha [r+\gamma max_{a'}Q(s',a')-Q(s,,a)] Q(s,a)<−−Q(s,a)+α[r+γmaxa′Q(s′,a′)−Q(s,,a)]
- M ( s , a ) < − − r , s ′ M(s,a)<--r,s' M(s,a)<−−r,s′
- for次数 n = 1 − − > N n=1-->N n=1−−>Ndo
- 随机选择一个曾经访问过的状态 s m s_m sm
- 采取一个曾经在状态 s m s_m sm下执行过的动作 a m a_m am
- r m , s m ′ < − − M ( s m , a m ) r_m,s'_m<--M(s_m,a_m) rm,sm′<−−M(sm,am)
- Q ( s m , a m ) < − − Q ( s m , a m ) + α [ r m + γ m a x a ′ Q ( s m ′ , a ′ ) − Q ( s m , a m ) ] Q(s_m,a_m)<--Q(s_m,a_m)+\alpha[r_m+\gamma max_{a'}Q(s'_m,a')-Q(s_m,a_m)] Q(sm,am)<−−Q(sm,am)+α[rm+γmaxa′Q(sm′,a′)−Q(sm,am)]
- end for
- s < − − s ′ s<--s' s<−−s′
- end for
- end for
可以看到,在每次与环境进行交互执行一次Q-learning之后,Dyna-Q会做n次Q-planning。其中Q-learning的次数N是一个事先可以选择的超参数,当期为0时就是普通的Q-learning。值得注意的是,上述Dyna-Q算法是执行在一个离散并确定的环境中,所以当看到一条经验数据 ( s , a , r , s ′ ) (s,a,r,s') (s,a,r,s′)时,可以直接对模型做出更新,即 M ( s , a ) < − − r , s ′ M(s,a)<--r,s' M(s,a)<−−r,s′。
6.3 Dyna-Q 代码实践
我们在悬崖漫步环境中执行过 Q-learning 算法,现在也在这个环境中实现 Dyna-Q,以方便比较。首先仍需要实现悬崖漫步的环境代码,和 5.3 节一样。
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import random
import time
class CliffWalkingEnv:
def __init__(self, ncol, nrow):
self.nrow = nrow
self.ncol = ncol
self.x = 0 # 记录当前智能体位置的横坐标
self.y = self.nrow - 1 # 记录当前智能体位置的纵坐标
def step(self, action): # 外部调用这个函数来改变当前位置
# 4种动作, change[0]:上, change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)
# 定义在左上角
change = [[0, -1], [0, 1], [-1, 0], [1, 0]]
self.x = min(self.ncol - 1, max(0, self.x + change[action][0]))
self.y = min(self.nrow - 1, max(0, self.y + change[action][1]))
next_state = self.y * self.ncol + self.x
reward = -1
done = False
if self.y == self.nrow - 1 and self.x > 0: # 下一个位置在悬崖或者
# 目标
done = True
if self.x != self.ncol - 1:
reward = -100
return next_state, reward, done
def reset(self): # 回归初始状态,起点在左上角
self.x = 0
self.y = self.nrow - 1
return self.y * self.ncol + self.x
class DynaQ:
""" Dyna-Q算法 """
def __init__(self, ncol, nrow, epsilon, alpha, gamma, n_planning, n_action=4):
self.Q_table = np.zeros([nrow * ncol, n_action]) # 初始化Q(s,a)表格
self.n_action = n_action # 动作个数
self.alpha = alpha # 学习率
self.gamma = gamma # 折扣因子
self.epsilon = epsilon # epsilon-贪婪策略中的参数
self.n_planning = n_planning # 执行Q-planning的次数, 对应1次Q-learning
self.model = dict() # 环境模型
def take_action(self, state): # 选取下一步的操作
if np.random.random() < self.epsilon:
action = np.random.randint(self.n_action)
else:
action = np.argmax(self.Q_table[state])
return action
def q_learning(self, s0, a0, r, s1): # 在上一节中的Qlearning算法中 是根据贪婪策略进行选择的
td_error = r + self.gamma * self.Q_table[s1].max() - self.Q_table[s0, a0]
self.Q_table[s0, a0] += self.alpha * td_error
def update(self, s0, a0, r, s1, n):
self.q_learning(s0, a0, r, s1)
self.model[(s0, a0)] = r, s1 # 将数据添加到模型中
for _ in range(self.n_planning): # Q-planning循环
# 随机选择曾经遇到过的状态动作对
(s, a), (r, s_) = random.choice(list(self.model.items()))
self.q_learning(s, a, r, s_)
# print("第%d次", n)
# print(self.Q_table) # 状态价值表格
# print(self.model) # model的格式是字典 有当前状态下的序号和下一个动作 存储的内容是下一步状态和奖励(36, 0): (-1, 24)
def DynaQ_CliffWalking(n_planning):
ncol = 12
nrow = 4
env = CliffWalkingEnv(ncol, nrow)
epsilon = 0.01
alpha = 0.1
gamma = 0.9
agent = DynaQ(ncol, nrow, epsilon, alpha, gamma, n_planning)
num_episodes = 300 # 智能体在环境中运行多少条序列
return_list = [] # 记录每一条序列的回报
for i in range(10): # 显示10个进度条
# tqdm的进度条功能
with tqdm(total=int(num_episodes / 10),
desc='Iteration %d' % i) as pbar:
for i_episode in range(int(num_episodes / 10)): # 每个进度条的序列数
episode_return = 0
state = env.reset()
done = False
n = 0
while not done:
action = agent.take_action(state)
next_state, reward, done = env.step(action)
episode_return += reward # 这里回报的计算不进行折扣因子衰减
n = n+1
agent.update(state, action, reward, next_state, n)
state = next_state
return_list.append(episode_return)
if (i_episode + 1) % 10 == 0: # 每10条序列打印一下这10条序列的平均回报
pbar.set_postfix({
'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),
'return': '%.3f' % np.mean(return_list[-10:])
})
pbar.update(1)
return return_list
np.random.seed(0)
random.seed(0)
n_planning_list = [0, 2, 20]
for n_planning in n_planning_list:
print('Q-planning步数为:%d' % n_planning)
time.sleep(0.5)
return_list = DynaQ_CliffWalking(n_planning)
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list,
return_list,
label=str(n_planning) + ' planning steps')
plt.legend()
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Dyna-Q on {}'.format('Cliff Walking'))
plt.show()
打印的model
{(36, 0): (-1, 24), (24, 0): (-1, 12), (12, 0): (-1, 0), (0, 0): (-1, 0), (0, 1): (-1, 12), (12, 1): (-1, 24), (24, 1): (-1, 36), (36, 1): (-1, 36), (36, 2): (-1, 36), (36, 3): (-100, 37), (24, 2): (-1, 24), (24, 3): (-1, 25), (25, 0): (-1, 13), (13, 0): (-1, 1), (1, 0): (-1, 1), (1, 1): (-1, 13), (13, 1): (-1, 25), (25, 1): (-100, 37), (25, 2): (-1, 24), (25, 3): (-1, 26), (26, 0): (-1, 14), (14, 0): (-1, 2), (2, 0): (-1, 2), (2, 1): (-1, 14), (14, 1): (-1, 26), (26, 1): (-100, 38), (12, 2): (-1, 12), (12, 3): (-1, 13), (13, 2): (-1, 12), (13, 3): (-1, 14), (14, 2): (-1, 13), (14, 3): (-1, 15), (15, 0): (-1, 3), (3, 0): (-1, 3), (3, 1): (-1, 15), (15, 1): (-1, 27), (27, 0): (-1, 15), (15, 2): (-1, 14), (15, 3): (-1, 16), (16, 0): (-1, 4), (4, 0): (-1, 4), (4, 1): (-1, 16), (16, 1): (-1, 28), (28, 0): (-1, 16), (16, 2): (-1, 15), (16, 3): (-1, 17), (17, 0): (-1, 5), (5, 0): (-1, 5), (5, 1): (-1, 17), (17, 1): (-1, 29), (29, 0): (-1, 17), (17, 2): (-1, 16), (17, 3): (-1, 18), (18, 0): (-1, 6), (6, 0): (-1, 6), (6, 1): (-1, 18), (18, 1): (-1, 30), (30, 0): (-1, 18), (18, 2): (-1, 17), (29, 1): (-100, 41), (0, 2): (-1, 0), (0, 3): (-1, 1), (1, 2): (-1, 0), (1, 3): (-1, 2), (2, 2): (-1, 1), (2, 3): (-1, 3), (3, 2): (-1, 2), (3, 3): (-1, 4), (4, 2): (-1, 3), (4, 3): (-1, 5), (5, 2): (-1, 4), (5, 3): (-1, 6), (6, 2): (-1, 5), (6, 3): (-1, 7), (7, 0): (-1, 7), (7, 1): (-1, 19), (19, 0): (-1, 7), (7, 2): (-1, 6), (7, 3): (-1, 8), (8, 0): (-1, 8), (8, 1): (-1, 20), (20, 0): (-1, 8), (8, 2): (-1, 7), (19, 1): (-1, 31), (31, 0): (-1, 19), (19, 2): (-1, 18), (18, 3): (-1, 19), (19, 3): (-1, 20), (20, 1): (-1, 32), (32, 0): (-1, 20), (20, 2): (-1, 19), (20, 3): (-1, 21), (21, 0): (-1, 9), (9, 0): (-1, 9), (9, 1): (-1, 21), (21, 1): (-1, 33), (33, 0): (-1, 21), (21, 2): (-1, 20), (31, 1): (-100, 43), (26, 2): (-1, 25), (26, 3): (-1, 27), (27, 1): (-100, 39), (27, 2): (-1, 26), (27, 3): (-1, 28), (28, 1): (-100, 40), (28, 2): (-1, 27), (28, 3): (-1, 29), (29, 2): (-1, 28), (29, 3): (-1, 30), (30, 1): (-100, 42), (30, 2): (-1, 29), (30, 3): (-1, 31), (31, 2): (-1, 30), (31, 3): (-1, 32), (32, 1): (-100, 44), (32, 2): (-1, 31), (32, 3): (-1, 33), (33, 1): (-100, 45), (33, 2): (-1, 32), (33, 3): (-1, 34), (34, 0): (-1, 22), (22, 0): (-1, 10), (10, 0): (-1, 10), (10, 1): (-1, 22), (22, 1): (-1, 34), (34, 1): (-100, 46), (34, 2): (-1, 33), (34, 3): (-1, 35), (35, 0): (-1, 23), (23, 0): (-1, 11), (11, 0): (-1, 11), (11, 1): (-1, 23), (23, 1): (-1, 35), (35, 1): (-1, 47), (35, 2): (-1, 34), (35, 3): (-1, 35), (21, 3): (-1, 22), (22, 2): (-1, 21), (22, 3): (-1, 23), (23, 2): (-1, 22), (23, 3): (-1, 23), (11, 2): (-1, 10), (10, 2): (-1, 9), (9, 2): (-1, 8), (8, 3): (-1, 9), (9, 3): (-1, 10), (10, 3): (-1, 11), (11, 3): (-1, 11)}
Q_table的打印输出:
[[ -7.94108868 -7.71232075 -7.94108868 -7.71232075]
[ -7.71232075 -7.45813417 -7.94108868 -7.45813417]
[ -7.45813417 -7.17570464 -7.71232075 -7.17570464]
[ -7.17570464 -6.86189404 -7.45813417 -6.86189404]
[ -6.86189404 -6.5132156 -7.17570464 -6.5132156 ]
[ -6.5132156 -6.12579511 -6.86189404 -6.12579511]
[ -6.12579511 -5.6953279 -6.5132156 -5.6953279 ]
[ -5.6953279 -5.217031 -6.12579511 -5.217031 ]
[ -5.217031 -4.68559 -5.6953279 -4.68559 ]
[ -4.68559 -4.0951 -5.217031 -4.0951 ]
[ -4.0951 -3.439 -4.68559 -3.439 ]
[ -3.439 -2.71 -4.0951 -3.439 ]
[ -7.94108868 -7.45813417 -7.71232075 -7.45813417]
[ -7.71232075 -7.17570464 -7.71232075 -7.17570464]
[ -7.45813417 -6.86189404 -7.45813417 -6.86189404]
[ -7.17570464 -6.5132156 -7.17570464 -6.5132156 ]
[ -6.86189404 -6.12579511 -6.86189404 -6.12579511]
[ -6.5132156 -5.6953279 -6.5132156 -5.6953279 ]
[ -6.12579511 -5.217031 -6.12579511 -5.217031 ]
[ -5.6953279 -4.68559 -5.6953279 -4.68559 ]
[ -5.217031 -4.0951 -5.217031 -4.0951 ]
[ -4.68559 -3.439 -4.68559 -3.439 ]
[ -4.0951 -2.71 -4.0951 -2.71 ]
[ -3.439 -1.9 -3.439 -2.71 ]
[ -7.71232075 -7.71232075 -7.45813417 -7.17570464]
[ -7.45813417 -100. -7.45813417 -6.86189404]
[ -7.17570464 -100. -7.17570464 -6.5132156 ]
[ -6.86189404 -100. -6.86189404 -6.12579511]
[ -6.5132156 -100. -6.5132156 -5.6953279 ]
[ -6.12579511 -100. -6.12579511 -5.217031 ]
[ -5.6953279 -100. -5.6953279 -4.68559 ]
[ -5.217031 -100. -5.217031 -4.0951 ]
[ -4.68559 -100. -4.68559 -3.439 ]
[ -4.0951 -100. -4.0951 -2.71 ]
[ -3.439 -100. -3.439 -1.9 ]
[ -2.71 -1. -2.71 -1.9 ]
[ -7.45813417 -7.71232075 -7.71232075 -100. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]
Q_planning步数为0
Iteration 0: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 527.73it/s, episode=30, return=-138.400]
Iteration 1: 100%|████████████████████████████████████████| 30/30 [00:00<00:00, 733.67it/s, episode=60, return=-64.100]
Iteration 2: 100%|████████████████████████████████████████| 30/30 [00:00<00:00, 884.11it/s, episode=90, return=-46.000]
Iteration 3: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 970.35it/s, episode=120, return=-38.000]
Iteration 4: 100%|██████████████████████████████████████| 30/30 [00:00<00:00, 1203.43it/s, episode=150, return=-28.600]
Iteration 5: 100%|██████████████████████████████████████| 30/30 [00:00<00:00, 1367.35it/s, episode=180, return=-25.300]
Iteration 6: 100%|██████████████████████████████████████| 30/30 [00:00<00:00, 1434.95it/s, episode=210, return=-23.600]
Iteration 7: 100%|██████████████████████████████████████| 30/30 [00:00<00:00, 2005.50it/s, episode=240, return=-20.100]
Iteration 8: 100%|██████████████████████████████████████| 30/30 [00:00<00:00, 1671.15it/s, episode=270, return=-17.100]
Iteration 9: 100%|██████████████████████████████████████| 30/30 [00:00<00:00, 1880.12it/s, episode=300, return=-16.500]
Q_planning步数为2
Iteration 0: 100%|████████████████████████████████████████| 30/30 [00:00<00:00, 385.62it/s, episode=30, return=-53.800]
Iteration 1: 100%|████████████████████████████████████████| 30/30 [00:00<00:00, 556.96it/s, episode=60, return=-37.100]
Iteration 2: 100%|████████████████████████████████████████| 30/30 [00:00<00:00, 771.19it/s, episode=90, return=-23.600]
Iteration 3: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 911.51it/s, episode=120, return=-18.500]
Iteration 4: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 940.01it/s, episode=150, return=-16.400]
Iteration 5: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 970.36it/s, episode=180, return=-16.400]
Iteration 6: 100%|██████████████████████████████████████| 30/30 [00:00<00:00, 1203.21it/s, episode=210, return=-13.400]
Iteration 7: 100%|██████████████████████████████████████| 30/30 [00:00<00:00, 1307.75it/s, episode=240, return=-13.200]
Iteration 8: 100%|██████████████████████████████████████| 30/30 [00:00<00:00, 1308.43it/s, episode=270, return=-13.200]
Iteration 9: 100%|██████████████████████████████████████| 30/30 [00:00<00:00, 1367.13it/s, episode=300, return=-13.500]
Q_planning步数20
Iteration 0: 100%|████████████████████████████████████████| 30/30 [00:00<00:00, 152.69it/s, episode=30, return=-18.500]
Iteration 1: 100%|████████████████████████████████████████| 30/30 [00:00<00:00, 261.54it/s, episode=60, return=-13.600]
Iteration 2: 100%|████████████████████████████████████████| 30/30 [00:00<00:00, 289.22it/s, episode=90, return=-13.000]
Iteration 3: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 262.54it/s, episode=120, return=-13.500]
Iteration 4: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 294.91it/s, episode=150, return=-13.500]
Iteration 5: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 294.91it/s, episode=180, return=-13.000]
Iteration 6: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 275.97it/s, episode=210, return=-22.000]
Iteration 7: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 278.51it/s, episode=240, return=-23.200]
Iteration 8: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 278.53it/s, episode=270, return=-13.000]
Iteration 9: 100%|███████████████████████████████████████| 30/30 [00:00<00:00, 271.00it/s, episode=300, return=-13.400]
从上述结果中我们可以很容易地看出,随着 Q-planning 步数的增多,Dyna-Q 算法的收敛速度也随之变快。当然,并不是在所有的环境中,都是 Q-planning 步数越大则算法收敛越快,这取决于环境是否是确定性的,以及环境模型的精度。在上述悬崖漫步环境中,状态的转移是完全确定性的,构建的环境模型的精度是最高的,所以可以通过增加 Q-planning 步数来直接降低算法的样本复杂度。
6.4 小结
本章讲解了一个经典的基于模型的强化学习算法 Dyna-Q,并且通过调整在悬崖漫步环境下的 Q-planning 步数,直观地展示了 Q-planning 步数对于收敛速度的影响。我们发现基于模型的强化学习算法 Dyna-Q 在以上环境中获得了很好的效果,但这些环境比较简单,模型可以直接通过经验数据得到。如果环境比较复杂,状态是连续的,或者状态转移是随机的而不是决定性的,如何学习一个比较准确的模型就变成非常重大的挑战,这直接影响到基于模型的强化学习算法能否应用于这些环境并获得比无模型的强化学习更好的效果。
法的样本复杂度。
6.4 小结
本章讲解了一个经典的基于模型的强化学习算法 Dyna-Q,并且通过调整在悬崖漫步环境下的 Q-planning 步数,直观地展示了 Q-planning 步数对于收敛速度的影响。我们发现基于模型的强化学习算法 Dyna-Q 在以上环境中获得了很好的效果,但这些环境比较简单,模型可以直接通过经验数据得到。如果环境比较复杂,状态是连续的,或者状态转移是随机的而不是决定性的,如何学习一个比较准确的模型就变成非常重大的挑战,这直接影响到基于模型的强化学习算法能否应用于这些环境并获得比无模型的强化学习更好的效果。