Any-point Trajectory Modeling for Policy Learning 使用教程

Any-point Trajectory Modeling for Policy Learning 使用教程

ATM Official codebase for "Any-point Trajectory Modeling for Policy Learning" ATM 项目地址: https://gitcode.com/gh_mirrors/atm9/ATM

1. 项目介绍

本项目是"Any-point Trajectory Modeling for Policy Learning"的官方代码库,旨在为机器人政策学习提供一个基于任意点轨迹建模的框架。该框架通过轨迹变换器(Track Transformer)进行预训练,并进一步指导策略训练,以实现更高效、更准确的机器人控制策略。

2. 项目快速启动

环境搭建

首先,确保您的系统中已安装了Git和Conda。然后执行以下命令克隆项目仓库并创建环境:

git clone --recursive https://github.com/Large-Trajectory-Model/ATM.git
cd ATM/
conda env create -f environment.yml
conda activate atm

依赖安装

在项目环境中,安装必要的第三方库:

pip install -e third_party/robosuite/
pip install -e third_party/robomimic/

数据集预处理

下载LIBERO数据集并预处理:

mkdir data
python -m scripts.download_libero_datasets
python -m scripts.preprocess_libero --suite libero_spatial
python -m scripts.preprocess_libero --suite libero_object
python -m scripts.preprocess_libero --suite libero_goal
python -m scripts.preprocess_libero --suite libero_10
python -m scripts.preprocess_libero --suite libero_90

接着,将数据集分为训练集和验证集:

python -m scripts.split_libero_dataset

下载预训练模型

为了复现论文中的实验结果,需要下载预训练模型:

mkdir results
unzip -o atm_release_checkpoints.zip -d results/
rm atm_release_checkpoints.zip

模型训练

模型训练分为两个阶段:轨迹变换器预训练和基于轨迹的策略训练。

轨迹变换器预训练
python -m scripts.train_libero_track_transformer --suite $SUITE_NAME

其中$SUITE_NAME可以是libero_spatiallibero_objectlibero_goallibero_100

基于轨迹的策略训练

对于简单的行为克隆(BC)基线,使用以下命令:

python -m scripts.train_libero_policy_bc --suite $SUITE_NAME

对于我们的Track-guided策略,使用以下命令:

python -m scripts.train_libero_policy_atm --suite $SUITE_NAME --tt $PATH_TO_TT

其中$PATH_TO_TT是第一阶段预训练的轨迹变换器的路径。

模型评估

评估模型性能,使用以下命令:

python -m scripts.eval_libero_policy --suite $SUITE_NAME --exp-dir $PATH_TO_EXP

其中$SUITE_NAME是测试的套件名称,$PATH_TO_EXP是训练策略文件夹的路径。

3. 应用案例和最佳实践

  • 案例1:使用轨迹变换器进行预训练,然后使用BC策略训练进行简单的机器人控制。
  • 最佳实践:在进行策略训练之前,确保轨迹变换器的预训练损失收敛到一个较低的值。

4. 典型生态项目

  • Robosuite:一个用于模拟和测试机器人算法的开源机器人模拟器。
  • Robomimic:一个用于模仿学习的数据集和工具箱。

ATM Official codebase for "Any-point Trajectory Modeling for Policy Learning" ATM 项目地址: https://gitcode.com/gh_mirrors/atm9/ATM

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

霍妲思

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值