OpenPI(π0) LoRA 微调操作指南

部署运行你感兴趣的模型镜像

OpenPI LoRA 微调操作指南

适用场景:RTX 4090(24GB 显存)用户的完整训练流程


负一、核心概念

LoRA 微调 vs 全量微调

对比项LoRA 微调全量微调
显存需求22.5 GB ✅70 GB ❌
训练参数仅 1-2%100%
训练速度快(6-12h)慢(24h+)
RTX 4090可用不可用

LoRA 原理:冻结原始模型参数,只训练小型适配器矩阵

# 原始: W [1024×2048] = 2,097,152 参数(冻结)
# LoRA: A [1024×8] + B [8×2048] = 24,768 参数(训练)
# 输出 = Input × (W + A×B)

零、重要路径说明

1. 数据集存储位置

LIBERO 数据集(自动下载):

Windows: C:\Users\<用户名>\.cache\huggingface\datasets\physical-intelligence___libero\

Linux/Mac: ~/.cache/huggingface/datasets/physical-intelligence___libero/

结构:
├── default/
│   ├── 0.0.0/
│   │   ├── dataset_info.json
│   │   └── parquet files
│   └── ...

自定义数据集

~/.cache/huggingface/datasets/<你的用户名>___<数据集名>/

2. Checkpoint 保存位置

openpi\checkpoints\<配置名称>\<实验名称>\<步数>\

示例:
openpi\checkpoints\my_lora_finetune\test_run_1\
├── 1000/
│   ├── params/              # 模型权重(包含 LoRA 参数)
│   │   ├── manifest.ocdbt   # 元数据
│   │   └── state            # 训练状态
│   └── assets/
│       └── physical-intelligence_libero/
│           └── norm_stats.json  # 归一化统计(推理必需)
├── 2000/
├── ...
└── 30000/  # 最终模型

3. 预训练权重位置

云端(自动下载):

gs://openpi-assets/checkpoints/pi0_base/params

本地缓存

~/.cache/openpi/checkpoints/pi0_base/

一、训练配置

1.1 编辑配置文件

打开 src/openpi/training/config.py,滚动到文件末尾的 _CONFIGS 列表(约第 850 行),添加以下配置:

先介绍两个微调过程中的关键配置参数:

  • config-name:配置名称,唯一标识
  • exp-name:实验名称,用于区分不同实验
# filepath: openpi\src\openpi\training\config.py
# 在文件末尾的 _CONFIGS 列表中添加(约第 850 行)

TrainConfig(
    # ============ 基础配置 ============
    name="my_lora_finetune",  # 唯一配置名称
    exp_name="run_v1",        # 实验名称(命令行可覆盖)
    
    # ============ 模型配置(LoRA 变体)============
    model=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",      # 视觉-语言模型用 LoRA
        action_expert_variant="gemma_300m_lora"  # 动作专家用 LoRA
    ),
    
    # ============ 数据配置 ============
    data=LeRobotLiberoDataConfig(
        repo_id="physical-intelligence/libero",  # 数据集 ID
        base_config=DataConfig(
            prompt_from_task=True,  # 从数据集加载任务提示词
        ),
        extra_delta_transform=True,  # LIBERO 需要
    ),
    
    # ============ 预训练权重 ============
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_base/params"
    ),
    
    # ============ 冻结参数(关键!)============
    freeze_filter=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora"
    ).get_freeze_filter(),  # 自动冻结非 LoRA 参数
    
    # ============ 训练超参数 ============
    num_train_steps=30_000,
    batch_size=32,  # RTX 4090 可用 32-64
    
    # 学习率调度器类名
    lr_schedule=_optimizer.CosineDecaySchedule(
        warmup_steps=1_000,
        peak_lr=3e-4,  # LoRA 可以用更高的学习率
        decay_steps=30_000,
        decay_lr=1e-6,
    ),
    
    optimizer=_optimizer.AdamW(
        clip_gradient_norm=1.0,
    ),
    
    # ============ 关闭 EMA 节省显存 ============
    ema_decay=None,  # LoRA 不需要指数移动平均
    
    # ============ 其他配置 ============
    log_interval=100,
    save_interval=1000,
    keep_period=5000,
    wandb_enabled=True,
)

1.2 配置文件详解

1.2.1 模型配置参数
model=pi0_config.Pi0Config(
    paligemma_variant="gemma_2b_lora",      # 选项:gemma_2b, gemma_2b_lora
    action_expert_variant="gemma_300m_lora" # 选项:gemma_300m, gemma_300m_lora
)

变体说明

变体名称参数量用途显存需求
gemma_2b2B全量微调70+ GB
gemma_2b_lora2B (训练 ~20M)LoRA 微调22.5 GB
gemma_300m300M全量微调70+ GB
gemma_300m_lora300M (训练 ~3M)LoRA 微调22.5 GB

源码位置src/openpi/models/pi0_config.py 第 17-78 行


1.2.2 数据配置参数
data=LeRobotLiberoDataConfig(
    repo_id="physical-intelligence/libero",
    base_config=DataConfig(
        prompt_from_task=True,
    ),
    extra_delta_transform=True,
)

参数说明

参数类型作用示例值
repo_idstrHuggingFace 数据集 ID"physical-intelligence/libero"
prompt_from_taskbool是否从数据集的 task 字段加载提示词True
extra_delta_transformbool是否额外进行 delta 转换(LIBERO 特有)True

源码位置src/openpi/training/config.py 第 374-452 行


📖 关于 repo_id 的详细解释

1. 为什么叫 “repo_id” 而不是 “dataset_id”?

