PyMARL框架学习(二)——main.py & run.py & learners & modules

PyMARL框架学习系列目录

第一章 PyMARL框架学习 runners & envs & controllers & utils & components以及环境部署和训练问题
第二章 PyMARL框架学习 main.py & run.py & learners & modules




一、main.py

main.py是PyMARL框架的项目入口点,负责设置实验配置并启动实验流程。
其主要功能包括:

  1. 初始化配置和日志记录:通过sacred库管理实验流程,创建日志记录器,并定义结果保存路径。该文件会捕获和管理实验输出,确保标准输出和日志信息正确记录。
  2. 读取并合并配置文件:从多个配置文件中加载算法和环境的配置,使用 _get_config函数读取YAML文件中的内容,并通过recursive_dict_update合并默认配置、环境配置和算法配置。
  3. 设置随机种子:为了保证实验的可重复性,main.py中设置了NumPy和PyTorch的随机种子。
  4. 启动实验:调用ex.run_commandline(params)启动实验流程,调用run()函数执行主要的MARL训练过程。

1.引入库

import numpy as np
import os
import collections
from os.path import dirname, abspath
from copy import deepcopy
from sacred import Experiment, SETTINGS
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
import sys
import torch as th
from utils.logging import get_logger
import yaml
from run import run
  1. numpy(np):numpy是经典的用于科学计算的库,尤其在处理多维数组和矩阵。在main.py中主要用于生成随机数np.random.seed()和其他数值计算相关的操作。
  2. os:用于操作系统相关的功能,例如路径操作、文件管理等。在main.py中,os.path.join用于构建文件路径,os.path.dirname和os.path.abspath分别用于获取文件的目录名和绝对路径。
  3. collections:包含高效的数据结构。在main.py的recursive_dict_update()方法中,collections.Mapping被用来判断字典的嵌套层次,用于递归地更新配置。
  4. copy:用于深拷贝复杂的对象,包括嵌套的列表和字典,确保拷贝后的对象和原对象之间没有引用关系。在main.py中,用于复制配置字典(config_copy)。
  5. sacred:sacred库函数不会在本博客详细介绍,感兴趣的朋友可以访问以下链接,参考大佬的介绍:Sacred 教程;同样,有关Python装饰器的内容也不在此论述,参考大佬的介绍:Python 装饰器
    。我们只需要简单知道Sacred是一个Python库, 可以帮助我们配置、组织、记录和复制实验。其作用是保存实验中的一系列重要信息与结果。Experiment用于定义实验对象,SETTINGS用于配置捕获实验输出的模式。
  6. sacred.observers.FileStorageObserver:是sacred的一个观察者类,用于将实验结果和日志保存到磁盘,确保实验过程中的数据可以被持久化。
  7. sacred.utils.apply_backspaces_and_linefeeds:是sacred中处理日志的类。处理日志中的回车和换行符,确保输出日志格式整洁,避免字符错乱。
  8. sys:提供与Python解释器交互的接口,常用于处理命令行参数。在main.py中,它通过sys.argv获取命令行参数列表,并传递给实验运行函数。
  9. torch:torch是PyTorch的核心库,用于深度学习和张量计算。在main.py中,th.manual_seed()被用来设置PyTorch的随机种子,确保实验结果的可重复性。
  10. utils.logging.get_logger:从自定义的utils.logging模块中导入get_logger函数,用于初始化日志记录器。它负责输出实验中的日志信息。
  11. yaml:yaml库用于解析YAML格式的配置文件。在main.py中,用于加载实验的默认配置、算法配置和环境配置

2.创建sacred对象

SETTINGS['CAPTURE_MODE'] = "fd"                # "fd" 表示通过文件描述符来捕获标准输出(stdout)和标准错误(stderr)。
                                               # 这样,所有的打印输出和错误信息将被写入文件,而不是直接显示在控制台。如果设置为 "no",则标准输出和标准错误将直接显示在控制台。
logger = get_logger()                          # 创建一个日志记录器实例,get_logger() 是一个自定义的函数,用于获取配置好的日志记录器。这个记录器负责在实验过程中记录重要的日志信息。记录器的配置通常包括日志级别、日志格式、日志输出位置(例如控制台或文件),以及其他日志相关的设置
ex = Experiment("pymarl")                      # 创建一个 Sacred 实验对象。Experiment 类用于定义和管理实验。"pymarl"是实验的名称,用于标识实验。
ex.logger = logger                             # 将之前的将之前创建的日志记录器关联到Sacred实验对象
ex.captured_out_filter = apply_backspaces_and_linefeeds # 设置输出过滤器。apply_backspaces_and_linefeeds是一个工具函数,用于处理标准输出中的回车和换行符。它确保日志输出的格式在文件中不会因为控制字符而混乱,使得输出更加清晰可读。
results_path = os.path.join(dirname(dirname(abspath(__file__))), "results") # 定义实验结果的存储路径,

