OpenPI LoRA 微调操作指南
适用场景:RTX 4090(24GB 显存)用户的完整训练流程
负一、核心概念
LoRA 微调 vs 全量微调
| 对比项 | LoRA 微调 | 全量微调 |
|---|---|---|
| 显存需求 | 22.5 GB ✅ | 70 GB ❌ |
| 训练参数 | 仅 1-2% | 100% |
| 训练速度 | 快(6-12h) | 慢(24h+) |
| RTX 4090 | 可用 | 不可用 |
LoRA 原理:冻结原始模型参数,只训练小型适配器矩阵
# 原始: W [1024×2048] = 2,097,152 参数(冻结)
# LoRA: A [1024×8] + B [8×2048] = 24,768 参数(训练)
# 输出 = Input × (W + A×B)
零、重要路径说明
1. 数据集存储位置
LIBERO 数据集(自动下载):
Windows: C:\Users\<用户名>\.cache\huggingface\datasets\physical-intelligence___libero\
Linux/Mac: ~/.cache/huggingface/datasets/physical-intelligence___libero/
结构:
├── default/
│ ├── 0.0.0/
│ │ ├── dataset_info.json
│ │ └── parquet files
│ └── ...
自定义数据集:
~/.cache/huggingface/datasets/<你的用户名>___<数据集名>/
2. Checkpoint 保存位置
openpi\checkpoints\<配置名称>\<实验名称>\<步数>\
示例:
openpi\checkpoints\my_lora_finetune\test_run_1\
├── 1000/
│ ├── params/ # 模型权重(包含 LoRA 参数)
│ │ ├── manifest.ocdbt # 元数据
│ │ └── state # 训练状态
│ └── assets/
│ └── physical-intelligence_libero/
│ └── norm_stats.json # 归一化统计(推理必需)
├── 2000/
├── ...
└── 30000/ # 最终模型
3. 预训练权重位置
云端(自动下载):
gs://openpi-assets/checkpoints/pi0_base/params
本地缓存:
~/.cache/openpi/checkpoints/pi0_base/
一、训练配置
1.1 编辑配置文件
打开 src/openpi/training/config.py,滚动到文件末尾的 _CONFIGS 列表(约第 850 行),添加以下配置:
先介绍两个微调过程中的关键配置参数:
config-name:配置名称,唯一标识exp-name:实验名称,用于区分不同实验
# filepath: openpi\src\openpi\training\config.py
# 在文件末尾的 _CONFIGS 列表中添加(约第 850 行)
TrainConfig(
# ============ 基础配置 ============
name="my_lora_finetune", # 唯一配置名称
exp_name="run_v1", # 实验名称(命令行可覆盖)
# ============ 模型配置(LoRA 变体)============
model=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora", # 视觉-语言模型用 LoRA
action_expert_variant="gemma_300m_lora" # 动作专家用 LoRA
),
# ============ 数据配置 ============
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero", # 数据集 ID
base_config=DataConfig(
prompt_from_task=True, # 从数据集加载任务提示词
),
extra_delta_transform=True, # LIBERO 需要
),
# ============ 预训练权重 ============
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_base/params"
),
# ============ 冻结参数(关键!)============
freeze_filter=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora"
).get_freeze_filter(), # 自动冻结非 LoRA 参数
# ============ 训练超参数 ============
num_train_steps=30_000,
batch_size=32, # RTX 4090 可用 32-64
# 学习率调度器类名
lr_schedule=_optimizer.CosineDecaySchedule(
warmup_steps=1_000,
peak_lr=3e-4, # LoRA 可以用更高的学习率
decay_steps=30_000,
decay_lr=1e-6,
),
optimizer=_optimizer.AdamW(
clip_gradient_norm=1.0,
),
# ============ 关闭 EMA 节省显存 ============
ema_decay=None, # LoRA 不需要指数移动平均
# ============ 其他配置 ============
log_interval=100,
save_interval=1000,
keep_period=5000,
wandb_enabled=True,
)
1.2 配置文件详解
1.2.1 模型配置参数
model=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora", # 选项:gemma_2b, gemma_2b_lora
action_expert_variant="gemma_300m_lora" # 选项:gemma_300m, gemma_300m_lora
)
变体说明:
| 变体名称 | 参数量 | 用途 | 显存需求 |
|---|---|---|---|
gemma_2b | 2B | 全量微调 | 70+ GB |
gemma_2b_lora | 2B (训练 ~20M) | LoRA 微调 | 22.5 GB |
gemma_300m | 300M | 全量微调 | 70+ GB |
gemma_300m_lora | 300M (训练 ~3M) | LoRA 微调 | 22.5 GB |
源码位置:src/openpi/models/pi0_config.py 第 17-78 行
1.2.2 数据配置参数
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(
prompt_from_task=True,
),
extra_delta_transform=True,
)
参数说明:
| 参数 | 类型 | 作用 | 示例值 |
|---|---|---|---|
repo_id | str | HuggingFace 数据集 ID | "physical-intelligence/libero" |
prompt_from_task | bool | 是否从数据集的 task 字段加载提示词 | True |
extra_delta_transform | bool | 是否额外进行 delta 转换(LIBERO 特有) | True |
源码位置:src/openpi/training/config.py 第 374-452 行
📖 关于 repo_id 的详细解释
1. 为什么叫 “repo_id” 而不是 “dataset_id”?
repo_id 这个命名来自 HuggingFace Hub 的设计理念:
- HuggingFace Hub 使用 Git 仓库(Repository) 来管理数据集、模型和空间(Spaces)
- 每个数据集本质上是一个 Git 仓库,包含:
- 数据文件(Parquet、Arrow、CSV 等)
- 元数据文件(
dataset_info.json、README.md) - 版本控制历史(Git commits)
- 因此使用
repo_id强调了数据集的仓库属性
命名格式:<用户名或组织名>/<数据集名称>
# 示例
repo_id = "physical-intelligence/libero" # 组织:physical-intelligence
repo_id = "lerobot/aloha_sim_transfer_cube" # 组织:lerobot
repo_id = "your_username/my_aloha_dataset" # 个人用户
2. HuggingFace Hub 上有哪些可用的机器人数据集?
访问 https://huggingface.co/datasets 并搜索关键词:
| 关键词 | 数据集类型 | 示例 repo_id |
|---|---|---|
lerobot | LeRobot 格式数据集 | lerobot/aloha_sim_transfer_cube_human |
libero | LIBERO 任务数据集 | physical-intelligence/libero |
aloha | ALOHA 机器人数据集 | physical-intelligence/aloha_pen_uncap_diverse |
droid | DROID 机器人数据集 | 需要自行转换后上传 |
robomimic | Robomimic 数据集 | eai-lab/robomimic |
1.2.3 学习率调度器
lr_schedule=_optimizer.CosineDecaySchedule(
warmup_steps=1_000, # 预热步数
peak_lr=3e-4, # 峰值学习率
decay_steps=30_000, # 衰减步数
decay_lr=1e-6, # 最终学习率
)
可用的学习率调度器:
| 调度器类 | 特点 | 适用场景 |
|---|---|---|
CosineDecaySchedule | 余弦衰减 + 预热 | 推荐用于 LoRA |
LinearWarmupSchedule | 线性预热 | 简单任务 |
源码位置:src/openpi/training/optimizer.py 第 1-80 行
1.2.4 冻结参数机制
freeze_filter=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora"
).get_freeze_filter()
工作原理:
# src/openpi/models/pi0_config.py 第 80-106 行
def get_freeze_filter(self):
"""返回需要冻结的参数过滤器"""
filters = []
if "lora" in self.paligemma_variant:
# 冻结 Gemma 的所有参数
filters.append(nnx_utils.PathRegex(".*llm.*"))
if "lora" in self.action_expert_variant:
# 冻结动作专家的所有参数
filters.append(nnx_utils.PathRegex(".*llm.*_1.*"))
# 排除所有 LoRA 参数(允许训练)
filters.append(nnx.Not(nnx_utils.PathRegex(".*lora.*")))
return nnx.All(*filters)
效果示意:
# 模型参数示例
model.llm.layer_0.weight # ❌ 冻结(~2GB)
model.llm.layer_0.lora_a # ✅ 训练(~8MB)
model.llm.layer_0.lora_b # ✅ 训练(~8MB)
model.action_expert.layer_1.weight # ❌ 冻结(~300MB)
model.action_expert.layer_1.lora_a # ✅ 训练(~3MB)
1.2.5 EMA(指数移动平均)
ema_decay=None # LoRA 微调建议关闭
什么是 EMA:
EMA t = α ⋅ θ t + ( 1 − α ) ⋅ EMA t − 1 \text{EMA}_t = \alpha \cdot \theta_t + (1-\alpha) \cdot \text{EMA}_{t-1} EMAt=α⋅θt+(1−α)⋅EMAt−1
- 维护一份参数的移动平均版本
- 推理时使用 EMA 参数,通常更稳定
为什么 LoRA 关闭 EMA:
- 节省显存:EMA 需要额外存储一份完整参数(~10GB)
- LoRA 本身已足够稳定:少量参数不易震荡
源码位置:src/openpi/training/config.py 第 481 行
二、归一化详解
2.1 为什么需要归一化
问题场景:
# 原始数据(未归一化)
state = [
0.5, # 关节 1 角度(弧度)
-2.3, # 关节 2 角度(弧度)
1500.0, # 力传感器读数(牛顿)
0.1, # 夹爪开合度(0-1)
]
问题:
- 数值跨度巨大(0.1 到 1500)
- 力传感器会主导梯度计算
- 训练不稳定,难以收敛
归一化后:
# Z-score 归一化:(x - mean) / std
state_normalized = [
-0.23, # (0.5 - 0.6) / 0.4
0.15, # (-2.3 - (-2.5)) / 0.5
0.08, # (1500 - 1480) / 250
-0.10, # (0.1 - 0.15) / 0.05
]
效果:
- ✅ 均值接近 0,方差接近 1
- ✅ 各维度贡献均衡
- ✅ 梯度下降更高效
2.2 两种归一化方法
方法 1:Z-score 归一化(Pi0)
# config.py 第 211 行
use_quantile_norm=False # 使用 Z-score
# 公式
normalized = (x - mean) / std
特点:
- 假设数据服从正态分布
- 对离群值敏感
方法 2:Quantile 归一化(Pi0.5)
# config.py 第 211 行
use_quantile_norm=True # 使用 Quantile
# 公式
normalized = (x - median) / IQR
# IQR = Q75 - Q25(四分位距)
特点:
- 基于分位数,更鲁棒
- 对离群值不敏感
源码位置:src/openpi/shared/normalize.py 第 1-150 行
2.3 归一化统计的加载
# config.py 第 209-223 行
def _load_norm_stats(
self,
assets_dir: epath.Path,
asset_id: str | None
) -> dict[str, _transforms.NormStats] | None:
"""从 assets 目录加载归一化统计"""
if asset_id is None:
return None
try:
# 构建路径:assets/<asset_id>/norm_stats.json
data_assets_dir = str(assets_dir / asset_id)
norm_stats = _normalize.load(_download.maybe_download(data_assets_dir))
logging.info(f"Loaded norm stats from {data_assets_dir}")
return norm_stats
except FileNotFoundError:
logging.info(f"Norm stats not found in {data_assets_dir}, skipping.")
return None
加载路径:
checkpoints/my_lora_finetune/assets/physical-intelligence_libero/norm_stats.json
文件格式:
{
"state": {
"mean": [0.1, -0.3, 0.5, ...],
"std": [0.4, 0.6, 0.8, ...],
"min": [-1.2, -2.5, -0.3, ...],
"max": [1.5, 2.1, 1.8, ...]
},
"actions": {
"mean": [0.0, 0.0, 0.0, ...],
"std": [0.2, 0.3, 0.1, ...],
"min": [-0.5, -0.6, -0.3, ...],
"max": [0.5, 0.6, 0.3, ...]
}
}
2.4 重用预训练的归一化统计
场景:您的机器人与预训练数据集中的机器人相同
TrainConfig(
name="my_aloha_lora",
data=LeRobotAlohaDataConfig(
repo_id="your_username/my_aloha_dataset",
assets=AssetsConfig(
# 重用 pi0_base 的 Trossen 机器人归一化统计
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
asset_id="trossen", # 关键:指定机器人类型
),
),
# ...
)
可用的预训练归一化统计:
| 机器人 | asset_id | 适用模型 |
|---|---|---|
| Trossen (ALOHA) | trossen | pi0_base, pi0_fast_base |
| DROID Franka | droid | pi0_base, pi05_base |
| UR5e | ur5e | pi0_base |
| Panda | panda | pi0_base |
源码位置:docs/norm_stats.md 第 9-50 行
三、 完整训练流程
步骤 1:添加配置
编辑 src/openpi/training/config.py,添加上述配置。
步骤 2:计算归一化统计
uv run scripts/compute_norm_stats.py --config-name my_lora_finetune
步骤 3:启动训练
3.1 直接启动(前台运行)
# Windows PowerShell
$env:XLA_PYTHON_CLIENT_MEM_FRACTION="0.9"
uv run scripts/train.py my_lora_finetune --exp-name=test_run_1 --overwrite
# Linux/macOS
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
uv run scripts/train.py my_lora_finetune \
--exp-name=test_run_1 \
--overwrite
缺点:关闭终端窗口,训练会中断 ❌
3.2 使用 tmux 后台运行(推荐,Linux/macOS)
📦 安装 tmux(如果还没有)
# Ubuntu/Debian
sudo apt-get install tmux
# CentOS/RHEL
sudo yum install tmux
# macOS
brew install tmux
🚀 使用 tmux 启动训练
# 1. 创建名为 "openpi_train" 的 tmux 会话
tmux new -s openpi_train
# 2. 在 tmux 会话中设置环境变量并启动训练
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
cd /path/to/openpi-main
uv run scripts/train.py my_lora_finetune --exp-name=test_run_1 --overwrite
# 3. 分离会话(训练在后台继续运行)
# 按键:Ctrl + B,然后按 D
🔍 管理 tmux 会话
查看所有会话:
tmux ls
# 输出示例:
# openpi_train: 1 windows (created Thu Jan 15 10:30:00 2025)
重新连接到会话:
# 连接到 openpi_train 会话
tmux attach -t openpi_train
# 或简写
tmux a -t openpi_train
再次分离会话:
按键组合:Ctrl + B,然后按 D
杀死会话(训练完成后):
# 方法 1:在会话内部
exit
# 方法 2:从外部杀死
tmux kill-session -t openpi_train
💡 tmux 常用快捷键
| 快捷键 | 功能 |
|---|---|
Ctrl + B → D | 分离会话(detach) |
Ctrl + B → C | 创建新窗口 |
Ctrl + B → N | 切换到下一个窗口 |
Ctrl + B → P | 切换到上一个窗口 |
Ctrl + B → % | 垂直分割窗口 |
Ctrl + B → " | 水平分割窗口 |
Ctrl + B → 方向键 | 在分割的窗格间切换 |
Ctrl + B → [ | 进入滚动模式(查看历史输出) |
步骤 4:监控训练进度
4.1 本地终端日志
训练过程中,终端会实时显示:
[2025/01/15 10:30:20] Step 0: loss=2.345, lr=0.0
[2025/01/15 10:30:25] Step 100: loss=1.892, lr=1.2e-4
[2025/01/15 10:30:30] Step 200: loss=1.567, lr=2.1e-4
...
4.2 使用 Weights & Biases 可视化监控
训练启动后,Wandb 会自动记录数据。
4.3 检查本地 Checkpoint
# 查看已保存的 checkpoint
ls checkpoints\my_lora_finetune\test_run_1\
# 输出示例:
# 1000/ (约 500 MB)
# 2000/ (约 500 MB)
# ...
# 30000/ (约 500 MB - 最终模型)
每个 checkpoint 的内容:
30000/
├── params/ # 模型权重(包含 LoRA 参数)
│ ├── manifest.ocdbt # TensorStore 元数据
│ └── state # 训练状态(优化器、步数等)
└── assets/
└── physical-intelligence_libero/
└── norm_stats.json # 归一化统计(推理时必需)
步骤 5:使用训练好的模型推理
训练完成后,您可以:
- 启动策略服务器进行远程推理
- 在仿真环境中测试模型
- 录制 demo 视频展示效果
5.1 启动策略服务器(HTTP API)
基础启动
uv run scripts/serve_policy.py policy:checkpoint \
--policy.config=my_lora_finetune \
--policy.dir=checkpoints/my_lora_finetune/test_run_1/30000 \
--port 8000
启动输出示例:
[2025/01/15 18:00:00] Loading checkpoint from checkpoints/my_lora_finetune/test_run_1/30000
[2025/01/15 18:00:05] Loading norm stats from checkpoints/.../norm_stats.json
[2025/01/15 18:00:10] Model loaded successfully!
[2025/01/15 18:00:10] Starting server on http://0.0.0.0:8000
[2025/01/15 18:00:10] Health check: http://localhost:8000/healthz
验证服务器
# 健康检查
Invoke-WebRequest http://localhost:8000/healthz
# 输出:StatusCode: 200, Content: OK
API 使用示例
Python 客户端:
import requests
import numpy as np
from PIL import Image
# 1. 加载图像
image = Image.open("camera_view.jpg")
image_array = np.array(image) # shape: [H, W, 3]
# 2. 准备机器人状态
state = [0.1, -0.5, 0.3, 0.2, 0.8, 1.0, 0.0] # 7 维状态
# 3. 发送请求
response = requests.post("http://localhost:8000/act", json={
"observation": {
"images": [image_array.tolist()], # 转为列表
"state": state,
},
"prompt": "pick up the red cube", # 任务描述
})
# 4. 获取动作
result = response.json()
actions = np.array(result["action"]) # shape: [10, 7](10 步预测)
print(f"下一步动作: {actions[0]}")
print(f"完整动作序列: {actions.shape}")
JavaScript 客户端:
const response = await fetch('http://localhost:8000/act', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
observation: {
images: [imageArray], // 3D array
state: [0.1, -0.5, 0.3, 0.2, 0.8, 1.0, 0.0]
},
prompt: "pick up the red cube"
})
});
const result = await response.json();
console.log('Next action:', result.action[0]);
5.2 在仿真环境中测试模型
查找可用的仿真任务
查看 gym-aloha GitHub 仓库
- 访问 https://github.com/huggingface/gym-aloha
- 查看
gym_aloha/__init__.py文件,查找register()函数调用 - 查看
gym_aloha/env.py文件中的_make_env_task()方法
根据 gym-aloha 源码 和 env.py,当前支持以下任务:
| 任务 ID | 任务名称 | 描述 | 最大步数 | 奖励机制 |
|---|---|---|---|---|
gym_aloha/AlohaTransferCube-v0 | 转移方块任务 | 右臂抓取红色方块并转移到左臂 | 300 | 1分:右臂抓住方块 2分:方块被抬起 3分:转移到左臂 4分:成功转移且不触桌面 |
gym_aloha/AlohaInsertion-v0 | 插入任务 | 左右臂分别抓取插座和插针,然后在空中插入 | 300 | 1分:触摸插针和插座 2分:抓住两者 3分:对齐并接触 4分:成功插入 |
🚀 运行仿真环境测试
使用 examples/aloha_sim/run_task.py 脚本:
# examples/aloha_sim/run_task.py
import numpy as np
import imageio
import tqdm
from openpi_client import websocket_client_policy
from examples.aloha_sim import env as _env
# --- 1. 配置 ---
# 服务器地址和端口
SERVER_HOST = "localhost"
SERVER_PORT = 8000
# 仿真环境配置
SIM_TASK_NAME = "gym_aloha/AlohaTransferCube-v0"
SEED = 42 # 固定随机种子以保证可复现性
# 视频保存配置
OUTPUT_VIDEO_PATH = "data/aloha_sim/multi_prompt_task.mp4"
VIDEO_FPS = 15
# --- 2. 定义多阶段任务和自定义提示词 ---
# 每个阶段包含:提示词、执行的步数
task_phases = [
{"prompt": "pick up the cube", "steps": 200}
]
def main():
"""
主函数:连接服务器,运行多阶段任务,并保存视频。
"""
print(">> 正在初始化策略客户端...")
try:
client = websocket_client_policy.WebsocketClientPolicy(
host=SERVER_HOST, port=SERVER_PORT
)
print(f">> 成功连接到策略服务器 at {SERVER_HOST}:{SERVER_PORT}")
except Exception as e:
print(f"!! 无法连接到服务器: {e}")
print("!! 请确保 'serve_policy.py' 正在运行。")
return
print(f">> 正在创建仿真环境: {SIM_TASK_NAME}")
env = _env.AlohaSimEnvironment(task=SIM_TASK_NAME, seed=SEED)
env.reset()
print(">> 仿真环境已重置。")
# 用于存储视频帧的列表
video_frames = []
# 计算总步数用于进度条
total_steps = sum(phase["steps"] for phase in task_phases)
with tqdm.tqdm(total=total_steps, desc="执行任务") as pbar:
# --- 3. 循环执行每个任务阶段 ---
for phase in task_phases:
current_prompt = phase["prompt"]
steps_for_phase = phase["steps"]
pbar.set_description(f"阶段: {current_prompt}")
for _ in range(steps_for_phase):
# a. 从环境中获取观测
observation = env.get_observation()
# b. 添加当前的自定义提示词
observation["prompt"] = current_prompt
# c. 发送观测到服务器并获取动作
try:
output = client.infer(observation)
action = output["actions"][0] # 取动作序列的第一个动作
except Exception as e:
print(f"\n!! 推理失败: {e}")
print("!! 检查服务器连接和状态。")
return
# d. 将动作应用到仿真环境
env.apply_action({"actions": action})
# e. 渲染并保存当前帧用于录制视频
# 注意:我们需要 'cam_high' 图像,并将其从 CHW 转换为 HWC 格式
frame = observation["images"]["cam_high"]
frame_hwc = np.transpose(frame, (1, 2, 0))
video_frames.append(frame_hwc)
pbar.update(1)
print(f"\n>> 任务完成。总共录制了 {len(video_frames)} 帧。")
# --- 4. 保存视频 ---
print(f">> 正在将视频保存到: {OUTPUT_VIDEO_PATH}")
try:
imageio.mimwrite(OUTPUT_VIDEO_PATH, video_frames, fps=VIDEO_FPS, quality=8)
print(">> 视频保存成功!")
except Exception as e:
print(f"!! 视频保存失败: {e}")
if __name__ == "__main__":
main()
完整步骤:
# 1. 启动策略服务器(在一个终端)
uv run scripts/serve_policy.py policy:checkpoint \
--policy.config=my_lora_finetune \
--policy.dir=checkpoints/my_lora_finetune/test_run_1/30000 \
--port 8000
# 2. 运行仿真脚本(在另一个终端)
cd examples/aloha_sim
# 运行转移方块任务
uv run python run_task.py
# 或者直接指定任务(修改 run_task.py 中的 SIM_TASK_NAME)
自定义多阶段任务示例:
您可以修改 run_task.py 中的 task_phases 列表来创建自定义的多阶段任务:
# run_task.py 中的自定义提示词示例
task_phases = [
{"prompt": "pick up the cube", "steps": 70},
{"prompt": "move the cube to the right side", "steps": 80},
{"prompt": "place the cube into the bowl", "steps": 60},
{"prompt": "move your arms back to a neutral position", "steps": 40},
]
5.3 评估模型性能
💡 关于评估代码
OpenPI 项目已包含官方评估脚本,位于:
- LIBERO 评估:
examples/libero/main.py - ALOHA 仿真评估:
examples/aloha_sim/main.py
这些脚本是官方提供的标准评估工具,无需自己编写。下面将介绍如何使用它们。
使用官方评估脚本(推荐)⭐
方法 1:LIBERO 环境评估
适用场景:评估在 LIBERO 数据集上训练的模型
步骤:
# 1. 启动策略服务器(在一个终端)
uv run scripts/serve_policy.py policy:checkpoint \
--policy.config=my_lora_finetune \
--policy.dir=checkpoints/my_lora_finetune/test_run_1/30000 \
--port 8000
# 2. 运行评估脚本(在另一个终端)
cd examples/libero
# 评估 libero_spatial 任务套件(10 个任务,每个任务 50 次)
uv run main.py \
--host 0.0.0.0 \
--port 8000 \
--task-suite-name libero_spatial \
--num-trials-per-task 50 \
--video-out-path ../../data/libero/eval_videos
# 其他可用任务套件:
# --task-suite-name libero_object # 物体操作任务
# --task-suite-name libero_goal # 目标导向任务
# --task-suite-name libero_10 # LIBERO-10 基准测试
# --task-suite-name libero_90 # LIBERO-90 基准测试
输出示例:
Task suite: libero_spatial
Task: pick up the red block and put it in the drawer
Starting episode 1...
Success: True
# episodes completed so far: 1
# successes: 1 (100.0%)
Task: pick up the red block and put it in the drawer
Starting episode 2...
Success: True
# episodes completed so far: 2
# successes: 2 (100.0%)
...
Current task success rate: 0.86 # 当前任务成功率 86%
Current total success rate: 0.84 # 总体成功率 84%
Total success rate: 0.84 # 最终成功率
Total episodes: 500 # 总测试回合数
结果文件:
data/libero/eval_videos/
├── rollout_pick_up_the_red_block_success.mp4
├── rollout_pick_up_the_red_block_failure.mp4
├── rollout_stack_the_blue_cube_success.mp4
└── ...
方法 2:ALOHA 仿真评估
适用场景:评估在 ALOHA 仿真环境训练的模型
评估脚本代码 (examples/aloha_sim/main.py):
import dataclasses
import logging
import pathlib
import env as _env
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import saver as _saver
import tyro
@dataclasses.dataclass
class Args:
"""ALOHA 仿真评估参数配置"""
# 视频保存目录
out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
# 仿真任务名称
task: str = "gym_aloha/AlohaTransferCube-v0"
# 随机种子
seed: int = 0
# 动作序列长度(必须与训练时一致)
action_horizon: int = 10
# 策略服务器地址
host: str = "0.0.0.0"
port: int = 8000
# 是否显示仿真画面
display: bool = False
def main(args: Args) -> None:
"""运行 ALOHA 仿真评估
工作流程:
1. 创建仿真环境 (AlohaSimEnvironment)
2. 连接到策略服务器 (WebsocketClientPolicy)
3. 使用动作分块代理执行策略 (ActionChunkBroker)
4. 保存评估视频 (VideoSaver)
"""
runtime = _runtime.Runtime(
# 初始化 ALOHA 仿真环境
environment=_env.AlohaSimEnvironment(
task=args.task,
seed=args.seed,
),
# 初始化策略代理
agent=_policy_agent.PolicyAgent(
policy=action_chunk_broker.ActionChunkBroker(
# 连接到远程策略服务器
policy=_websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
),
action_horizon=args.action_horizon,
)
),
# 添加视频保存器
subscribers=[
_saver.VideoSaver(args.out_dir),
],
# 控制频率 50Hz
max_hz=50,
)
# 运行评估
runtime.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
tyro.cli(main)
运行步骤:
# 1. 启动策略服务器(在一个终端)
uv run scripts/serve_policy.py policy:checkpoint \
--policy.config=pi0_aloha_sim_lora \
--policy.dir=checkpoints/pi0_aloha_sim_lora/run_1/20000 \
--port 8000
# 2. 运行评估脚本(在另一个终端)
cd examples/aloha_sim
uv run main.py \
--host 0.0.0.0 \
--port 8000 \
--task gym_aloha/AlohaTransferCube-v0 \
--out-dir ../../data/aloha_sim/eval_videos
可用任务:
| 任务名称 | 描述 |
|---|---|
gym_aloha/AlohaTransferCube-v0 | 转移方块 |
gym_aloha/AlohaInsertion-v0 | 插入任务 |
gym_aloha/AlohaThreadVelcro-v0 | 穿线任务 |
5.4 部署到真实机器人
对于真实机器人部署,参考:
- ALOHA 真机:
examples/aloha_real/README.md - DROID 机器人:
examples/droid/README.md - UR5 机械臂:
examples/ur5/README.md
基本流程:
- 启动策略服务器(在 GPU 服务器上)
- 连接机器人硬件
- 运行控制脚本(读取相机 → 调用策略 → 发送动作)
# 真实机器人推理示例(伪代码)
import robot_interface # 机器人控制库
robot = robot_interface.connect()
camera = robot_interface.Camera()
while True:
# 1. 获取观测
image = camera.capture()
state = robot.get_state()
# 2. 调用策略
response = requests.post("http://gpu-server:8000/act", json={
"observation": {"images": [image], "state": state},
"prompt": "pick up the object"
})
action = response.json()["action"][0]
# 3. 执行动作
robot.execute(action)
time.sleep(0.1) # 控制频率 10Hz
四、 不同机器人平台的训练配置总结
本章节根据 config.py 中的 TrainConfig 配置,总结了各种机器人平台的训练配置方案。
4.1 LIBERO 仿真环境
方案 A:Pi0 全量微调(需要 70GB+ 显存)
TrainConfig(
name="pi0_libero",
# 全量微调模型(所有参数都训练)
model=pi0_config.Pi0Config(),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(
prompt_from_task=True, # 从数据集加载任务提示词
),
extra_delta_transform=True, # LIBERO 特有的 delta 转换
),
# 加载 pi0_base 预训练权重
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_base/params"
),
num_train_steps=30_000,
# 其他参数使用默认值
)
特点:
- ✅ 训练效果最好
- ❌ 需要 70GB+ 显存(A100/H100)
- ❌ RTX 4090 无法使用
方案 B:Pi0 LoRA 微调(推荐 RTX 4090)⭐
TrainConfig(
name="pi0_libero_low_mem_finetune",
# LoRA 变体:只训练 LoRA 参数
model=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora", # 视觉-语言模型用 LoRA
action_expert_variant="gemma_300m_lora" # 动作专家用 LoRA
),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(prompt_from_task=True),
extra_delta_transform=True,
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_base/params"
),
num_train_steps=30_000,
# LoRA 专用设置
freeze_filter=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora"
).get_freeze_filter(), # 冻结非 LoRA 参数
ema_decay=None, # 关闭 EMA 节省显存
)
特点:
- ✅ 显存需求 22.5GB(RTX 4090 可用)
- ✅ 训练速度快
- ✅ 推荐配置
使用命令:
# 1. 计算归一化统计
uv run scripts/compute_norm_stats.py --config-name pi0_libero_low_mem_finetune
# 2. 启动训练
uv run scripts/train.py pi0_libero_low_mem_finetune --exp-name=my_run_1 --overwrite
方案 C:Pi0-FAST 全量微调
TrainConfig(
name="pi0_fast_libero",
# Pi0-FAST 模型:更快的推理速度
model=pi0_fast.Pi0FASTConfig(
action_dim=7, # LIBERO 动作维度
action_horizon=10, # 动作序列长度
max_token_len=180, # 最大 token 长度(单臂机器人)
),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(prompt_from_task=True),
extra_delta_transform=True,
),
# 加载 pi0_fast_base 预训练权重
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_fast_base/params"
),
num_train_steps=30_000,
)
特点:
- ✅ 推理速度更快(适合实时控制)
- ❌ 需要 70GB+ 显存
- ⚠️ 需要调整
max_token_len:单臂 ~180,双臂 ~250
方案 D:Pi0-FAST LoRA 微调(推荐 RTX 4090)⭐
TrainConfig(
name="pi0_fast_libero_low_mem_finetune",
# Pi0-FAST LoRA 变体
model=pi0_fast.Pi0FASTConfig(
action_dim=7,
action_horizon=10,
max_token_len=180,
paligemma_variant="gemma_2b_lora" # FAST 只需冻结 PaliGemma
),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(prompt_from_task=True),
extra_delta_transform=True,
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_fast_base/params"
),
num_train_steps=30_000,
# LoRA 专用设置
freeze_filter=pi0_fast.Pi0FASTConfig(
action_dim=7,
action_horizon=10,
max_token_len=180,
paligemma_variant="gemma_2b_lora"
).get_freeze_filter(),
ema_decay=None,
)
特点:
- ✅ 低显存(22.5GB)+ 快速推理
- ✅ 适合实时机器人控制
- ✅ RTX 4090 最佳选择
方案 E:Pi0.5 微调(最新架构)
TrainConfig(
name="pi05_libero",
# Pi0.5 模型(改进的架构)
model=pi0_config.Pi0Config(
pi05=True, # 启用 Pi0.5
action_horizon=10,
discrete_state_input=False, # 使用连续状态输入
),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(prompt_from_task=True),
extra_delta_transform=False, # Pi0.5 不需要
),
batch_size=256, # 更大的批次
# 不同的学习率调度
lr_schedule=_optimizer.CosineDecaySchedule(
warmup_steps=10_000,
peak_lr=5e-5, # 较低的学习率
decay_steps=1_000_000,
decay_lr=5e-5,
),
optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
ema_decay=0.999, # Pi0.5 使用 EMA
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi05_base/params"
),
num_train_steps=30_000,
)
特点:
- ✅ 最新架构,性能更好
- ❌ 需要大显存
- ⚠️ 需要 PyTorch 权重路径(如果转换)
4.2 ALOHA 真实机器人
方案 A:Pi0 ALOHA 微调(自定义数据集)
TrainConfig(
name="pi0_aloha_pen_uncap",
model=pi0_config.Pi0Config(),
data=LeRobotAlohaDataConfig(
# 替换为您自己的 ALOHA 数据集
repo_id="physical-intelligence/aloha_pen_uncap_diverse",
# 重用 pi0_base 的 Trossen 归一化统计
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
asset_id="trossen", # Trossen 机器人配置
),
default_prompt="uncap the pen", # 默认任务提示词
# 数据重打包(适配 LeRobot 格式)
repack_transforms=_transforms.Group(
inputs=[
_transforms.RepackTransform({
"images": {
"cam_high": "observation.images.cam_high",
"cam_left_wrist": "observation.images.cam_left_wrist",
"cam_right_wrist": "observation.images.cam_right_wrist",
},
"state": "observation.state",
"actions": "action",
})
]
),
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_base/params"
),
num_train_steps=20_000, # ALOHA 数据集较小
)
适用场景:
- ✅ ALOHA 真机数据集
- ✅ 使用 Trossen 机器人
- ⚠️ 需要先转换为 LeRobot 格式(参考
examples/aloha_real/README.md)
数据转换命令:
cd examples/aloha_real
# 转换 ALOHA HDF5 数据到 LeRobot 格式
python convert_aloha_data_to_lerobot.py \
--input-dir /path/to/aloha_data \
--output-repo your_hf_username/my_aloha_dataset
方案 B:Pi0 ALOHA LoRA 微调(RTX 4090)
TrainConfig(
name="pi0_aloha_lora",
# LoRA 变体
model=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora"
),
data=LeRobotAlohaDataConfig(
repo_id="your_hf_username/my_aloha_dataset",
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
asset_id="trossen",
),
default_prompt="your task description",
repack_transforms=_transforms.Group(
inputs=[
_transforms.RepackTransform({
"images": {
"cam_high": "observation.images.cam_high",
"cam_left_wrist": "observation.images.cam_left_wrist",
"cam_right_wrist": "observation.images.cam_right_wrist",
},
"state": "observation.state",
"actions": "action",
})
]
),
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_base/params"
),
num_train_steps=20_000,
# LoRA 设置
freeze_filter=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora"
).get_freeze_filter(),
ema_decay=None,
)
特点:
- ✅ RTX 4090 可用(22.5GB)
- ✅ 适合小数据集快速迭代
- ⚠️ 需要调整
default_prompt为您的任务
方案 C:Pi0.5 ALOHA 微调
TrainConfig(
name="pi05_aloha_pen_uncap",
model=pi0_config.Pi0Config(pi05=True),
data=LeRobotAlohaDataConfig(
repo_id="physical-intelligence/aloha_pen_uncap_diverse",
# 使用 pi05_base 的归一化统计
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets",
asset_id="trossen",
),
default_prompt="uncap the pen",
repack_transforms=_transforms.Group(
inputs=[
_transforms.RepackTransform({
"images": {
"cam_high": "observation.images.cam_high",
"cam_left_wrist": "observation.images.cam_left_wrist",
"cam_right_wrist": "observation.images.cam_right_wrist",
},
"state": "observation.state",
"actions": "action",
})
]
),
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi05_base/params"
),
num_train_steps=20_000,
batch_size=64, # Pi0.5 可以用更大的批次
)
4.3 ALOHA 仿真环境
TrainConfig(
name="pi0_aloha_sim",
model=pi0_config.Pi0Config(),
data=LeRobotAlohaDataConfig(
# LeRobot 官方 ALOHA 仿真数据集
repo_id="lerobot/aloha_sim_transfer_cube_human",
default_prompt="Transfer cube",
# 仿真环境不使用 delta 动作
use_delta_joint_actions=False,
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_base/params"
),
num_train_steps=20_000,
)
适用场景:
- ✅ 快速原型验证
- ✅ 算法开发和测试
- ⚠️ 不需要真实机器人硬件
LoRA 版本:
TrainConfig(
name="pi0_aloha_sim_lora",
model=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora"
),
data=LeRobotAlohaDataConfig(
repo_id="lerobot/aloha_sim_transfer_cube_human",
default_prompt="Transfer cube",
use_delta_joint_actions=False,
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_base/params"
),
num_train_steps=20_000,
freeze_filter=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora"
).get_freeze_filter(),
ema_decay=None,
)
4.4 DROID 机器人(大规模数据集)
方案 A:Pi0-FAST 全量微调(完整 DROID 数据集)
TrainConfig(
name="pi0_fast_full_droid_finetune",
# Pi0-FAST 适合 DROID 的实时控制
model=pi0_fast.Pi0FASTConfig(
action_dim=8, # DROID 动作维度
action_horizon=16, # 更长的动作序列
max_token_len=180,
),
# 使用 RLDS 数据加载器(大规模数据集)
data=RLDSDroidDataConfig(
repo_id="droid",
# ⚠️ 设置为您的 DROID RLDS 数据集路径
rlds_data_dir="<path_to_droid_rlds_dataset>",
action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_fast_base/params"
),
# 更低的学习率(大数据集)
lr_schedule=_optimizer.CosineDecaySchedule(
warmup_steps=1_000,
peak_lr=5e-5,
decay_steps=1_000_000,
decay_lr=5e-5,
),
num_train_steps=100_000, # 100k 步,约 2 天(8x H100)
batch_size=256,
log_interval=100,
save_interval=5000,
keep_period=20_000,
# ⚠️ 重要:RLDS 必须设置 num_workers=0
num_workers=0,
)
特点:
- ✅ 适合完整 DROID 数据集(数百小时)
- ✅ 使用 RLDS 高效数据加载
- ❌ 需要 8x H100 GPU
- ⚠️
num_workers=0是必需的
方案 B:Pi0.5 DROID 微调(完整数据集)
TrainConfig(
name="pi05_full_droid_finetune",
model=pi0_config.Pi0Config(
pi05=True,
action_dim=32, # Pi0.5 使用 32 维动作
action_horizon=16,
),
data=RLDSDroidDataConfig(
repo_id="droid",
rlds_data_dir="/mnt/pi-data/kevin", # 您的 RLDS 路径
action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,
# 使用 pi05_base 的 DROID 归一化统计
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets/",
asset_id="droid",
),
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi05_base/params"
),
lr_schedule=_optimizer.CosineDecaySchedule(
warmup_steps=1_000,
peak_lr=5e-5,
decay_steps=1_000_000,
decay_lr=5e-5,
),
num_train_steps=100_000,
batch_size=256,
log_interval=100,
save_interval=5000,
keep_period=10_000,
num_workers=0,
)
方案 C:Pi0.5 DROID 微调(小数据集,LeRobot 格式)
TrainConfig(
name="pi05_droid_finetune",
model=pi0_config.Pi0Config(
pi05=True,
action_dim=32,
action_horizon=16,
),
# 使用 LeRobot 数据格式(适合小数据集)
data=LeRobotDROIDDataConfig(
# ⚠️ 替换为您的 DROID 数据集
repo_id="your_hf_username/my_droid_dataset",
base_config=DataConfig(prompt_from_task=True),
# ⚠️ 重要:重用原始 DROID 归一化统计
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi05_droid/assets",
asset_id="droid",
),
),
# 加载 pi05_droid 检查点(已在 DROID 上预训练)
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi05_droid/params"
),
num_train_steps=20_000,
batch_size=32,
)
数据转换:
cd examples/droid
# 转换 DROID 数据到 LeRobot 格式
python convert_droid_data_to_lerobot.py \
--input-dir /path/to/droid_data \
--output-repo your_hf_username/my_droid_dataset
适用场景:
- ✅ 自定义 DROID 数据集(< 10 小时)
- ✅ RTX 4090 可用(batch_size=32)
- ✅ 重用 pi05_droid 预训练权重
4.5 配置对比表
LIBERO 数据集
| 配置名称 | 模型 | LoRA | 显存需求 | 推理速度 | RTX 4090 |
|---|---|---|---|---|---|
pi0_libero | Pi0 | ❌ | 70GB+ | 中等 | ❌ |
pi0_libero_low_mem_finetune | Pi0 | ✅ | 22.5GB | 中等 | ✅ ⭐ |
pi0_fast_libero | Pi0-FAST | ❌ | 70GB+ | 快 | ❌ |
pi0_fast_libero_low_mem_finetune | Pi0-FAST | ✅ | 22.5GB | 快 | ✅ ⭐⭐ |
pi05_libero | Pi0.5 | ❌ | 70GB+ | 中等 | ❌ |
ALOHA 数据集
| 配置名称 | 数据集类型 | LoRA | 显存需求 | RTX 4090 |
|---|---|---|---|---|
pi0_aloha_pen_uncap | 真机 | ❌ | 70GB+ | ❌ |
pi0_aloha_lora (自定义) | 真机 | ✅ | 22.5GB | ✅ ⭐ |
pi05_aloha_pen_uncap | 真机 | ❌ | 40GB+ | ❌ |
pi0_aloha_sim | 仿真 | ❌ | 70GB+ | ❌ |
pi0_aloha_sim_lora (自定义) | 仿真 | ✅ | 22.5GB | ✅ ⭐ |
DROID 数据集
| 配置名称 | 数据规模 | 数据格式 | LoRA | 显存需求 | 推荐 GPU |
|---|---|---|---|---|---|
pi0_fast_full_droid_finetune | 完整 | RLDS | ❌ | 70GB+ | 8x H100 |
pi05_full_droid_finetune | 完整 | RLDS | ❌ | 70GB+ | 8x H100 |
pi05_droid_finetune | 小型 | LeRobot | ❌ | 40GB+ | A100 |
4.6 RTX 4090 用户推荐配置
根据您的需求选择:
🎯 场景 1:LIBERO 仿真训练
推荐:pi0_fast_libero_low_mem_finetune
# 1. 添加配置到 config.py(已内置)
# 2. 计算归一化统计
uv run scripts/compute_norm_stats.py --config-name pi0_fast_libero_low_mem_finetune
# 3. 启动训练
uv run scripts/train.py pi0_fast_libero_low_mem_finetune --exp-name=my_run
理由:LoRA + 快速推理 + 22.5GB 显存
🎯 场景 2:ALOHA 真机数据集
推荐:创建自定义 pi0_aloha_lora 配置
# 添加到 config.py
TrainConfig(
name="pi0_aloha_lora",
model=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora"
),
data=LeRobotAlohaDataConfig(
repo_id="your_hf_username/my_aloha_dataset",
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
asset_id="trossen",
),
default_prompt="your task",
repack_transforms=_transforms.Group(
inputs=[
_transforms.RepackTransform({
"images": {
"cam_high": "observation.images.cam_high",
"cam_left_wrist": "observation.images.cam_left_wrist",
"cam_right_wrist": "observation.images.cam_right_wrist",
},
"state": "observation.state",
"actions": "action",
})
]
),
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_base/params"
),
num_train_steps=20_000,
freeze_filter=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora"
).get_freeze_filter(),
ema_decay=None,
)
🎯 场景 3:ALOHA 仿真训练
推荐:创建 pi0_aloha_sim_lora 配置(参考上文 9.3 节)
🎯 场景 4:小规模 DROID 数据集
不推荐在 RTX 4090 上训练 DROID(即使是小数据集,也需要 32 维动作空间,显存可能不足)
如果必须尝试,使用 batch_size=16 并监控显存。
9.7 快速创建自定义配置模板
# filepath: f:\codespace\openpi-main\src\openpi\training\config.py
# 添加到 _CONFIGS 列表末尾
TrainConfig(
# ============ 基础信息 ============
name="<YOUR_PLATFORM>_<YOUR_TASK>_lora", # 例如:aloha_pick_cube_lora
# ============ 模型配置(LoRA)============
model=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora",
# 如果是 Pi0-FAST,使用:
# model=pi0_fast.Pi0FASTConfig(
# action_dim=<YOUR_ACTION_DIM>,
# action_horizon=<YOUR_ACTION_HORIZON>,
# max_token_len=180, # 单臂 180,双臂 250
# paligemma_variant="gemma_2b_lora"
# ),
),
# ============ 数据配置 ============
# LIBERO:
data=LeRobotLiberoDataConfig(
repo_id="<YOUR_REPO_ID>",
base_config=DataConfig(prompt_from_task=True),
extra_delta_transform=True,
),
# 或 ALOHA:
# data=LeRobotAlohaDataConfig(
# repo_id="<YOUR_REPO_ID>",
# assets=AssetsConfig(
# assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
# asset_id="trossen",
# ),
# default_prompt="<YOUR_TASK_PROMPT>",
# repack_transforms=_transforms.Group(
# inputs=[_transforms.RepackTransform({
# "images": {
# "cam_high": "observation.images.cam_high",
# "cam_left_wrist": "observation.images.cam_left_wrist",
# "cam_right_wrist": "observation.images.cam_right_wrist",
# },
# "state": "observation.state",
# "actions": "action",
# })]
# ),
# ),
# ============ 预训练权重 ============
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_base/params"
# 或 pi0_fast_base: "gs://openpi-assets/checkpoints/pi0_fast_base/params"
# 或 pi05_base: "gs://openpi-assets/checkpoints/pi05_base/params"
),
# ============ 冻结参数(LoRA 必需)============
freeze_filter=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora"
).get_freeze_filter(),
# 如果是 Pi0-FAST LoRA:
# freeze_filter=pi0_fast.Pi0FASTConfig(
# action_dim=<YOUR_ACTION_DIM>,
# action_horizon=<YOUR_ACTION_HORIZON>,
# max_token_len=180,
# paligemma_variant="gemma_2b_lora"
# ).get_freeze_filter(),
# ============ 训练超参数 ============
num_train_steps=20_000, # ALOHA/小数据集:20k,LIBERO:30k
batch_size=32, # RTX 4090 推荐 32
lr_schedule=_optimizer.CosineDecaySchedule(
warmup_steps=1_000,
peak_lr=3e-4,
decay_steps=20_000,
decay_lr=1e-6,
),
optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
# ============ LoRA 专用 ============
ema_decay=None, # 关闭 EMA
# ============ 日志 ============
log_interval=100,
save_interval=1000,
keep_period=5000,
wandb_enabled=True,
)
附录:完整配置模板
# filepath: f:\codespace\openpi-main\src\openpi\training\config.py
# 复制此模板并修改相应参数
TrainConfig(
# ============ 基础信息 ============
name="<YOUR_CONFIG_NAME>", # 唯一配置名称
exp_name="<EXPERIMENT_NAME>", # 实验名称
project_name="openpi", # WandB 项目名
# ============ 模型配置 ============
model=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora",
action_dim=<ACTION_DIM>, # 您的动作维度
action_horizon=<ACTION_HORIZON>, # 动作序列长度
),
# ============ 数据配置 ============
data=LeRobotLiberoDataConfig( # 或您自己的 DataConfig
repo_id="<YOUR_DATASET_REPO_ID>",
base_config=DataConfig(
prompt_from_task=True,
),
),
# ============ 预训练权重 ============
weight_loader=weight_loaders.CheckpointWeightLoader(
"gs://openpi-assets/checkpoints/pi0_base/params"
),
# ============ 冻结参数 ============
freeze_filter=pi0_config.Pi0Config(
paligemma_variant="gemma_2b_lora",
action_expert_variant="gemma_300m_lora"
).get_freeze_filter(),
# ============ 训练超参数 ============
num_train_steps=<NUM_STEPS>,
batch_size=<BATCH_SIZE>,
lr_schedule=_optimizer.CosineDecaySchedule(
warmup_steps=<WARMUP_STEPS>,
peak_lr=<PEAK_LR>,
decay_steps=<DECAY_STEPS>,
decay_lr=<DECAY_LR>,
),
optimizer=_optimizer.AdamW(
clip_gradient_norm=1.0,
),
ema_decay=None,
# ============ 日志和保存 ============
log_interval=100,
save_interval=1000,
keep_period=5000,
wandb_enabled=True,
)
17万+

被折叠的 条评论
为什么被折叠?