repo_id 这个命名来自 HuggingFace Hub 的设计理念:

  • HuggingFace Hub 使用 Git 仓库(Repository) 来管理数据集、模型和空间(Spaces)
  • 每个数据集本质上是一个 Git 仓库,包含:
    • 数据文件(Parquet、Arrow、CSV 等)
    • 元数据文件(dataset_info.jsonREADME.md
    • 版本控制历史(Git commits)
  • 因此使用 repo_id 强调了数据集的仓库属性

命名格式<用户名或组织名>/<数据集名称>

# 示例
repo_id = "physical-intelligence/libero"      # 组织:physical-intelligence
repo_id = "lerobot/aloha_sim_transfer_cube"   # 组织:lerobot
repo_id = "your_username/my_aloha_dataset"    # 个人用户

2. HuggingFace Hub 上有哪些可用的机器人数据集?

访问 https://huggingface.co/datasets 并搜索关键词:

关键词数据集类型示例 repo_id
lerobotLeRobot 格式数据集lerobot/aloha_sim_transfer_cube_human
liberoLIBERO 任务数据集physical-intelligence/libero
alohaALOHA 机器人数据集physical-intelligence/aloha_pen_uncap_diverse
droidDROID 机器人数据集需要自行转换后上传
robomimicRobomimic 数据集eai-lab/robomimic

1.2.3 学习率调度器
lr_schedule=_optimizer.CosineDecaySchedule(
    warmup_steps=1_000,   # 预热步数
    peak_lr=3e-4,         # 峰值学习率
    decay_steps=30_000,   # 衰减步数
    decay_lr=1e-6,        # 最终学习率
)

可用的学习率调度器

调度器类特点适用场景
CosineDecaySchedule余弦衰减 + 预热推荐用于 LoRA
LinearWarmupSchedule线性预热简单任务

源码位置src/openpi/training/optimizer.py 第 1-80 行


1.2.4 冻结参数机制
freeze_filter=pi0_config.Pi0Config(
    paligemma_variant="gemma_2b_lora",
    action_expert_variant="gemma_300m_lora"
).get_freeze_filter()

工作原理

# src/openpi/models/pi0_config.py 第 80-106 行
def get_freeze_filter(self):
    """返回需要冻结的参数过滤器"""
    filters = []
    
    if "lora" in self.paligemma_variant:
        # 冻结 Gemma 的所有参数
        filters.append(nnx_utils.PathRegex(".*llm.*"))
    
    if "lora" in self.action_expert_variant:
        # 冻结动作专家的所有参数
        filters.append(nnx_utils.PathRegex(".*llm.*_1.*"))
    
    # 排除所有 LoRA 参数(允许训练)
    filters.append(nnx.Not(nnx_utils.PathRegex(".*lora.*")))
    
    return nnx.All(*filters)

效果示意

# 模型参数示例
model.llm.layer_0.weight           # ❌ 冻结(~2GB)
model.llm.layer_0.lora_a           # ✅ 训练(~8MB)
model.llm.layer_0.lora_b           # ✅ 训练(~8MB)
model.action_expert.layer_1.weight # ❌ 冻结(~300MB)
model.action_expert.layer_1.lora_a # ✅ 训练(~3MB)

1.2.5 EMA(指数移动平均)
ema_decay=None  # LoRA 微调建议关闭

什么是 EMA

EMA t = α ⋅ θ t + ( 1 − α ) ⋅ EMA t − 1 \text{EMA}_t = \alpha \cdot \theta_t + (1-\alpha) \cdot \text{EMA}_{t-1} EMAt=αθt+(1α)EMAt1

  • 维护一份参数的移动平均版本
  • 推理时使用 EMA 参数,通常更稳定

为什么 LoRA 关闭 EMA

  1. 节省显存:EMA 需要额外存储一份完整参数(~10GB)
  2. LoRA 本身已足够稳定:少量参数不易震荡

源码位置src/openpi/training/config.py 第 481 行


二、归一化详解

2.1 为什么需要归一化

问题场景

# 原始数据(未归一化)
state = [
    0.5,      # 关节 1 角度(弧度)
    -2.3,     # 关节 2 角度(弧度)
    1500.0,   # 力传感器读数(牛顿)
    0.1,      # 夹爪开合度(0-1)
]

问题

  1. 数值跨度巨大(0.1 到 1500)
  2. 力传感器会主导梯度计算
  3. 训练不稳定,难以收敛

归一化后

# Z-score 归一化:(x - mean) / std
state_normalized = [
    -0.23,  # (0.5 - 0.6) / 0.4
    0.15,   # (-2.3 - (-2.5)) / 0.5
    0.08,   # (1500 - 1480) / 250
    -0.10,  # (0.1 - 0.15) / 0.05
]

效果

  • ✅ 均值接近 0,方差接近 1
  • ✅ 各维度贡献均衡
  • ✅ 梯度下降更高效

2.2 两种归一化方法

方法 1:Z-score 归一化(Pi0)
# config.py 第 211 行
use_quantile_norm=False  # 使用 Z-score

# 公式
normalized = (x - mean) / std

特点

  • 假设数据服从正态分布
  • 对离群值敏感
方法 2:Quantile 归一化(Pi0.5)
# config.py 第 211 行
use_quantile_norm=True  # 使用 Quantile

# 公式
normalized = (x - median) / IQR
# IQR = Q75 - Q25(四分位距)

特点

  • 基于分位数,更鲁棒
  • 对离群值不敏感

源码位置src/openpi/shared/normalize.py 第 1-150 行


2.3 归一化统计的加载

# config.py 第 209-223 行
def _load_norm_stats(
    self, 
    assets_dir: epath.Path, 
    asset_id: str | None
) -> dict[str, _transforms.NormStats] | None:
    """从 assets 目录加载归一化统计"""
    
    if asset_id is None:
        return None
    
    try:
        # 构建路径:assets/<asset_id>/norm_stats.json
        data_assets_dir = str(assets_dir / asset_id)
        norm_stats = _normalize.load(_download.maybe_download(data_assets_dir))
        logging.info(f"Loaded norm stats from {data_assets_dir}")
        return norm_stats
    except FileNotFoundError:
        logging.info(f"Norm stats not found in {data_assets_dir}, skipping.")
    
    return None

加载路径

checkpoints/my_lora_finetune/assets/physical-intelligence_libero/norm_stats.json

文件格式

{
  "state": {
    "mean": [0.1, -0.3, 0.5, ...],
    "std": [0.4, 0.6, 0.8, ...],
    "min": [-1.2, -2.5, -0.3, ...],
    "max": [1.5, 2.1, 1.8, ...]
  },
  "actions": {
    "mean": [0.0, 0.0, 0.0, ...],
    "std": [0.2, 0.3, 0.1, ...],
    "min": [-0.5, -0.6, -0.3, ...],
    "max": [0.5, 0.6, 0.3, ...]
  }
}

2.4 重用预训练的归一化统计

场景:您的机器人与预训练数据集中的机器人相同

TrainConfig(
    name="my_aloha_lora",
    
    data=LeRobotAlohaDataConfig(
        repo_id="your_username/my_aloha_dataset",
        assets=AssetsConfig(
            # 重用 pi0_base 的 Trossen 机器人归一化统计
            assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
            asset_id="trossen",  # 关键:指定机器人类型
        ),
    ),
    # ...
)

可用的预训练归一化统计

机器人asset_id适用模型
Trossen (ALOHA)trossenpi0_base, pi0_fast_base
DROID Frankadroidpi0_base, pi05_base
UR5eur5epi0_base
Pandapandapi0_base

源码位置docs/norm_stats.md 第 9-50 行


三、 完整训练流程

步骤 1:添加配置

编辑 src/openpi/training/config.py,添加上述配置。

步骤 2:计算归一化统计

uv run scripts/compute_norm_stats.py --config-name my_lora_finetune

步骤 3:启动训练

3.1 直接启动(前台运行)
# Windows PowerShell
$env:XLA_PYTHON_CLIENT_MEM_FRACTION="0.9"
uv run scripts/train.py my_lora_finetune --exp-name=test_run_1 --overwrite

# Linux/macOS
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
    uv run scripts/train.py my_lora_finetune \
    --exp-name=test_run_1 \
    --overwrite

缺点:关闭终端窗口,训练会中断 ❌

3.2 使用 tmux 后台运行(推荐,Linux/macOS)
📦 安装 tmux(如果还没有)
# Ubuntu/Debian
sudo apt-get install tmux

# CentOS/RHEL
sudo yum install tmux

# macOS
brew install tmux
🚀 使用 tmux 启动训练
# 1. 创建名为 "openpi_train" 的 tmux 会话
tmux new -s openpi_train

# 2. 在 tmux 会话中设置环境变量并启动训练
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
cd /path/to/openpi-main
uv run scripts/train.py my_lora_finetune --exp-name=test_run_1 --overwrite

# 3. 分离会话(训练在后台继续运行)
# 按键:Ctrl + B,然后按 D
🔍 管理 tmux 会话

查看所有会话

tmux ls

# 输出示例:
# openpi_train: 1 windows (created Thu Jan 15 10:30:00 2025)

重新连接到会话

# 连接到 openpi_train 会话
tmux attach -t openpi_train

# 或简写
tmux a -t openpi_train

再次分离会话

按键组合:Ctrl + B,然后按 D

杀死会话(训练完成后):

# 方法 1:在会话内部
exit

# 方法 2:从外部杀死
tmux kill-session -t openpi_train
💡 tmux 常用快捷键
快捷键功能
Ctrl + BD分离会话(detach)
Ctrl + BC创建新窗口
Ctrl + BN切换到下一个窗口
Ctrl + BP切换到上一个窗口
Ctrl + B%垂直分割窗口
Ctrl + B"水平分割窗口
Ctrl + B方向键在分割的窗格间切换
Ctrl + B[进入滚动模式(查看历史输出)

步骤 4:监控训练进度

4.1 本地终端日志

训练过程中,终端会实时显示:

[2025/01/15 10:30:20] Step 0: loss=2.345, lr=0.0
[2025/01/15 10:30:25] Step 100: loss=1.892, lr=1.2e-4
[2025/01/15 10:30:30] Step 200: loss=1.567, lr=2.1e-4
...
4.2 使用 Weights & Biases 可视化监控

训练启动后,Wandb 会自动记录数据。

4.3 检查本地 Checkpoint
# 查看已保存的 checkpoint
ls checkpoints\my_lora_finetune\test_run_1\

# 输出示例:
# 1000/  (约 500 MB)
# 2000/  (约 500 MB)
# ...
# 30000/ (约 500 MB - 最终模型)

每个 checkpoint 的内容

30000/
├── params/              # 模型权重(包含 LoRA 参数)
│   ├── manifest.ocdbt   # TensorStore 元数据
│   └── state            # 训练状态(优化器、步数等)
└── assets/
    └── physical-intelligence_libero/
        └── norm_stats.json  # 归一化统计(推理时必需)

步骤 5:使用训练好的模型推理

训练完成后,您可以:

  1. 启动策略服务器进行远程推理
  2. 在仿真环境中测试模型
  3. 录制 demo 视频展示效果
5.1 启动策略服务器(HTTP API)
基础启动
uv run scripts/serve_policy.py policy:checkpoint \
    --policy.config=my_lora_finetune \
    --policy.dir=checkpoints/my_lora_finetune/test_run_1/30000 \
    --port 8000

启动输出示例

[2025/01/15 18:00:00] Loading checkpoint from checkpoints/my_lora_finetune/test_run_1/30000
[2025/01/15 18:00:05] Loading norm stats from checkpoints/.../norm_stats.json
[2025/01/15 18:00:10] Model loaded successfully!
[2025/01/15 18:00:10] Starting server on http://0.0.0.0:8000
[2025/01/15 18:00:10] Health check: http://localhost:8000/healthz
验证服务器
# 健康检查
Invoke-WebRequest http://localhost:8000/healthz

# 输出:StatusCode: 200, Content: OK
API 使用示例

Python 客户端

import requests
import numpy as np
from PIL import Image

# 1. 加载图像
image = Image.open("camera_view.jpg")
image_array = np.array(image)  # shape: [H, W, 3]

# 2. 准备机器人状态
state = [0.1, -0.5, 0.3, 0.2, 0.8, 1.0, 0.0]  # 7 维状态

# 3. 发送请求
response = requests.post("http://localhost:8000/act", json={
    "observation": {
        "images": [image_array.tolist()],  # 转为列表
        "state": state,
    },
    "prompt": "pick up the red cube",  # 任务描述
})

# 4. 获取动作
result = response.json()
actions = np.array(result["action"])  # shape: [10, 7](10 步预测)

print(f"下一步动作: {actions[0]}")
print(f"完整动作序列: {actions.shape}")

JavaScript 客户端

const response = await fetch('http://localhost:8000/act', {
    method: 'POST',
    headers: { 'Content-Type': 'application/json' },
    body: JSON.stringify({
        observation: {
            images: [imageArray],  // 3D array
            state: [0.1, -0.5, 0.3, 0.2, 0.8, 1.0, 0.0]
        },
        prompt: "pick up the red cube"
    })
});

const result = await response.json();
console.log('Next action:', result.action[0]);
5.2 在仿真环境中测试模型
查找可用的仿真任务

查看 gym-aloha GitHub 仓库

  1. 访问 https://github.com/huggingface/gym-aloha
  2. 查看 gym_aloha/__init__.py 文件,查找 register() 函数调用
  3. 查看 gym_aloha/env.py 文件中的 _make_env_task() 方法

根据 gym-aloha 源码env.py,当前支持以下任务:

任务 ID任务名称描述最大步数奖励机制
gym_aloha/AlohaTransferCube-v0转移方块任务右臂抓取红色方块并转移到左臂3001分:右臂抓住方块
2分:方块被抬起
3分:转移到左臂
4分:成功转移且不触桌面
gym_aloha/AlohaInsertion-v0插入任务左右臂分别抓取插座和插针,然后在空中插入3001分:触摸插针和插座
2分:抓住两者
3分:对齐并接触
4分:成功插入

🚀 运行仿真环境测试

使用 examples/aloha_sim/run_task.py 脚本:

# examples/aloha_sim/run_task.py
import numpy as np
import imageio
import tqdm
from openpi_client import websocket_client_policy
from examples.aloha_sim import env as _env

# --- 1. 配置 ---
# 服务器地址和端口
SERVER_HOST = "localhost"
SERVER_PORT = 8000

# 仿真环境配置
SIM_TASK_NAME = "gym_aloha/AlohaTransferCube-v0"
SEED = 42  # 固定随机种子以保证可复现性

# 视频保存配置
OUTPUT_VIDEO_PATH = "data/aloha_sim/multi_prompt_task.mp4"
VIDEO_FPS = 15

# --- 2. 定义多阶段任务和自定义提示词 ---
# 每个阶段包含:提示词、执行的步数
task_phases = [
    {"prompt": "pick up the cube", "steps": 200}
]

def main():
    """
    主函数:连接服务器,运行多阶段任务,并保存视频。
    """
    print(">> 正在初始化策略客户端...")
    try:
        client = websocket_client_policy.WebsocketClientPolicy(
            host=SERVER_HOST, port=SERVER_PORT
        )
        print(f">> 成功连接到策略服务器 at {SERVER_HOST}:{SERVER_PORT}")
    except Exception as e:
        print(f"!! 无法连接到服务器: {e}")
        print("!! 请确保 'serve_policy.py' 正在运行。")
        return

    print(f">> 正在创建仿真环境: {SIM_TASK_NAME}")
    env = _env.AlohaSimEnvironment(task=SIM_TASK_NAME, seed=SEED)
    env.reset()
    print(">> 仿真环境已重置。")

    # 用于存储视频帧的列表
    video_frames = []
    
    # 计算总步数用于进度条
    total_steps = sum(phase["steps"] for phase in task_phases)
    
    with tqdm.tqdm(total=total_steps, desc="执行任务") as pbar:
        # --- 3. 循环执行每个任务阶段 ---
        for phase in task_phases:
            current_prompt = phase["prompt"]
            steps_for_phase = phase["steps"]
            pbar.set_description(f"阶段: {current_prompt}")

            for _ in range(steps_for_phase):
                # a. 从环境中获取观测
                observation = env.get_observation()

                # b. 添加当前的自定义提示词
                observation["prompt"] = current_prompt

                # c. 发送观测到服务器并获取动作
                try:
                    output = client.infer(observation)
                    action = output["actions"][0]  # 取动作序列的第一个动作
                except Exception as e:
                    print(f"\n!! 推理失败: {e}")
                    print("!! 检查服务器连接和状态。")
                    return

                # d. 将动作应用到仿真环境
                env.apply_action({"actions": action})

                # e. 渲染并保存当前帧用于录制视频
                # 注意:我们需要 'cam_high' 图像,并将其从 CHW 转换为 HWC 格式
                frame = observation["images"]["cam_high"]
                frame_hwc = np.transpose(frame, (1, 2, 0))
                video_frames.append(frame_hwc)
                
                pbar.update(1)

    print(f"\n>> 任务完成。总共录制了 {len(video_frames)} 帧。")

    # --- 4. 保存视频 ---
    print(f">> 正在将视频保存到: {OUTPUT_VIDEO_PATH}")
    try:
        imageio.mimwrite(OUTPUT_VIDEO_PATH, video_frames, fps=VIDEO_FPS, quality=8)
        print(">> 视频保存成功!")
    except Exception as e:
        print(f"!! 视频保存失败: {e}")

if __name__ == "__main__":
    main()

完整步骤

# 1. 启动策略服务器(在一个终端)
uv run scripts/serve_policy.py policy:checkpoint \
    --policy.config=my_lora_finetune \
    --policy.dir=checkpoints/my_lora_finetune/test_run_1/30000 \
    --port 8000

# 2. 运行仿真脚本(在另一个终端)
cd examples/aloha_sim

# 运行转移方块任务
uv run python run_task.py

# 或者直接指定任务(修改 run_task.py 中的 SIM_TASK_NAME)

自定义多阶段任务示例

您可以修改 run_task.py 中的 task_phases 列表来创建自定义的多阶段任务:

# run_task.py 中的自定义提示词示例
task_phases = [
    {"prompt": "pick up the cube", "steps": 70},
    {"prompt": "move the cube to the right side", "steps": 80},
    {"prompt": "place the cube into the bowl", "steps": 60},
    {"prompt": "move your arms back to a neutral position", "steps": 40},
]

5.3 评估模型性能
💡 关于评估代码

OpenPI 项目已包含官方评估脚本,位于:

  • LIBERO 评估examples/libero/main.py
  • ALOHA 仿真评估examples/aloha_sim/main.py

这些脚本是官方提供的标准评估工具,无需自己编写。下面将介绍如何使用它们。


使用官方评估脚本(推荐)⭐
方法 1:LIBERO 环境评估

适用场景:评估在 LIBERO 数据集上训练的模型

步骤

# 1. 启动策略服务器(在一个终端)
uv run scripts/serve_policy.py policy:checkpoint \
    --policy.config=my_lora_finetune \
    --policy.dir=checkpoints/my_lora_finetune/test_run_1/30000 \
    --port 8000

# 2. 运行评估脚本(在另一个终端)
cd examples/libero

# 评估 libero_spatial 任务套件(10 个任务,每个任务 50 次)
uv run main.py \
    --host 0.0.0.0 \
    --port 8000 \
    --task-suite-name libero_spatial \
    --num-trials-per-task 50 \
    --video-out-path ../../data/libero/eval_videos

# 其他可用任务套件:
# --task-suite-name libero_object    # 物体操作任务
# --task-suite-name libero_goal      # 目标导向任务
# --task-suite-name libero_10        # LIBERO-10 基准测试
# --task-suite-name libero_90        # LIBERO-90 基准测试

输出示例

Task suite: libero_spatial
Task: pick up the red block and put it in the drawer
Starting episode 1...
Success: True
# episodes completed so far: 1
# successes: 1 (100.0%)

Task: pick up the red block and put it in the drawer
Starting episode 2...
Success: True
# episodes completed so far: 2
# successes: 2 (100.0%)

...

Current task success rate: 0.86  # 当前任务成功率 86%
Current total success rate: 0.84 # 总体成功率 84%

Total success rate: 0.84         # 最终成功率
Total episodes: 500              # 总测试回合数

结果文件

data/libero/eval_videos/
├── rollout_pick_up_the_red_block_success.mp4
├── rollout_pick_up_the_red_block_failure.mp4
├── rollout_stack_the_blue_cube_success.mp4
└── ...

方法 2:ALOHA 仿真评估

适用场景:评估在 ALOHA 仿真环境训练的模型

评估脚本代码 (examples/aloha_sim/main.py):

import dataclasses
import logging
import pathlib

import env as _env
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import saver as _saver
import tyro


@dataclasses.dataclass
class Args:
    """ALOHA 仿真评估参数配置"""
    
    # 视频保存目录
    out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
    
    # 仿真任务名称
    task: str = "gym_aloha/AlohaTransferCube-v0"
    
    # 随机种子
    seed: int = 0
    
    # 动作序列长度(必须与训练时一致)
    action_horizon: int = 10
    
    # 策略服务器地址
    host: str = "0.0.0.0"
    port: int = 8000
    
    # 是否显示仿真画面
    display: bool = False


def main(args: Args) -> None:
    """运行 ALOHA 仿真评估
    
    工作流程:
    1. 创建仿真环境 (AlohaSimEnvironment)
    2. 连接到策略服务器 (WebsocketClientPolicy)
    3. 使用动作分块代理执行策略 (ActionChunkBroker)
    4. 保存评估视频 (VideoSaver)
    """
    runtime = _runtime.Runtime(
        # 初始化 ALOHA 仿真环境
        environment=_env.AlohaSimEnvironment(
            task=args.task,
            seed=args.seed,
        ),
        
        # 初始化策略代理
        agent=_policy_agent.PolicyAgent(
            policy=action_chunk_broker.ActionChunkBroker(
                # 连接到远程策略服务器
                policy=_websocket_client_policy.WebsocketClientPolicy(
                    host=args.host,
                    port=args.port,
                ),
                action_horizon=args.action_horizon,
            )
        ),
        
        # 添加视频保存器
        subscribers=[
            _saver.VideoSaver(args.out_dir),
        ],
        
        # 控制频率 50Hz
        max_hz=50,
    )
    
    # 运行评估
    runtime.run()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, force=True)
    tyro.cli(main)

运行步骤

# 1. 启动策略服务器(在一个终端)
uv run scripts/serve_policy.py policy:checkpoint \
    --policy.config=pi0_aloha_sim_lora \
    --policy.dir=checkpoints/pi0_aloha_sim_lora/run_1/20000 \
    --port 8000

# 2. 运行评估脚本(在另一个终端)
cd examples/aloha_sim

uv run main.py \
    --host 0.0.0.0 \
    --port 8000 \
    --task gym_aloha/AlohaTransferCube-v0 \
    --out-dir ../../data/aloha_sim/eval_videos

可用任务

任务名称描述
gym_aloha/AlohaTransferCube-v0转移方块
gym_aloha/AlohaInsertion-v0插入任务
gym_aloha/AlohaThreadVelcro-v0穿线任务

5.4 部署到真实机器人

对于真实机器人部署,参考:

基本流程

  1. 启动策略服务器(在 GPU 服务器上)
  2. 连接机器人硬件
  3. 运行控制脚本(读取相机 → 调用策略 → 发送动作)
# 真实机器人推理示例(伪代码)
import robot_interface  # 机器人控制库

robot = robot_interface.connect()
camera = robot_interface.Camera()

while True:
    # 1. 获取观测
    image = camera.capture()
    state = robot.get_state()
    
    # 2. 调用策略
    response = requests.post("http://gpu-server:8000/act", json={
        "observation": {"images": [image], "state": state},
        "prompt": "pick up the object"
    })
    
    action = response.json()["action"][0]
    
    # 3. 执行动作
    robot.execute(action)
    time.sleep(0.1)  # 控制频率 10Hz

四、 不同机器人平台的训练配置总结

本章节根据 config.py 中的 TrainConfig 配置,总结了各种机器人平台的训练配置方案。

4.1 LIBERO 仿真环境

方案 A:Pi0 全量微调(需要 70GB+ 显存)
TrainConfig(
    name="pi0_libero",
    
    # 全量微调模型(所有参数都训练)
    model=pi0_config.Pi0Config(),
    
    data=LeRobotLiberoDataConfig(
        repo_id="physical-intelligence/libero",
        base_config=DataConfig(
            prompt_from_task=True,  # 从数据集加载任务提示词
        ),
        extra_delta_transform=True,  # LIBERO 特有的 delta 转换
    ),
    
    # 加载 pi0_base 预训练权重
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_base/params"
    ),
    
    num_train_steps=30_000,
    # 其他参数使用默认值
)