3.my_main函数

@ex.main                           # @ex.main是sacred框架的装饰器,表明这是实验的主函数,实验开始时会调用此函数
def my_main(_run, _config, _log):  # my_main是使用sacred框架装饰的主要实验函数。当实验开始时,这个函数被执行,负责设置随机种子并启动实验流程。
    config = config_copy(_config)  # 这个语句将实验的配置_config深拷贝到config。使用config_copy的目的是确保修改config时不会影响原始配置对象_config,因为_config可能会在其他地方被使用。
    np.random.seed(config["seed"]) # 设置numpy的随机种子,使得整个实验中的随机数生成是可控且可重复的。config["seed"]是从配置中获取的随机种子
    th.manual_seed(config["seed"]) # 同样,设置PyTorch的随机种子。
    config['env_args']['seed'] = config["seed"] # 为环境参数设置相同的随机种子,确保环境初始化时使用的随机数也是可控的。
    run(_run, config, _log)        # 运行实验框架

4._get_config函数

def _get_config(params, arg_name, subfolder):  # _get_config函数用于从命令行参数中提取特定的配置文件,并将其加载为字典形式返回。
    config_name = None                         # 初始化config_name为None,用于存储命令行参数中指定的配置文件名。
    for _i, _v in enumerate(params):           # 循环遍历params列表,逐个检查参数。
        if _v.split("=")[0] == arg_name:       # _v.split("=")用于将参数按等号拆分,检查参数是否符合arg_name。
            config_name = _v.split("=")[1]     # 若参数的名字与arg_name相符,则提取等号右边的值作为config_name,即配置文件名。
            del params[_i]                     # 删除已经处理的参数,避免在后续处理中重复使用。
            break

    if config_name is not None:                # 如果找到了有效的配置文件名,则继续读取配置文件
        with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)), "r") as f:
            try:
                config_dict = yaml.load(f)     # 打开指定路径的YAML配置文件,路径由config_name和subfolder共同组成;使用yaml库解析YAML配置文件,返回一个字典格式的配置。
            except yaml.YAMLError as exc:
                assert False, "{}.yaml error: {}".format(config_name, exc) # 如果读取YAML文件出错,则通过assert抛出错误并终止程序运行,输出相应的错误信息。
        return config_dict

5.recursive_dict_update函数

def recursive_dict_update(d, u):                           # 递归地合并两个字典d和u,如果遇到嵌套字典会继续递归更新。
    for k, v in u.items():                                 # 遍历字典u中的键值对,k是键,v是对应的值。
        if isinstance(v, collections.Mapping):             # 检查值v是否是一个字典(即是否是映射类型)。collections.Mapping是Python中的抽象基类,表示字典类型。
            d[k] = recursive_dict_update(d.get(k, {
   }), v)  # 如果v是字典,对字典d的相应键k执行递归更新。如果d中不存在键k,则返回一个空字典{}。这一步递归地合并字典d和u。
        else: 
            d[k] = v                                       # 如果v不是字典,直接将v赋值给d[k],覆盖或新增键k的值。
    return d                                               # 返回合并后的字典d。

6.config_copy函数

def config_copy(config):                                       # 递归深拷贝配置对象,确保字典或列表中的每个元素都被独立复制,不共享原对象的引用
    if isinstance(config, dict):                               # 检查config是否为字典类型
        return {
   k: config_copy(v) for k, v in config.items()}  # 如果config是字典,递归地对每个键值对调用config_copy,创建一个新的字典,其中的值是递归复制的
    elif isinstance(config, list):                             # 如果config是列表,归地对列表中的每个元素调用config_copy,创建一个新的列表
        return [config_copy(v) for v in config]
    else:
        return deepcopy(config)                                # 如果config既不是字典也不是列表,则直接返回config的深拷贝

6.主程序_main_

