DPPO(Distributed PPO)分布式算法实现控制倒立摆

Distributed PPO 分布式算法实现倒立摆

单进程多线程实现倒立摆,完整代码在结尾

DPPO算法

DPPO是一种鲁棒的策略梯度算法,适用于高维离散或连续控制问题,能够使用分布式计算扩展到更大的领域。

作为PPO算法的分布式版本,算法分成了领导者和工作者。

  • 领导者维护唯一的策略网络和价值网络,负责集中训练和参数更新。
  • 工作者仅使用全局网络与环境交互,生成轨迹,工作者可选择两种方案进行设置
    1. 工作者只进行数据收集,即与环境交互,不进行梯度计算,而是将数据传递给领导者统一进行更新。
    2. 工作者除了与环境交互收集数据,还进行反向传播计算梯度,最后将梯度汇总计算平均值传输给领导者进行更新。

本文使用第一种方案设置工作者,通过单进程设置多线程Worker与环境交互,来实现离散动作空间倒立摆任务。

Distributed PPO算法不使用副本网络,因此在进行网络参数更新时,需要注意阻塞与环境交互线程,防止收集到的经验样本非最新的策略决策得到的。

  • 同步阻塞:工作者在数据采集阶段并行运行,但在全局训练时需等待同步,避免策略差异导致数据分布偏移。

DPPO算法与PPO算法相同属于Off-Policy算法这里使用PPO-阶段算法进行实验。

由于Distributed PPO是分布式思想,因此如何设置进程/线程与具体任务有关
这里由于倒立摆任务的简单性,使用单进程多线程进行实验。

PPO-截断算法的策略梯度