特点

  • ✅ 训练效果最好
  • ❌ 需要 70GB+ 显存(A100/H100)
  • ❌ RTX 4090 无法使用

方案 B:Pi0 LoRA 微调(推荐 RTX 4090)⭐
TrainConfig(
    name="pi0_libero_low_mem_finetune",
    
    # LoRA 变体:只训练 LoRA 参数
    model=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",      # 视觉-语言模型用 LoRA
        action_expert_variant="gemma_300m_lora"  # 动作专家用 LoRA
    ),
    
    data=LeRobotLiberoDataConfig(
        repo_id="physical-intelligence/libero",
        base_config=DataConfig(prompt_from_task=True),
        extra_delta_transform=True,
    ),
    
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_base/params"
    ),
    
    num_train_steps=30_000,
    
    # LoRA 专用设置
    freeze_filter=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora"
    ).get_freeze_filter(),  # 冻结非 LoRA 参数
    
    ema_decay=None,  # 关闭 EMA 节省显存
)

特点

  • ✅ 显存需求 22.5GB(RTX 4090 可用)
  • ✅ 训练速度快
  • ✅ 推荐配置

使用命令

# 1. 计算归一化统计
uv run scripts/compute_norm_stats.py --config-name pi0_libero_low_mem_finetune

# 2. 启动训练
uv run scripts/train.py pi0_libero_low_mem_finetune --exp-name=my_run_1 --overwrite

