此文章与 DP3复现代码运行逻辑全流程(二)——部署全流程代码eval_policy.sh 紧密相连
在 eval_policy.sh 中定义了相关参数,然后在 eval_policy.py 中进一步运行
该脚本使用 Hydra 框架管理配置的评估脚本。从配置文件中读取参数,设置项目的根目录,创建一个工作区实例,并调用其评估方法
接下来分析一下其运行逻辑
目录
1 函数入口 if __name__ == "__main__":
1 函数入口 if __name__ == "__main__":
Python 标准入口点,确保只有在脚本被直接执行时才会运行
2 导入库函数
import sys, import os, import pathlib: sys 用于访问系统参数和函数,os 用于与操作系统交互,pathlib 用于处理路径
import hydra: Hydra 用于配置管理
import torch: PyTorch 用于深度学习计算
import dill: Dill 用于序列化对象
from omegaconf import OmegaConf: OmegaConf 用于处理配置文件
from train import TrainDP3Workspace: TrainDP3Workspace 类负责管理训练和评估
3 设置根目录
ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent): 获取当前文件的上上级目录
sys.path.append(ROOT_DIR): 将根目录添加到系统路径,允许在项目的其他模块中导入
os.chdir(ROOT_DIR): 更改当前工作目录为项目的根目录
4 注册新的 OmegaConf 解析器:
OmegaConf.register_new_resolver("eval", eval, replace=True): 注册一个新的解析器,用于在配置文件中处理名为 "eval" 的表达式
5 定义主函数
采用了 Hydra 这个库,基础知识可以参见 DP3复现基础知识(一)—— Hydra 库
@hydra.main(
version_base=None,
config_path=str(pathlib.Path(__file__).parent.joinpath(
'diffusion_policy_3d', 'config'))
)
装饰器,用于定义 Hydra 应用的入口点。指定了配置文件的位置
此处展开说明一下:
这个装饰器告诉 Hydra 使用指定的 config_path 路径中的配置文件来加载默认配置
当命令行中传入参数(如 task=${task_name})时,Hydra 会自动覆盖相应的配置项,并生成一个包含所有配置的 cfg 对象
config_path 则指定了配置文件所在的文件夹路径
例如,config_name = dp3,在 diffusion_policy_3d/config 文件夹下,Hydra 会自动找到它
def main(cfg):
workspace = TrainDP3Workspace(cfg)
workspace.eval()
定义主函数,接收配置对象 cfg
创建工作区并执行评估
实例化 TrainDP3Workspace 类,将配置传递给它。这个类负责管理训练和评估的相关操作
调用工作区中的评估方法,执行模型的评估过程
此处展开说明一下:
Hydra 装饰器将会自动加载配置并将其传递给 main(cfg) 函数中的 cfg 参数
传入的命令行参数会自动覆盖配置中的默认值,并合并成 cfg 对象。通过 cfg,可以访问传入的所有参数
cfg 是 OmegaConf 的一个对象,包含了命令行传入的所有配置项和默认配置。代码中的 TrainDP3Workspace(cfg) 将 cfg 传递给 TrainDP3Workspace 类
假设 TrainDP3Workspace 使用 cfg 的属性(如 cfg.task、cfg.training.device),则这些属性的值已经被命令行传入的参数覆盖
TrainDP3Workspace 类 和 eval() 函数均在 train.py 脚本中
所以此脚本的主要作用,仍是导入库函数及设置工作目录,并没有详细的策略执行函数。最终详细函数与训练过程相似,均在 train.py 脚本中,这很好理解,执行过程只是训练过程的单步前向执行
6 举例
如果执行样例:
bash scripts/eval_policy.sh dp3 adroit_hammer 0112 0 0
在 eval_policy.sh 文件中,task_name = adroit_hammer,alg_name = config_name = dp3
Hydra 会加载 config_name.yaml (dp3.yaml)配置文件,默认配置会加载进 cfg 对象
adroit_hammer 和 training.device="cuda:0" 会覆盖配置文件中的相应值
最终,main(cfg) 函数中的 cfg 对象包含了所有的配置项,可以在 TrainDP3Workspace(cfg) 内部使用
因此,虽然在代码中没有显式看到传参的地方,但 @hydra.main 装饰器自动处理了这一过程,将参数传递并合并到 cfg 对象中供 main(cfg) 使用