PyMARL框架学习(一)——runners & envs & controllers & utils & components 以及环境部署和训练问题

系列文章目录

第一章 PyMARL框架学习 runners & envs & controllers & utils & components 以及环境部署和训练问题
第二章 PyMARL框架学习 main.py & run.py & learners & modules



前言

PyMARL是一个多智能体强化学习(Multi-Agent Reinforcement learning, MARL)框架,专门用于研究和实现多智能体环境下的强化学习算法。该框架由WhiRL开发,支持多个经典的多智能体强化学习算法,如QMIX、VDN、COMA、QTRAN 等。PyMARL广泛应用于模拟多个智能体的合作和竞争场景。PyMARL是用PyTorch编写的,并使用SMAC作为其环境。

本文主要用于博主的自学、自查,源代码来自GitHub-oxwhirl/PyMARL
文章内容恐有纰漏,仅供参考。若有大佬发现问题请您指正,感谢!

一、ruuners

runners模块的作用是负责MARL算法中环境交互、数据收集和训练流程等。其核心功能是管理多智能体与环境的交互过程,并收集每个时间步的状态、动作、奖励等数据,用于后续的训练和分析。此模块下实现了两种Runner:ParallelRunner和EpisodeRunner,分别用于并行环境和单一环境的运行,下面将对两个代码进行简单的分析。

episode_runner.py

class EpisodeRunner:

    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        self.batch_size = self.args.batch_size_run
        assert self.batch_size == 1                                    # batch_size必须等于1,确保每次训练或测试只处理一个环境
        self.env = env_REGISTRY[self.args.env](**self.args.env_args)   # 从env_REGISTRY注册表中获取环境对象。env_REGISTRY[self.args.env]是一个环境类(或构造函数),**self.args.env_args解包传递给环境构造函数的参数
        self.episode_limit = self.env.episode_limit                    # 获取环境的回合限制
        self.t = 0                                                     # 初始化为0,用于跟踪当前时间步
        self.t_env = 0                                                 # 初始化为0,用于跟踪当前环境的时间步
        self.train_returns = []                                        # 分别用于存储训练和测试的返回值
        self.test_returns = []
        self.train_stats = {
   
   }                                          # 分别用于存储训练和测试的统计信息
        self.test_stats = {
   
   }

        # Log the first run
        self.log_train_stats_t = -1000000                              # 用于记录最后一次训练统计信息的时间步骤,初始化为一个非常小的值,以便确保第一次训练时能够记录

    def setup(self, scheme, groups, preprocess, mac):                  # 设置批量生成和多智能体控制的相关内容
        self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1, preprocess=preprocess, device=self.args.device)
        # 使用partial函数创建一个部分应用的EpisodeBatch函数。EpisodeBatch用于创建新的批量数据,partial预设了函数的部分参数,剩下的参数(scheme, groups, preprocess, device)会在函数调用时提供
        self.mac = mac
 
    def get_env_info(self):                                            # 返回环境信息
        return self.env.get_env_info()

    def save_replay(self):                                             # 存储回放信息
        self.env.save_replay()

    def close_env(self):                                               # 关闭环境
        self.env.close()

    def reset(self):                                                   # 重置类的状态
        self.batch = self.new_batch()
        self.env.reset()
        self.t = 0

    def run(self, test_mode=False):
        self.reset()
        terminated = False                                             # 标志变量,指示当前回合是否结束
        episode_return = 0                                             # 累积当前回合的总回报
        self.mac.init_hidden(batch_size=self.batch_size)               # 初始化多智能体控制器MAC的隐藏状态
 
        while not terminated:                                          # 循环体,直到回合结束,即terminated为True时停止循环
            pre_transition_data = {
   
                                       # 收集在当前时间步的环境状态、可用动作和观察数据
                "state": [self.env.get_state()],
                "avail_actions": [self.env.get_avail_actions()],
                "obs": [self.env.get_obs()]
            }

            self.batch.update(pre_transition_data, ts=self.t)          # 更新数据,将pre_transition_data添加到当前时间步self.t的数据中 

            # Pass the entire batch of experiences up till now to the agents
            # Receive the actions for each agent at this timestep in a batch of size 1
            actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
            # 使用多智能体控制器MAC选择当前时间步的动作。select_actions方法会从当前批量数据中选择动作,test_mode参数决定是否处于测试模式
            reward, terminated, env_info = self.env.step(actions[0])
            # 执行环境中的一步操作,传入的actions[0]是第一个智能体的动作。
            episode_return += reward

            post_transition_data = {
   
   
                "actions": actions,
                "reward": [(reward,)],
                "terminated": [(terminated != env_info.get("episode_limit", False),)],
            }
            # 收集在当前时间步执行动作后的数据
            self.batch.update(post_transition_data, ts=self.t)
            # 更新批量数据,将post_transition_data添加到当前时间步self.t的数据中
            self.t += 1                                                  # 增加时间步的值

        last_data = {
   
   
            "state": [self.env.get_state()],
            "avail_actions": [self.env.get_avail_actions()],
            "obs": [self.env.get_obs()]
        }
        self.batch.update(last_data, ts=self.t)
        # 回合结束时,收集最后的环境状态、可用动作和观察数据。
        # 更新批量数据,将last_data添加到当前时间步self.t的数据中

        # Select actions in the last stored state
        actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
        self.batch.update({
   
   "actions": actions}, ts=self.t)
        # 在回合结束时,基于最后的状态选择动作,并更新批量数据

        cur_stats = self.test_stats if test_mode else self.train_stats
        cur_returns = self.test_returns if test_mode else self.train_returns
        log_prefix = "test_" if test_mode else ""
        # 根据test_mode确定当前的统计数据和返回值列表,并设置日志前缀
        cur_stats.update({
   
   k: cur_stats.get(k, 0) + env_info.get(k, 0) for k in set(cur_stats) | set(env_info)})
        cur_stats["n_episodes"] = 1 + cur_stats.get("n_episodes", 0)
        cur_stats["ep_length"] = self.t + cur_stats.get("ep_length", 0)
        # 更新统计数据 cur_stats,包括环境信息中的所有键,累计当前回合的信息,回合数+1,回合长度(ep_length)加上当前时间部署self.t
        if not test_mode:                                                 # 如果不是测试模式,增加环境时间步self.t_env
            self.t_env += self.t

        cur_returns.append(episode_return)                                # 将当前回合的返回值episode_return添加到返回值列表cur_returns中

        if test_mode and (len(self.test_returns) == self.args.test_nepisode):
            self._log(cur_returns, cur_stats, log_prefix)
        elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval:
            self._log(cur_returns, cur_stats, log_prefix)
            if hasattr(self.mac.action_selector, "epsilon"):
                self.logger.log_stat("epsilon", self.mac.action_selector
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值