方案 C:Pi0-FAST 全量微调
TrainConfig(
    name="pi0_fast_libero",
    
    # Pi0-FAST 模型:更快的推理速度
    model=pi0_fast.Pi0FASTConfig(
        action_dim=7,           # LIBERO 动作维度
        action_horizon=10,      # 动作序列长度
        max_token_len=180,      # 最大 token 长度(单臂机器人)
    ),
    
    data=LeRobotLiberoDataConfig(
        repo_id="physical-intelligence/libero",
        base_config=DataConfig(prompt_from_task=True),
        extra_delta_transform=True,
    ),
    
    # 加载 pi0_fast_base 预训练权重
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_fast_base/params"
    ),
    
    num_train_steps=30_000,
)

特点

  • ✅ 推理速度更快(适合实时控制)
  • ❌ 需要 70GB+ 显存
  • ⚠️ 需要调整 max_token_len:单臂 ~180,双臂 ~250

方案 D:Pi0-FAST LoRA 微调(推荐 RTX 4090)⭐
TrainConfig(
    name="pi0_fast_libero_low_mem_finetune",
    
    # Pi0-FAST LoRA 变体
    model=pi0_fast.Pi0FASTConfig(
        action_dim=7,
        action_horizon=10,
        max_token_len=180,
        paligemma_variant="gemma_2b_lora"  # FAST 只需冻结 PaliGemma
    ),
    
    data=LeRobotLiberoDataConfig(
        repo_id="physical-intelligence/libero",
        base_config=DataConfig(prompt_from_task=True),
        extra_delta_transform=True,
    ),
    
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_fast_base/params"
    ),
    
    num_train_steps=30_000,
    
    # LoRA 专用设置
    freeze_filter=pi0_fast.Pi0FASTConfig(
        action_dim=7,
        action_horizon=10,
        max_token_len=180,
        paligemma_variant="gemma_2b_lora"
    ).get_freeze_filter(),
    
    ema_decay=None,
)

