强化学习原理python篇06(拓展)——DQN拓展
拓展篇参考赵世钰老师的教材和Maxim Lapan 深度学习强化学习实践(第二版),请各位结合阅读,本合集只专注于数学概念的代码实现。
n-steps
假设在训练开始时,顺序地完成前面的更新,前两个更新是没有用的,因为当前Q(s2, a)和Q(s2, a)是不对的,并且只包含初始的随机值。唯一有用的更新是第3个更新,它将奖励r3正确地赋给终结状态前的状态s3。
现在来完成一次又一次的更新。在第2次迭代,正确的值被赋给了Q(s2, a),但是Q(s1, a)的更新还是不对的。只有在第3次迭代时才能给所有的Q赋上正确的值。所以,即使在1步的情况下,它也需要3步才能将正确的值传播给所有的状态。
为此,修改第四步
4)将转移过程(s, a, r, s’)存储在回放缓冲区中 r 用 n 步合计展示。
代码
修改ReplayBuffer和DQN中的calculate_y_hat_and_y实现
class ReplayBuffer:
def __init__(self, episode_size, replay_time):
# 存取 queue episode
self.queue = []
self.queue_size = episode_size
self.replay_time = replay_time
def get_batch_queue(self, env, action_trigger, batch_size, epsilon):
def insert_sample_to_queue(env):
state, info = env.reset()
stop = 0
episode = []
while True:
if np.random.uniform(0, 1, 1) > epsilon:
action = env.action_space.sample()
else:
action = action_trigger(state)
next_state, reward, terminated, truncated, info = env.step(action)
episode.append([state, action, next_state, reward, terminated])
state = next_state
if terminated:
state, info = env.reset()
self.queue.append(episode)
episode = []
stop += 1
continue
if stop >= replay_time:
self.queue.append(episode)
episode = []
break
def init_queue(env):
while True:
insert_sample_to_queue(env)
if len(self.queue) >= self.queue_size:
break
init_queue(env)
insert_sample_to_queue(env)
self.queue = self.queue[-self.queue_size :]
return random.sample(self.queue, batch_size)
class DQN:
def __init__(self, env, obs_size, hidden_size, q_table_size):
self.env = env
self.net = Net(obs_size, hidden_size, q_table_size)
self.tgt_net = Net(obs_size, hidden_size, q_table_size)
# 更新net参数
def update_net_parameters(self, update=True):
self.net.load_state_dict(self.tgt_net.state_dict())
def get_action_trigger(self, state):
state = torch.Tensor(state)
action = int(torch.argmax(self.tgt_net(state).detach()))