if __name__ == '__main__':
    params = deepcopy(sys.argv)                                                             # 深拷贝命令行参数列表sys.argv,params是一个复制的列表,防止对命令行参数的修改影响原始参数
    with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f: # 打开default.yaml配置文件,路径为当前文件所在目录下的config文件夹中的default.yaml文件
        try:
            config_dict = yaml.load(f)                                                      # 使用yaml库加载default.yaml文件,将其解析为字典config_dict
        except yaml.YAMLError as exc:                                                       # 如果解析yaml文件时发生错误,捕获异常并终止程序,报告错误信息
            assert False, "default.yaml error: {}".format(exc)

    env_config = _get_config(params, "--env-config", "envs")         # 通过_get_config函数从命令行参数params中提取环境配置--env-config,加载位于config/envs/文件夹中的.yaml配置文件,返回为字典。
    alg_config = _get_config(params, "--config", "algs")             # 从命令行参数中提取算法配置--config,加载位于config/algs/文件夹中的.yaml文件
    config_dict = recursive_dict_update(config_dict, env_config)     # 递归地将env_config中的配置合并到config_dict中
    config_dict = recursive_dict_update(config_dict, alg_config)     # 接着将alg_config的配置合并到config_dict

    ex.add_config(config_dict)                                       # 将合并后的配置config_dict添加到sacred的实验对象ex中。

    logger.info("Saving to FileStorageObserver in results/sacred.")  # 记录日志信息,表示实验结果将保存在results/sacred/目录中
    file_obs_path = os.path.join(results_path, "sacred")             # 构建保存路径,实验结果将存储在results/sacred目录中
    ex.observers.append(FileStorageObserver.create(file_obs_path))   # 将FileStorageObserver添加到sacred的实验观察者中,用于将实验数据和日志保存到磁盘
    ex.run_commandline(params)                                       # 运行 acred实验,并传递命令行参数params,启动整个实验流程

二、run.py

1.引入库

import datetime
import os
import pprint
import time
import threading
import torch as th
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timehelper import time_left, time_str
from os.path import dirname, abspath

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.transforms import OneHot
  1. datetime:用于处理日期和时间,提供了日期和时间的类,如datetime.datetime、datetime.date等,便于进行时间计算、格式转换等操作。
  2. pprint:用于操作系统相关的功能,例如路径操作、文件管理等。在main.py中,os.path.join用于构建文件路径,os.path.dirname和os.path.abspath分别用于获取文件的目录名和绝对路径。
  3. time:提供时间相关的功能,如获取当前时间、测量程序运行时间、暂停执行等。
  4. threading:用于创建和管理线程,实现多线程并发执行。
  5. SimpleNamespace:SimpleNamespace是types模块中的一个简单类,用于创建具有动态属性的对象。用于存储和访问配置参数或其他需要动态属性的对象,提供类似于JavaScript对象的属性访问方式
  6. Logger:从PyMARL项目中的utils.logging模块中导入Logger类。
  7. time_left & time_str:从utils.timehelper模块中导入time_left和 time_str函数。time_left函数用于估算剩余的时间。它根据已经经过的时间和当前的进度,来推测剩余的任务所需时间。time_str函数将时间(秒数)转换为易读的字符串格式。
  8. REGISTRY as le_REGISTRY:从learners模块中引入注册表REGISTRY。管理和存储不同类型的学习器(learners)。
  9. REGISTRY as r_REGISTRY:从 runners 模块中引入注册表 REGISTRY。管理和存储不同类型的运行器(runners)。
  10. REGISTRY as mac_REGISTRY:从 controllers 模块中引入注册表 REGISTRY。管理和存储不同的控制器(controllers)。
  11. ReplayBuffer:从episode_buffer模块中引入ReplayBuffer。用于存储环境与智能体交互经验的缓冲区。
  12. OneHot:从transforms模块中引入OneHot变换。OneHot 是用于将分类变量转换为独热编码的工具,通常用于处理离散动作或状态。

2.run函数