特点

  • ✅ 低显存(22.5GB)+ 快速推理
  • ✅ 适合实时机器人控制
  • ✅ RTX 4090 最佳选择

方案 E:Pi0.5 微调(最新架构)
TrainConfig(
    name="pi05_libero",
    
    # Pi0.5 模型(改进的架构)
    model=pi0_config.Pi0Config(
        pi05=True,                   # 启用 Pi0.5
        action_horizon=10,
        discrete_state_input=False,  # 使用连续状态输入
    ),
    
    data=LeRobotLiberoDataConfig(
        repo_id="physical-intelligence/libero",
        base_config=DataConfig(prompt_from_task=True),
        extra_delta_transform=False,  # Pi0.5 不需要
    ),
    
    batch_size=256,  # 更大的批次
    
    # 不同的学习率调度
    lr_schedule=_optimizer.CosineDecaySchedule(
        warmup_steps=10_000,
        peak_lr=5e-5,      # 较低的学习率
        decay_steps=1_000_000,
        decay_lr=5e-5,
    ),
    
    optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
    ema_decay=0.999,   # Pi0.5 使用 EMA
    
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi05_base/params"
    ),
    
    num_train_steps=30_000,
)

特点

  • ✅ 最新架构,性能更好
  • ❌ 需要大显存
  • ⚠️ 需要 PyTorch 权重路径(如果转换)