由于引入了重要性采样,策略梯度修改为:
∇θJ(θ)=E[π(A∣S,θ)β(A∣S)∇θln⁡π(A∣S,θ)(qπ(S,A)−vπ(S))]\nabla_\theta J(\theta) = E[\dfrac{\pi(A|S,\theta)}{\beta(A|S)}\nabla_\theta \ln \pi(A|S,\theta)(q_\pi(S,A)-v_\pi(S))]θJ(θ)=E[β(AS)π(AS,θ)θlnπ(AS,θ)(qπ(S,A)vπ(S))]
其中β(A∣S)\beta(A|S)β(AS)即为经验样本采样时使用的旧策略,这里可表示为π(A∣S,θ′)\pi(A|S,\theta')π(AS,θ)
使用随机梯度方法来近似该真实梯度:
∇θJ(θ)≈π(at∣st,θt)β(at∣st)∇θln⁡π(at∣st,θt)(qπ(st,at)−vπ(st))\nabla_\theta J(\theta) \approx \dfrac{\pi(a_t|s_t,\theta_t)}{\beta(a_t|s_t)}\nabla_\theta \ln \pi(a_t|s_t,\theta_t)(q_\pi(s_t,a_t)-v_\pi(s_t))θJ(θ)β(atst)π(atst,θt)θlnπ(atst,θt)(qπ(st,at)vπ(st))
但是其中优势函数,即动作价值函数和状态价值函数仍是未知的,因此需要对该优势函数进行估计。
可使用时序差分误差来估计该优势函数:
qt(st,at)−vt(st)≈rt+1+γvt(st+1)−vt(st)q_{t}(s_t,a_t)-v_{t}(s_t) \approx r_{t+1}+\gamma v_t(s_{t+1})-v_t(s_t)qt(st,at)vt(st)rt+1+γvt(st+1)vt(st)
使用时序差分误差近似可只需使用一个神经网络来表征状态价值函数即可,只用维护一个Critic(状态价值函数)即可。
在时刻t刚开始时,上述的近似是不太准确的,但随着Critic不断更新迭代(优势函数的迭代近似部分),优势函数近似值会不断收敛到真正的优势函数值。

  • 因此最终PPO-截断的Actor策略梯度可表示为:
    ∇θJ(θ)≈π(at∣st,θt)β(at∣st)∇θln⁡π(at∣st,θt)(qt(st,at)−vt(st))\nabla_\theta J(\theta) \approx \dfrac{\pi(a_t|s_t,\theta_t)}{\beta(a_t|s_t)}\nabla_\theta \ln \pi(a_t|s_t,\theta_t)(q_t(s_t,a_t)-v_t(s_t))θJ(θ)β(atst)π(atst,θt)θlnπ(atst,θt)(qt(st,at)vt(st))

一般实际使用时将该策略梯度化简:
∇θJ(θ)≈1β(at∣st)∇θπ(at∣st,θt)(qt(st,at)−vt(st))\nabla_\theta J(\theta) \approx \dfrac{1}{\beta(a_t|s_t)}\nabla_\theta \pi(a_t|s_t,\theta_t)(q_t(s_t,a_t)-v_t(s_t))θJ(θ)β(atst)1θπ(atst,θt)(qt(st,at)vt(st))

PPO-截断存在截断步骤:
∇θJ(θ)=min⁡(∇θJ(θ),clip(∇θπ(at∣st)β(at∣st),1−ϵ,1+ϵ)⋅A(st))\nabla_\theta J(\theta) = \min(\nabla_\theta J(\theta),clip(\dfrac{\nabla_\theta \pi(a_t|s_t)}{\beta(a_t|s_t)},1-\epsilon,1+\epsilon)\cdot A(s_t))θJ(θ)=min(θJ(θ),clip(β(atst)θπ(atst),1ϵ,1+ϵ)A(st))

注意,策略更新的目标是使目标函数J(θ)J(\theta)J(θ)最大化,即价值最大化,这里需要使用梯度上升。因此使用Pytorch的反向传播时,需要在该策略梯度前加上负号,因为Pytorch的反向传播默认是梯度下降。

策略网络梯度

Critic的梯度表示为:
∇wJ(w)=(vt(st,wt)−(rt+1+γvt(st+1,wt)))∇wvt(st,wt)\nabla_wJ(w)=(v_t(s_t,w_t) - (r_{t+1}+\gamma v_t(s_{t+1},w_t)))\nabla_w v_t(s_t,w_t)wJ(w)=(vt(st,wt)(rt+1+γvt(st+1,wt)))wvt(st,wt)
Critic使用梯度下降,最小化近似价值与真实价值之间的误差,因此实际使用Pytorch的反向传播时,无需添加负号

PPO算法代码实践

Actor

演员使用一层隐藏层的神经网络。

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        result = self.net(x)
        return result
Critic

评论家使用一层隐藏层的神经网络

class Critic(nn.Module):
    def __init__(self, state_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        result = self.net(x)
        return result.squeeze(-1)
Worker

Worker工作者的任务是与环境交互,并将经验样本传输给主线程进行更新。

class Worker(Thread):
    def __init__(self, id, agent: DPPO, gamma, device, episode, train_queue: Queue, update_data, return_queue:Queue):
        super().__init__()
        self.env = gym.make('CartPole-v1')
        self.agent = agent
        self.device = device
        self.episode = episode
        self.train_queue = train_queue
        self.update_data = update_data
        self.gamma = gamma
        self.id = id
        self.return_queue = return_queue

    def run(self):
        return_list = []
        for num in range(episode):
            state, info = self.env.reset()
            done = False
            reward_list = []
            step = 0
            while not done and step < 1500:
                action = self.agent.take_action(state)
                next_state, reward, done, _, __ = self.env.step(action)
                self.train_queue.put((state, action, reward, next_state, done))
                reward_list.append(reward)
                state = next_state # 切记要更新状态
                step += 1
                if self.train_queue.qsize() > self.update_data:
                    INTERACT_FLAG.clear()
                    UPDATE_FLAG.set()
                    INTERACT_FLAG.wait()
            res = 0
            for i in range(len(reward_list)-1, -1, -1):
                res = self.gamma * res + reward_list[i]
            print(f'线程{self.id},第{num}轮回报:{res}{step}步数')
            return_list.append(res)
        self.return_queue.put(return_list)

为了防止策略过好,倒立摆任务一直进行无法停止,限制最大交互次数为1500。

DPPO主线程

主线程进行动作决策和网络更新。

class DPPO:
    def __init__(self, state_dim, action_dim, workers_num, lr, gamma, epochs, episode, eps, device, train_queue, update_data, return_queue):
        super().__init__()
        self.actor = Actor(state_dim, action_dim).to(device)
        self.critic = Critic(state_dim).to(device)
        self.optimizer_actor = torch.optim.AdamW(self.actor.parameters(), lr=lr)
        self.optimizer_critic = torch.optim.AdamW(self.critic.parameters(), lr=1.5*lr)
        self.device = device
        self.epochs = epochs
        self.gamma = gamma
        self.eps = eps
        self.workers_list = []
        self.workers_num = workers_num
        self.episode = episode
        self.train_queue = train_queue
        self.update_data = update_data
        self.return_queue = return_queue
        self.initialize()

    def take_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).to(self.device)
        prob = self.actor(state)
        dist = torch.distributions.Categorical(prob)
        action = dist.sample()
        return action.cpu().item()

    def update(self, state, action, reward, next_state, done):
        state = torch.tensor(state, dtype=torch.float32).to(self.device)
        action = torch.tensor(action, dtype=torch.int64).to(self.device)
        reward = torch.tensor(reward, dtype=torch.float32).to(self.device)
        next_state = torch.tensor(
            next_state, dtype=torch.float32).to(self.device)
        done = torch.tensor(done, dtype=torch.float32).to(self.device)

        with torch.no_grad():
            state_value = self.critic(state)
            next_state_value = self.critic(next_state)
            TD_target = reward + self.gamma * next_state_value * (1 - done)
            advantage = (TD_target - state_value).detach()
            old_prob = self.actor(state).gather(
                1, action.unsqueeze(1)).detach()

        for _ in range(self.epochs):
            new_prob = self.actor(state).gather(1, action.unsqueeze(1))
            ratio = new_prob / old_prob
            s1 = ratio * advantage
            s2 = torch.clamp(ratio, 1-self.eps, 1+self.eps) * advantage
            actor_loss = - torch.mean(torch.min(s1, s2)) # + 0.002 * torch.sum(new_prob * torch.log(new_prob + 1e-7), dim=1, keepdim=True).mean()
            critic_loss = torch.mean(F.mse_loss(self.critic(state), TD_target))
            self.optimizer_actor.zero_grad()
            self.optimizer_critic.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.optimizer_actor.step()
            self.optimizer_critic.step()

    def main(self):
        for index in range(self.workers_num):
            self.workers_list.append(
                Worker(index, self, self.gamma, self.device, self.episode,
                    self.train_queue, self.update_data, self.return_queue)
            )
            self.workers_list[-1].start()

        for i in range(self.workers_num):
            self.workers_list[-1].join()
            self.workers_list.pop()

    def initialize(self):
        def init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)
        self.actor.apply(init_weights)
        self.critic.apply(init_weights)