def run(_run, _config, _log):
    _config = args_sanity_check(_config, _log)               # 对配置参数进行合理性检查。通过args_sanity_check函数验证_config的各项参数是否设置正确,以避免运行时发生错误
    args = SN(**_config)                                     # 将_config转换为SimpleNamespace对象,允许通过点操作符访问配置参数。                    
    args.device = "cuda" if args.use_cuda else "cpu"         # 根据配置中的use_cuda参数,决定是否使用GPU(cuda)还是CPU

    logger = Logger(_log)                                    # 创建一个Logger对象
    _log.info("Experiment Parameters:")                      # 记录实验的配置信息
    experiment_params = pprint.pformat(_config, indent=4, width=1) # 美化配置信息,使其在日志输出中更加易读
    _log.info("\n\n" + experiment_params + "\n")             # 将配置信息记录到日志中

    unique_token = "{}__{}".format(args.name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) # 生成一个唯一的实验标识符unique_token,由实验名称和当前时间组成
    args.unique_token = unique_token                         # 将其保存到args中,用于标识实验
    if args.use_tensorboard:                                 # 如果配置中开启了Tensorboard(args.use_tensorboard),则
        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs") # 定义Tensorboard日志的存储路径tb_logs_direc和具体实验目录tb_exp_direc
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)              
        logger.setup_tb(tb_exp_direc)                                                           # 使用logger.setup_tb(tb_exp_direc)初始化Tensorboard日志记录器

    logger.setup_sacred(_run)                                # 配置Sacred框架,使用logger.setup_sacred(_run),让Sacred记录日志等实验信息

    run_sequential(args=args, logger=logger)                 # 调用run_sequential函数,传入配置参数args和日志记录器logger

    print("Exiting Main")                                    # 训练结束后的清理操作。首先打印一条信息表示主程序即将退出。
    print("Stopping all threads")                            # 遍历所有正在运行的线程,确保除主线程外的其他线程都安全地关闭。对于每个非主线程,使用join方法等待其结束,避免程序强制退出时导致线程不安全退出
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined") 

    print("Exiting script")                                  # 打印脚本即将退出的提示。

    os._exit(os.EX_OK)                                       # 使用os._exit(os.EX_OK)进行强制退出,确保框架和所有线程都彻底结束

3.evaluate_sequential函数

def evaluate_sequential(args, runner):
    for _ in range(args.test_nepisode):                      # 使用for循环,遍历args.test_nepisode次(即进行指定数量的测试episode)
        runner.run(test_mode=True)                           # 调用runner.run()方法,每次调用runner.run()都是在测试模式下进行的。test_mode=True通常意味着这次运行不会影响模型的训练过程,比如不会更新参数或进行探索性动作,只会用于评估模型的表现
    if args.save_replay:                                     # 检查args.save_replay是否为True,是否需要保存测试期间的回放
        runner.save_replay()                                 # 如果save_replay为True,调用runner.save_replay()方法,保存测试的回放数据(例如动作、状态、奖励等)
    runner.close_env()                                       # 在所有测试episode完成后,调用runner.close_env()关闭环境,释放资源

4.run_sequential函数

def run_sequential(args, logger):
    runner = r_REGISTRY[args.runner](args=args, logger=logger) # 初始化runner。r_REGISTRY是一个注册表,存储了不同的runner类型,根据args.runner指定的类型创建合适的runner对象
    env_info = runner.get_env_info()                            
    args.n_agents = env_info["n_agents"]                        
    args.n_actions = env_info["n_actions"]                      
    args.state_shape = env_info["state_shape"]                 # 获取环境的相关信息env_info,如智能体数量、动作空间大小、状态空间大小等,并将这些信息存储在args中供后续使用

    # Default/Base scheme
    scheme = {
   
        "state": {
   "vshape": env_info["state_shape"]},
        "obs": {
   "vshape": env_info["obs_shape"], "group": "agents"},
        "actions": {
   "vshape": (1,), "group": "agents", "dtype": th.long},
        "avail_actions": {
   "vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
        "reward": {
   "vshape": (1,)},
        "terminated": {
   "vshape": (1,), "dtype": th.uint8},
    }
    groups = {
   
        "agents": args.n_agents
    }
    preprocess = {
   
        "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)])
    }
    # 定义scheme,用于描述不同数据的形状(例如状态、观测、动作、奖励等)和数据类型。groups用于定义智能体组(即多智能体的数量),preprocess负责对动作进行预处理(例如将动作进行one-hot编码)

    buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
                          preprocess=preprocess,
                          device="cpu" if args.buffer_cpu_only else args.device)
    # 初始化ReplayBuffer,用于存储智能体与环境交互的数据。这个缓冲区将记录每个episode的状态、动作、奖励等,并支持后续的批量采样以进行训练

    mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)
    # 初始化多智能体控制器mac,从注册表mac_REGISTRY中根据args.mac加载控制器。控制器负责为各个智能体选择动作,并根据策略进行决策

    runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)
    # 将数据的scheme和预处理步骤传递给runner

    learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args)
    if args.use_cuda:
        learner.cuda()
    # 始化learner,用于实际执行模型的训练过程。它从注册表le_REGISTRY中根据args.learner创建训练组件。如果使用CUDA,则将learner转移到GPU上

    if args.checkpoint_path != "": # 检查是否有加载模型的路径
        timesteps = []
        timestep_to_load = 0
        if not os.path.isdir(args.checkpoint_path):
            logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(args.checkpoint_path))
            return
        # 如果检查点目录不存在,记录错误信息并退出函数

        for name in os.listdir(args.checkpoint_path):
            full_name = os.path.join(args.checkpoint_path, name)
            # Check if they are dirs the names of which are numbers
            if os.path.isdir(full_name) and name.isdigit():
                timesteps.append(int(name))
        # 遍历检查点目录,找出保存的模型文件夹,并记录这些文件夹代表的时间步数

        if args.load_step == 0:
            # choose the max timestep
            timestep_to_load = max(timesteps)
        else:
            # choose the timestep closest to load_step
            timestep_to_load = min(timesteps, key=lambda x: abs(x - args.load_step))
        # 根据args.load_step决定加载哪个时间步的模型。如果load_step == 0,则加载最新的模型,否则加载离指定步数最近的模型。

        model_path = os.path.join(args.checkpoint_path, str(timestep_to_load))
        logger.console_logger.info("Loading model from {}".format(model_path))
        learner.load_models(model_path)
        runner.t_env = timestep_to_load

        if args.evaluate or args.save_replay:
            evaluate_sequential(args, runner)
            return
        # 加载指定时间步的模型并开始评估或保存回放

    # start training
    episode = 0
    last_test_T = -args.test_interval - 1
    last_log_T = 0
    model_save_time = 0

    start_time = time.time()
    last_time = start_time

    logger.console_logger.info("Beginning training for {} timesteps".format(args.t_max))
    # 初始化一些变量以控制训练过程,包括当前episode编号、上次测试和日志的时间步数、模型保存的时间步数等,并记录训练开始时间


    while runner.t_env <= args.t_max: # 开始训练循环,直到达到最大时间步 t_max
        episode_batch = runner.run(test_mode=False)
        buffer.insert_episode_batch(episode_batch)
        # 运行一个完整的episode,并将episode数据插入缓冲区

        if buffer.can_sample(args.batch_size)