4.2 ALOHA 真实机器人

方案 A:Pi0 ALOHA 微调(自定义数据集)
TrainConfig(
    name="pi0_aloha_pen_uncap",
    
    model=pi0_config.Pi0Config(),
    
    data=LeRobotAlohaDataConfig(
        # 替换为您自己的 ALOHA 数据集
        repo_id="physical-intelligence/aloha_pen_uncap_diverse",
        
        # 重用 pi0_base 的 Trossen 归一化统计
        assets=AssetsConfig(
            assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
            asset_id="trossen",  # Trossen 机器人配置
        ),
        
        default_prompt="uncap the pen",  # 默认任务提示词
        
        # 数据重打包(适配 LeRobot 格式)
        repack_transforms=_transforms.Group(
            inputs=[
                _transforms.RepackTransform({
                    "images": {
                        "cam_high": "observation.images.cam_high",
                        "cam_left_wrist": "observation.images.cam_left_wrist",
                        "cam_right_wrist": "observation.images.cam_right_wrist",
                    },
                    "state": "observation.state",
                    "actions": "action",
                })
            ]
        ),
    ),
    
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_base/params"
    ),
    
    num_train_steps=20_000,  # ALOHA 数据集较小
)

适用场景

  • ✅ ALOHA 真机数据集
  • ✅ 使用 Trossen 机器人
  • ⚠️ 需要先转换为 LeRobot 格式(参考 examples/aloha_real/README.md

数据转换命令

cd examples/aloha_real

# 转换 ALOHA HDF5 数据到 LeRobot 格式
python convert_aloha_data_to_lerobot.py \
    --input-dir /path/to/aloha_data \
    --output-repo your_hf_username/my_aloha_dataset

方案 B:Pi0 ALOHA LoRA 微调(RTX 4090)
TrainConfig(
    name="pi0_aloha_lora",
    
    # LoRA 变体
    model=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora"
    ),
    
    data=LeRobotAlohaDataConfig(
        repo_id="your_hf_username/my_aloha_dataset",
        
        assets=AssetsConfig(
            assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
            asset_id="trossen",
        ),
        
        default_prompt="your task description",
        
        repack_transforms=_transforms.Group(
            inputs=[
                _transforms.RepackTransform({
                    "images": {
                        "cam_high": "observation.images.cam_high",
                        "cam_left_wrist": "observation.images.cam_left_wrist",
                        "cam_right_wrist": "observation.images.cam_right_wrist",
                    },
                    "state": "observation.state",
                    "actions": "action",
                })
            ]
        ),
    ),
    
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_base/params"
    ),
    
    num_train_steps=20_000,
    
    # LoRA 设置
    freeze_filter=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora"
    ).get_freeze_filter(),
    
    ema_decay=None,
)

特点

  • ✅ RTX 4090 可用(22.5GB)
  • ✅ 适合小数据集快速迭代
  • ⚠️ 需要调整 default_prompt 为您的任务

方案 C:Pi0.5 ALOHA 微调
TrainConfig(
    name="pi05_aloha_pen_uncap",
    
    model=pi0_config.Pi0Config(pi05=True),
    
    data=LeRobotAlohaDataConfig(
        repo_id="physical-intelligence/aloha_pen_uncap_diverse",
        
        # 使用 pi05_base 的归一化统计
        assets=AssetsConfig(
            assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets",
            asset_id="trossen",
        ),
        
        default_prompt="uncap the pen",
        
        repack_transforms=_transforms.Group(
            inputs=[
                _transforms.RepackTransform({
                    "images": {
                        "cam_high": "observation.images.cam_high",
                        "cam_left_wrist": "observation.images.cam_left_wrist",
                        "cam_right_wrist": "observation.images.cam_right_wrist",
                    },
                    "state": "observation.state",
                    "actions": "action",
                })
            ]
        ),
    ),
    
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi05_base/params"
    ),
    
    num_train_steps=20_000,
    batch_size=64,  # Pi0.5 可以用更大的批次
)

4.3 ALOHA 仿真环境

TrainConfig(
    name="pi0_aloha_sim",
    
    model=pi0_config.Pi0Config(),
    
    data=LeRobotAlohaDataConfig(
        # LeRobot 官方 ALOHA 仿真数据集
        repo_id="lerobot/aloha_sim_transfer_cube_human",
        
        default_prompt="Transfer cube",
        
        # 仿真环境不使用 delta 动作
        use_delta_joint_actions=False,
    ),
    
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_base/params"
    ),
    
    num_train_steps=20_000,
)