完整代码

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from queue import Queue
import threading as td
from threading import Thread
import gym
import matplotlib.pyplot as plt

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        result = self.net(x)
        return result

class Critic(nn.Module):
    def __init__(self, state_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        result = self.net(x)
        return result.squeeze(-1)

class DPPO:
    def __init__(self, state_dim, action_dim, workers_num, lr, gamma, epochs, episode, eps, device, train_queue, update_data, return_queue):
        super().__init__()
        self.actor = Actor(state_dim, action_dim).to(device)
        self.critic = Critic(state_dim).to(device)
        self.optimizer_actor = torch.optim.AdamW(self.actor.parameters(), lr=lr)
        self.optimizer_critic = torch.optim.AdamW(self.critic.parameters(), lr=1.5*lr)
        self.device = device
        self.epochs = epochs
        self.gamma = gamma
        self.eps = eps
        self.workers_list = []
        self.workers_num = workers_num
        self.episode = episode
        self.train_queue = train_queue
        self.update_data = update_data
        self.return_queue = return_queue
        self.initialize()

    def take_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).to(self.device)
        prob = self.actor(state)
        dist = torch.distributions.Categorical(prob)
        action = dist.sample()
        return action.cpu().item()

    def update(self, state, action, reward, next_state, done):
        state = torch.tensor(state, dtype=torch.float32).to(self.device)
        action = torch.tensor(action, dtype=torch.int64).to(self.device)
        reward = torch.tensor(reward, dtype=torch.float32).to(self.device)
        next_state = torch.tensor(
            next_state, dtype=torch.float32).to(self.device)
        done = torch.tensor(done, dtype=torch.float32).to(self.device)

        with torch.no_grad():
            state_value = self.critic(state)
            next_state_value = self.critic(next_state)
            TD_target = reward + self.gamma * next_state_value * (1 - done)
            advantage = (TD_target - state_value).detach()
            old_prob = self.actor(state).gather(
                1, action.unsqueeze(1)).detach()

        for _ in range(self.epochs):
            new_prob = self.actor(state).gather(1, action.unsqueeze(1))
            ratio = new_prob / old_prob
            s1 = ratio * advantage
            s2 = torch.clamp(ratio, 1-self.eps, 1+self.eps) * advantage
            actor_loss = - torch.mean(torch.min(s1, s2)) # + 0.002 * torch.sum(new_prob * torch.log(new_prob + 1e-7), dim=1, keepdim=True).mean()
            critic_loss = torch.mean(F.mse_loss(self.critic(state), TD_target))
            self.optimizer_actor.zero_grad()
            self.optimizer_critic.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.optimizer_actor.step()
            self.optimizer_critic.step()

    def main(self):
        for index in range(self.workers_num):
            self.workers_list.append(
                Worker(index, self, self.gamma, self.device, self.episode,
                    self.train_queue, self.update_data, self.return_queue)
            )
            self.workers_list[-1].start()

        for i in range(self.workers_num):
            self.workers_list[-1].join()
            self.workers_list.pop()

    def initialize(self):
        def init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)
        self.actor.apply(init_weights)
        self.critic.apply(init_weights)