### 安装 PyMARL 框架 为了成功安装并配置 PyMARL 框架,在 Windows 系统下的 Anaconda 环境中操作是一个推荐的选择。以下是详细的安装指南: #### 创建 Conda 虚拟环境 创建一个新的 conda 环境来隔离 PyMARL 的依赖关系,这有助于避免与其他项目发生冲突。 ```bash conda create -n pymarl python=3.8 conda activate pymarl ``` #### 安装基础库 确保安装必要的 Python 库和支持工具,这些对于 PyMARL 和 StarCraft II API 都是必需的。 ```bash pip install numpy scipy matplotlib scikit-image pandas seaborn opencv-python h5py tensorboardX torch torchvision pyyaml ``` #### 下载并设置 SMAC (StarCraft Multi-Agent Challenge) SMAC 是专门为多智能体研究设计的一个挑战平台,它基于暴雪娱乐开发的游戏《星际争霸II》构建而成。 ```bash git clone https://github.com/oxwhirl/smac.git cd smac pip install . ``` #### 获取 PyMARL 仓库 通过 Git 克隆官方 GitHub 上托管的 PyMARL 存储库到本地计算机上。 ```bash cd .. git clone https://github.com/oxwhirl/pymarl.git cd pymarl/src/ ``` #### 修改配置文件 根据需求调整 `default.yaml` 文件内的参数设定,特别是关于地图名称的部分可以按照如下方式指定[^2]。 ```python with open(os.path.join(os.path.dirname(__file__), &quot;config&quot;, &quot;default.yaml&quot;), &quot;r&quot;) as f: try: config_dict = yaml.load(f, Loader=yaml.FullLoader) except yaml.YAMLError as exc: assert False, &quot;default.yaml error: {}&quot;.format(exc) # 将 map_name 设置为 &#39;2s3z&#39; config_dict[&#39;env_args&#39;][&#39;map_name&#39;] = &#39;2s3z&#39; ``` #### 运行实验脚本 最后一步是在命令提示符窗口执行训练模型所需的指令。如果当前工作目录不是 PyMARL 所在位置,则需提供完整的路径至 main.py 文件前缀[^1]。 ```bash python src/main.py --config=qmix --env-config=sc2 with env_args.map_name=2s3z ``` 以上步骤涵盖了从准备环境到最后启动实验所需的一切准备工作。遵循上述指导应该能够顺利完成 PyMARL 及其关联组件的部署过程[^3]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值