适用场景

  • ✅ 快速原型验证
  • ✅ 算法开发和测试
  • ⚠️ 不需要真实机器人硬件

LoRA 版本

TrainConfig(
    name="pi0_aloha_sim_lora",
    
    model=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora"
    ),
    
    data=LeRobotAlohaDataConfig(
        repo_id="lerobot/aloha_sim_transfer_cube_human",
        default_prompt="Transfer cube",
        use_delta_joint_actions=False,
    ),
    
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_base/params"
    ),
    
    num_train_steps=20_000,
    
    freeze_filter=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora"
    ).get_freeze_filter(),
    
    ema_decay=None,
)

4.4 DROID 机器人(大规模数据集)

方案 A:Pi0-FAST 全量微调(完整 DROID 数据集)
TrainConfig(
    name="pi0_fast_full_droid_finetune",
    
    # Pi0-FAST 适合 DROID 的实时控制
    model=pi0_fast.Pi0FASTConfig(
        action_dim=8,           # DROID 动作维度
        action_horizon=16,      # 更长的动作序列
        max_token_len=180,
    ),
    
    # 使用 RLDS 数据加载器(大规模数据集)
    data=RLDSDroidDataConfig(
        repo_id="droid",
        
        # ⚠️ 设置为您的 DROID RLDS 数据集路径
        rlds_data_dir="<path_to_droid_rlds_dataset>",
        
        action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,
    ),
    
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_fast_base/params"
    ),
    
    # 更低的学习率(大数据集)
    lr_schedule=_optimizer.CosineDecaySchedule(
        warmup_steps=1_000,
        peak_lr=5e-5,
        decay_steps=1_000_000,
        decay_lr=5e-5,
    ),
    
    num_train_steps=100_000,  # 100k 步,约 2 天(8x H100)
    batch_size=256,
    
    log_interval=100,
    save_interval=5000,
    keep_period=20_000,
    
    # ⚠️ 重要:RLDS 必须设置 num_workers=0
    num_workers=0,
)

特点

  • ✅ 适合完整 DROID 数据集(数百小时)
  • ✅ 使用 RLDS 高效数据加载
  • ❌ 需要 8x H100 GPU
  • ⚠️ num_workers=0 是必需的

方案 B:Pi0.5 DROID 微调(完整数据集)
TrainConfig(
    name="pi05_full_droid_finetune",
    
    model=pi0_config.Pi0Config(
        pi05=True,
        action_dim=32,       # Pi0.5 使用 32 维动作
        action_horizon=16,
    ),
    
    data=RLDSDroidDataConfig(
        repo_id="droid",
        rlds_data_dir="/mnt/pi-data/kevin",  # 您的 RLDS 路径
        action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,
        
        # 使用 pi05_base 的 DROID 归一化统计
        assets=AssetsConfig(
            assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets/",
            asset_id="droid",
        ),
    ),
    
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi05_base/params"
    ),
    
    lr_schedule=_optimizer.CosineDecaySchedule(
        warmup_steps=1_000,
        peak_lr=5e-5,
        decay_steps=1_000_000,
        decay_lr=5e-5,
    ),
    
    num_train_steps=100_000,
    batch_size=256,
    
    log_interval=100,
    save_interval=5000,
    keep_period=10_000,
    
    num_workers=0,
)

方案 C:Pi0.5 DROID 微调(小数据集,LeRobot 格式)
TrainConfig(
    name="pi05_droid_finetune",
    
    model=pi0_config.Pi0Config(
        pi05=True,
        action_dim=32,
        action_horizon=16,
    ),
    
    # 使用 LeRobot 数据格式(适合小数据集)
    data=LeRobotDROIDDataConfig(
        # ⚠️ 替换为您的 DROID 数据集
        repo_id="your_hf_username/my_droid_dataset",
        
        base_config=DataConfig(prompt_from_task=True),
        
        # ⚠️ 重要:重用原始 DROID 归一化统计
        assets=AssetsConfig(
            assets_dir="gs://openpi-assets/checkpoints/pi05_droid/assets",
            asset_id="droid",
        ),
    ),
    
    # 加载 pi05_droid 检查点(已在 DROID 上预训练)
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi05_droid/params"
    ),
    
    num_train_steps=20_000,
    batch_size=32,
)

数据转换

cd examples/droid

# 转换 DROID 数据到 LeRobot 格式
python convert_droid_data_to_lerobot.py \
    --input-dir /path/to/droid_data \
    --output-repo your_hf_username/my_droid_dataset

适用场景

  • ✅ 自定义 DROID 数据集(< 10 小时)
  • ✅ RTX 4090 可用(batch_size=32)
  • ✅ 重用 pi05_droid 预训练权重

4.5 配置对比表

LIBERO 数据集
配置名称模型LoRA显存需求推理速度RTX 4090
pi0_liberoPi070GB+中等
pi0_libero_low_mem_finetunePi022.5GB中等✅ ⭐
pi0_fast_liberoPi0-FAST70GB+
pi0_fast_libero_low_mem_finetunePi0-FAST22.5GB✅ ⭐⭐
pi05_liberoPi0.570GB+中等
ALOHA 数据集
配置名称数据集类型LoRA显存需求RTX 4090
pi0_aloha_pen_uncap真机70GB+
pi0_aloha_lora (自定义)真机22.5GB✅ ⭐
pi05_aloha_pen_uncap真机40GB+
pi0_aloha_sim仿真70GB+
pi0_aloha_sim_lora (自定义)仿真22.5GB✅ ⭐
DROID 数据集
配置名称数据规模数据格式LoRA显存需求推荐 GPU
pi0_fast_full_droid_finetune完整RLDS70GB+8x H100
pi05_full_droid_finetune完整RLDS70GB+8x H100
pi05_droid_finetune小型LeRobot40GB+A100

4.6 RTX 4090 用户推荐配置

根据您的需求选择:

🎯 场景 1:LIBERO 仿真训练

推荐pi0_fast_libero_low_mem_finetune

# 1. 添加配置到 config.py(已内置)
# 2. 计算归一化统计
uv run scripts/compute_norm_stats.py --config-name pi0_fast_libero_low_mem_finetune

# 3. 启动训练
uv run scripts/train.py pi0_fast_libero_low_mem_finetune --exp-name=my_run

理由:LoRA + 快速推理 + 22.5GB 显存


🎯 场景 2:ALOHA 真机数据集

推荐:创建自定义 pi0_aloha_lora 配置

