Sarsa
上代码
from maze_env import Maze
from RL_brain import SarsaTable
def update():
for episode in range(100):
# initial observation
observation = env.reset()
# RL choose action based on observation
action = RL.choose_action(str(observation))
while True:
# fresh env
env.render()
# RL take action and get next observation and reward
observation_, reward, done = env.step(action)
# RL choose action based on next observation
action_ = RL.choose_action(str(observation_))
# RL learn from this transition (s, a, r, s, a) ==> Sarsa
RL.learn(str(observation), action, reward, str(observation_), action_)
# swap observation and action
observation = observation_
action = action_
# break while loop when end of this episode
if done:
break
# end of game
print('game over')
env.destroy()
if __name__ == "__main__":
env = Maze()
RL = SarsaTable(actions=list(range(env.n_actions)))
env.after(100, update)
env.mainloop()
---------------------------------------------
def choose_action(self, observation):
self.check_state_exist(observation)
# action selection
if np.random.rand() < self.epsilon:
# choose best action
state_action = self.q_table.loc[observation, :]
# some actions may have the same value, randomly choose on in these actions
action = np.random.choice(state_action[state_action == np.max(state_action)].index)
else:
# choose random action
action = np.random.choice(self.actions)
return action
-----------------------------------------------
class SarsaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal
else:
q_target = r # next state is terminal
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
Sarsa的思路与Q_learning大体相同,只有一部分有差别
(1)选动作的顺序不一样,选动作的策略都一样
Q_learning 初始状态s1--选动作a1--执行动作a1--下一个状态s2和奖励R--下一个状态s2所有动作中的最大值--更新Q值--更新状态--选动作。。。。。。
Sarsa 初始状态s1--选动作a1--执行动作a1--下一个状态s2和奖励R--选动作a2--下一个状态s2执行a2的值--更新Q值--更新状态--执行动作a2--。。。。。。。。。。。。。。。。。。。。。
可以看出Sarsa先选了动作,区别就在这里,
(2)更新策略:
Q_learning-- q_target = r + self.gamma * self.q_table.loc[s_, :].max()
Sarsa----------q_target = r + self.gamma * self.q_table.loc[s_, a_]
可以看出来,Q_learning利用下一步的最大值更新Q值,这样Q值就会比较大,这样选取Q就会向最有价值路线进发,而不管陷阱,就比较勇敢。
而Sarsa不一样,他的更新是按选的动作的实际值更新,而选动作又根据有概率的贪婪算法, 0.9概率选最大,0.1概率会进行探索。也就是说虽然这个动作很大几率上还是Q值最大的那个,但是至少有0.1的可能不是最大的那个动作,而Q_learning是直接选择最大的那个更新Q值,区别就在这里。