class Worker(Thread):
    def __init__(self, id, agent: DPPO, gamma, device, episode, train_queue: Queue, update_data, return_queue:Queue):
        super().__init__()
        self.env = gym.make('CartPole-v1')
        self.agent = agent
        self.device = device
        self.episode = episode
        self.train_queue = train_queue
        self.update_data = update_data
        self.gamma = gamma
        self.id = id
        self.return_queue = return_queue

    def run(self):
        return_list = []
        for num in range(episode):
            state, info = self.env.reset()
            done = False
            reward_list = []
            step = 0
            while not done and step < 1500:
                action = self.agent.take_action(state)
                next_state, reward, done, _, __ = self.env.step(action)
                self.train_queue.put((state, action, reward, next_state, done))
                reward_list.append(reward)
                state = next_state # 切记要更新状态
                step += 1
                if self.train_queue.qsize() > self.update_data:
                    # 开始更新,阻塞交互线程,唤醒更新线程
                    INTERACT_FLAG.clear()
                    UPDATE_FLAG.set()
                    INTERACT_FLAG.wait()
            res = 0
            for i in range(len(reward_list)-1, -1, -1):
                res = self.gamma * res + reward_list[i]
            print(f'线程{self.id},第{num}轮回报:{res}{step}步数')
            return_list.append(res)
        self.return_queue.put(return_list)
        


def dppo_update(agent: DPPO, train_queue: Queue, stop_flag: td.Event):
    while not stop_flag.is_set():
        UPDATE_FLAG.wait(timeout=5)
        traject = []
        while not train_queue.empty():
            traject.append(train_queue.get())
        state, action, reward, next_state, done = zip(*traject)
        agent.update(np.array(state), np.array(action), np.array(
            reward), np.array(next_state), np.array(done))
        # 阻塞更新线程,唤醒交互线程
        UPDATE_FLAG.clear()
        INTERACT_FLAG.set()


if __name__ == '__main__':
    os.system('cls')
    torch.autograd.set_detect_anomaly(True)

    env = gym.make('CartPole-v1')
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    env.close()

    train_queue = Queue(maxsize=1000)  # 训练队列
    return_queue = Queue(maxsize=10)  # 回报收集队列

    gamma = 0.99  # 折扣因子
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    lr = 5e-5  # 学习率
    worker_num = 2  # worker数量
    update_data = 64  # 每update_date个数据,进行一次更新
    epochs = 15  # 一次更新的迭代次数
    episode = 350  # 每个worker进行游戏的总轮次
    eps = 0.15  # 截断阈值

    stop_flag = td.Event()  # 判断所有程序是否结束
    INTERACT_FLAG = td.Event()  # 判断是否继续与环境交互产生样本
    UPDATE_FLAG = td.Event()  # 判断是否开始更新

    agent = DPPO(state_dim, action_dim, worker_num, lr, gamma, epochs,
                 episode, eps, device, train_queue, update_data, return_queue)
    # 启动更新线程
    td.Thread(target=dppo_update, args=(
        agent, train_queue, stop_flag), daemon=True).start()

    # 启动Workers线程
    agent.main()
    stop_flag.set() # 设置全局线程关闭
    while not return_queue.empty():
        return_list = return_queue.get()
        average = [np.mean(return_list[i:i+9])
                   for i in range(0, len(return_list)-8)]
        epi = [x for x in range(len(average))]
        plt.plot(epi, average)
        plt.show()

为了实现定量更新与同步阻塞,使用threading.Event(),这是Python用于线程间通信的,是通过内部标志位(初始为False)控制线程阻塞和唤醒。

  • 在收集到一定数据的经验样本后,阻塞交互线程,唤醒网络参数更新线程,开始更新网络;
  • 当更新完毕,阻塞更新线程,唤醒交互线程,使用新的策略与环境进行交互,收集数据。

两个Worker的回报收敛曲线如下。
在这里插入图片描述
在这里插入图片描述

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值