# 添加到 config.py
TrainConfig(
    name="pi0_aloha_lora",
    model=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora"
    ),
    data=LeRobotAlohaDataConfig(
        repo_id="your_hf_username/my_aloha_dataset",
        assets=AssetsConfig(
            assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
            asset_id="trossen",
        ),
        default_prompt="your task",
        repack_transforms=_transforms.Group(
            inputs=[
                _transforms.RepackTransform({
                    "images": {
                        "cam_high": "observation.images.cam_high",
                        "cam_left_wrist": "observation.images.cam_left_wrist",
                        "cam_right_wrist": "observation.images.cam_right_wrist",
                    },
                    "state": "observation.state",
                    "actions": "action",
                })
            ]
        ),
    ),
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_base/params"
    ),
    num_train_steps=20_000,
    freeze_filter=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora"
    ).get_freeze_filter(),
    ema_decay=None,
)

🎯 场景 3:ALOHA 仿真训练

推荐:创建 pi0_aloha_sim_lora 配置(参考上文 9.3 节)


🎯 场景 4:小规模 DROID 数据集

不推荐在 RTX 4090 上训练 DROID(即使是小数据集,也需要 32 维动作空间,显存可能不足)

如果必须尝试,使用 batch_size=16 并监控显存。


9.7 快速创建自定义配置模板

# filepath: f:\codespace\openpi-main\src\openpi\training\config.py
# 添加到 _CONFIGS 列表末尾

TrainConfig(
    # ============ 基础信息 ============
    name="<YOUR_PLATFORM>_<YOUR_TASK>_lora",  # 例如:aloha_pick_cube_lora
    
    # ============ 模型配置(LoRA)============
    model=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora",
        # 如果是 Pi0-FAST,使用:
        # model=pi0_fast.Pi0FASTConfig(
        #     action_dim=<YOUR_ACTION_DIM>,
        #     action_horizon=<YOUR_ACTION_HORIZON>,
        #     max_token_len=180,  # 单臂 180,双臂 250
        #     paligemma_variant="gemma_2b_lora"
        # ),
    ),
    
    # ============ 数据配置 ============
    # LIBERO:
    data=LeRobotLiberoDataConfig(
        repo_id="<YOUR_REPO_ID>",
        base_config=DataConfig(prompt_from_task=True),
        extra_delta_transform=True,
    ),
    # 或 ALOHA:
    # data=LeRobotAlohaDataConfig(
    #     repo_id="<YOUR_REPO_ID>",
    #     assets=AssetsConfig(
    #         assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
    #         asset_id="trossen",
    #     ),
    #     default_prompt="<YOUR_TASK_PROMPT>",
    #     repack_transforms=_transforms.Group(
    #         inputs=[_transforms.RepackTransform({
    #             "images": {
    #                 "cam_high": "observation.images.cam_high",
    #                 "cam_left_wrist": "observation.images.cam_left_wrist",
    #                 "cam_right_wrist": "observation.images.cam_right_wrist",
    #             },
    #             "state": "observation.state",
    #             "actions": "action",
    #         })]
    #     ),
    # ),
    
    # ============ 预训练权重 ============
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_base/params"
        # 或 pi0_fast_base: "gs://openpi-assets/checkpoints/pi0_fast_base/params"
        # 或 pi05_base: "gs://openpi-assets/checkpoints/pi05_base/params"
    ),
    
    # ============ 冻结参数(LoRA 必需)============
    freeze_filter=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora"
    ).get_freeze_filter(),
    # 如果是 Pi0-FAST LoRA:
    # freeze_filter=pi0_fast.Pi0FASTConfig(
    #     action_dim=<YOUR_ACTION_DIM>,
    #     action_horizon=<YOUR_ACTION_HORIZON>,
    #     max_token_len=180,
    #     paligemma_variant="gemma_2b_lora"
    # ).get_freeze_filter(),
    
    # ============ 训练超参数 ============
    num_train_steps=20_000,  # ALOHA/小数据集:20k,LIBERO:30k
    batch_size=32,           # RTX 4090 推荐 32
    
    lr_schedule=_optimizer.CosineDecaySchedule(
        warmup_steps=1_000,
        peak_lr=3e-4,
        decay_steps=20_000,
        decay_lr=1e-6,
    ),
    
    optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
    
    # ============ LoRA 专用 ============
    ema_decay=None,  # 关闭 EMA
    
    # ============ 日志 ============
    log_interval=100,
    save_interval=1000,
    keep_period=5000,
    wandb_enabled=True,
)

附录:完整配置模板

# filepath: f:\codespace\openpi-main\src\openpi\training\config.py
# 复制此模板并修改相应参数

TrainConfig(
    # ============ 基础信息 ============
    name="<YOUR_CONFIG_NAME>",          # 唯一配置名称
    exp_name="<EXPERIMENT_NAME>",       # 实验名称
    project_name="openpi",              # WandB 项目名
    
    # ============ 模型配置 ============
    model=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora",
        action_dim=<ACTION_DIM>,        # 您的动作维度
        action_horizon=<ACTION_HORIZON>, # 动作序列长度
    ),
    
    # ============ 数据配置 ============
    data=LeRobotLiberoDataConfig(       # 或您自己的 DataConfig
        repo_id="<YOUR_DATASET_REPO_ID>",
        base_config=DataConfig(
            prompt_from_task=True,
        ),
    ),
    
    # ============ 预训练权重 ============
    weight_loader=weight_loaders.CheckpointWeightLoader(
        "gs://openpi-assets/checkpoints/pi0_base/params"
    ),
    
    # ============ 冻结参数 ============
    freeze_filter=pi0_config.Pi0Config(
        paligemma_variant="gemma_2b_lora",
        action_expert_variant="gemma_300m_lora"
    ).get_freeze_filter(),
    
    # ============ 训练超参数 ============
    num_train_steps=<NUM_STEPS>,
    batch_size=<BATCH_SIZE>,
    
    lr_schedule=_optimizer.CosineDecaySchedule(
        warmup_steps=<WARMUP_STEPS>,
        peak_lr=<PEAK_LR>,
        decay_steps=<DECAY_STEPS>,
        decay_lr=<DECAY_LR>,
    ),
    
    optimizer=_optimizer.AdamW(
        clip_gradient_norm=1.0,
    ),
    
    ema_decay=None,
    
    # ============ 日志和保存 ============
    log_interval=100,
    save_interval=1000,
    keep_period=5000,
    wandb_enabled=True,
)

您可能感兴趣的与本文相关的镜像

ComfyUI

ComfyUI

AI应用
ComfyUI

ComfyUI是一款易于上手的工作流设计工具,具有以下特点:基于工作流节点设计,可视化工作流搭建,快速切换工作流,对显存占用小,速度快,支持多种插件,如ADetailer、Controlnet和AnimateDIFF等

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值