LightningCLI 详解
LightningCLI 是 PyTorch Lightning 框架提供的一个强大的命令行接口工具,它可以自动化处理深度学习项目中的许多繁琐工作。
主要功能
1. 自动创建命令行接口
cli = LightningCLI(
model_class, # 你的模型类
datamodule_class # 你的数据模块类
)
自动生成以下命令:
python script.py fit- 训练模型python script.py test- 测试模型python script.py predict- 推理预测python script.py validate- 验证模型
2. 自动解析配置文件
支持通过 YAML 配置文件管理所有超参数:
python main.py fit --config config.yaml
配置文件示例:
# config.yaml
model:
learning_rate: 0.001
hidden_dim: 256
data:
batch_size: 32
num_workers: 4
trainer:
max_epochs: 100
gpus: 2
3. 自动参数绑定
自动将命令行参数、配置文件参数绑定到模型和数据模块:
class MyModel(LightningModule):
def __init__(self, learning_rate=0.001, hidden_dim=128):
super().__init__()
# 这些参数可以从命令行或配置文件自动传入
self.lr = learning_rate
self.hidden_dim = hidden_dim
核心优势
1. 减少样板代码
传统方式需要手写大量代码:
# 传统方式 - 需要手写很多代码
parser = ArgumentParser()
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--epochs', type=int, default=100)
# ... 还有几十个参数
args = parser.parse_args()
model = Model(lr=args.lr)
datamodule = DataModule(batch_size=args.batch_size)
trainer = Trainer(max_epochs=args.epochs)
trainer.fit(model, datamodule)
使用 LightningCLI:
# 使用 LightningCLI - 一行搞定
cli = LightningCLI(Model, DataModule)
2. 自动处理复杂功能
- 分布式训练:自动处理多GPU/多节点设置
- 混合精度训练:简单配置即可启用
- 日志记录:自动集成 TensorBoard、WandB 等
- 检查点保存:自动保存最佳模型和恢复训练
- 早停机制:自动配置早停策略
在 UniVLA 项目中的使用
cli = LightningCLI(
DINO_LAM, # 潜在动作模型
LightningOpenX, # 数据模块
seed_everything_default=42,
)
这行代码实现了:
- 自动创建训练命令
# 训练 Stage 1
python main.py fit --config config/lam-stage-1.yaml
# 训练 Stage 2
python main.py fit --config config/lam-stage-2.yaml
- 自动管理配置
配置文件可能包含:
model:
encoder_type: "dino"
latent_dim: 512
codebook_size: 1024
data:
data_mix: "oxe_magic_soup" # 数据集混合策略
batch_size: 512
window_size: 2 # 视频帧窗口
trainer:
max_steps: 100000
accumulate_grad_batches: 4
devices: 8 # 使用8个GPU
- 自动处理分布式训练
# LightningCLI 自动识别 torchrun 启动的分布式环境
torchrun --nproc-per-node 8 main.py fit --config config.yaml
实际好处
-
实验管理方便
- 只需修改配置文件就能尝试不同超参数
- 配置文件可以版本控制,方便复现实验
-
代码更清晰
- 模型类只关注模型架构
- 数据类只关注数据处理
- 训练逻辑由 Lightning 自动处理
-
扩展性强
- 添加新参数只需在模型的
__init__中定义 - 支持自定义回调、日志器、优化器等
- 添加新参数只需在模型的
-
调试方便
# 快速调试模式
python main.py fit --config config.yaml --trainer.fast_dev_run=true
# 只跑2个batch测试
python main.py fit --config config.yaml --trainer.limit_train_batches=2
配置文件内容的确定方法
1. 配置项来源于类的 __init__ 参数
LightningCLI 会自动扫描你的类,将所有 __init__ 参数变成可配置项:
# 假设 DINO_LAM 类定义如下
class DINO_LAM(LightningModule):
def __init__(
self,
encoder_type: str = "dino",
latent_dim: int = 512,
codebook_size: int = 1024,
learning_rate: float = 1e-4,
beta: float = 0.25, # VQ-VAE loss weight
):
super().__init__()
self.encoder_type = encoder_type
self.latent_dim = latent_dim
# ...
# 数据模块类
class LightningOpenX(LightningDataModule):
def __init__(
self,
data_root_dir: str = "/path/to/data",
batch_size: int = 32,
num_workers: int = 4,
data_mix: str = "oxe_magic_soup",
):
super().__init__()
self.data_root_dir = data_root_dir
# ...
对应的配置文件就是:
# config.yaml
model: # 对应 DINO_LAM 类
encoder_type: "dino"
latent_dim: 512
codebook_size: 1024
learning_rate: 1e-4
beta: 0.25
data: # 对应 LightningOpenX 类
data_root_dir: "/path/to/data"
batch_size: 32
num_workers: 4
data_mix: "oxe_magic_soup"
2. 查看可用配置的方法
方法1:使用 --help 查看所有可配置参数
python main.py fit --help
这会显示所有可用的配置选项及其说明。
方法2:生成默认配置文件
# 生成包含所有默认值的配置文件
python main.py fit --print_config > default_config.yaml
方法3:查看源代码
在 UniVLA 项目中,查看这些文件:
# 1. 查看模型定义
# latent_action_model/genie/model.py 或类似位置
class DINO_LAM(LightningModule):
def __init__(self, ...): # 这里的参数就是可配置的
pass
# 2. 查看数据模块定义
# latent_action_model/genie/dataset.py
class LightningOpenX(LightningDataModule):
def __init__(self, ...): # 这里的参数也是可配置的
pass
3. UniVLA 项目中的实际配置
根据项目文档,UniVLA 已经提供了预设配置文件:
# Stage 1 配置文件
latent_action_model/config/lam-stage-1.yaml
# Stage 2 配置文件
latent_action_model/config/lam-stage-2.yaml
让我们看看一个典型的配置文件结构:
# lam-stage-1.yaml 可能的内容
model:
# VQ-VAE 架构参数
encoder_type: "dino_v2"
decoder_type: "conv"
latent_dim: 512
codebook_size: 1024
codebook_dim: 256
# 训练参数
learning_rate: 1e-4
beta: 0.25 # VQ loss weight
# DINO 特征参数
dino_model: "dinov2_vitb14"
freeze_encoder: true
data:
# 数据路径
data_root_dir: "/path/to/oxe/data"
# 数据集配置
data_mix: "oxe_magic_soup" # 预定义的数据集组合
# 批处理参数
batch_size: 64 # 每个GPU的batch size
num_workers: 8
# 视频采样参数
window_size: 2 # 连续帧数量
frame_spacing: 1
trainer: # PyTorch Lightning Trainer 参数
max_steps: 100000
accumulate_grad_batches: 8 # 梯度累积
precision: "16-mixed" # 混合精度训练
# 分布式训练
devices: 8
num_nodes: 1
strategy: "ddp"
# 检查点
enable_checkpointing: true
val_check_interval: 1000
# 日志
log_every_n_steps: 50
4. 配置优先级
LightningCLI 支持多层配置,优先级从高到低:
# 1. 命令行参数(最高优先级)
python main.py fit --model.learning_rate=0.001
# 2. 配置文件
python main.py fit --config config.yaml
# 3. 类中的默认值(最低优先级)
5. 查看项目特定配置的位置
在 UniVLA 项目中,关键配置定义位置:
# 1. 预定义的数据集组合
# prismatic/vla/datasets/rlds/oxe/mixtures.py
OXE_MAGIC_SOUP = [
("fractal20220817_data", 1.0),
("kuka", 1.0),
("bridge_v2", 1.0),
# ...
]
# 2. VLA 配置
# prismatic/conf/vla.py
VLA_CONFIGS = {
"prism-dinosiglip-224px+mx-oxe-magic-soup": {
"encoder": "dino_siglip",
"resolution": 224,
# ...
}
}
# 3. Stage 2 配置需要指定 Stage 1 的检查点
# config/lam-stage-2.yaml
model:
stage_one_ckpt: "/path/to/stage1/checkpoint.ckpt" # 需要修改这个
# 其他参数...
6. 实践建议
-
先用默认配置:项目通常提供了优化过的默认配置
# 直接使用提供的配置 python main.py fit --config config/lam-stage-1.yaml -
微调特定参数:通过命令行覆盖特定参数
python main.py fit --config config/lam-stage-1.yaml \ --data.batch_size=128 \ --model.learning_rate=5e-5 -
创建自定义配置:复制默认配置并修改
# 复制默认配置 cp config/lam-stage-1.yaml config/my_config.yaml # 编辑 my_config.yaml # 使用自定义配置 python main.py fit --config config/my_config.yaml
使用 LightningCLI 的完整指南
让我通过一个具体的例子来说明传统方式和 LightningCLI 方式的区别:
一、传统方式的代码结构
# main.py - 传统方式
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
# 1. 手写参数解析
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--data_path', type=str, default='./data')
parser.add_argument('--gpus', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=4)
args = parser.parse_args()
# 2. 定义模型
class MyModel(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.fc = nn.Linear(784, hidden_dim)
self.output = nn.Linear(hidden_dim, 10)
def forward(self, x):
return self.output(torch.relu(self.fc(x)))
# 3. 定义数据集
class MyDataset(Dataset):
def __init__(self, data_path):
self.data_path = data_path
# 加载数据...
def __len__(self):
return 1000
def __getitem__(self, idx):
return torch.randn(784), torch.randint(0, 10, (1,))
# 4. 手写训练循环
model = MyModel(args.hidden_dim)
dataset = MyDataset(args.data_path)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
optimizer = Adam(model.parameters(), lr=args.lr)
# 5. 训练循环
for epoch in range(args.epochs):
for batch in dataloader:
x, y = batch
# 前向传播
out = model(x)
loss = nn.functional.cross_entropy(out, y.squeeze())
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Loss: {loss.item()}")
二、LightningCLI 方式的代码结构
# model.py - Lightning 模型
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
class MyLightningModel(LightningModule):
def __init__(
self,
hidden_dim: int = 128, # 这些参数会自动变成配置项!
learning_rate: float = 1e-3,
weight_decay: float = 0.01,
dropout: float = 0.1,
):
super().__init__()
self.save_hyperparameters() # 自动保存所有参数
# 定义网络结构
self.fc = nn.Linear(784, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.output = nn.Linear(hidden_dim, 10)
# 保存超参数
self.learning_rate = learning_rate
self.weight_decay = weight_decay
def forward(self, x):
x = torch.relu(self.fc(x))
x = self.dropout(x)
return self.output(x)
def training_step(self, batch, batch_idx):
x, y = batch
out = self(x)
loss = nn.functional.cross_entropy(out, y.squeeze())
# 自动日志记录
self.log('train_loss', loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
out = self(x)
loss = nn.functional.cross_entropy(out, y.squeeze())
acc = (out.argmax(dim=1) == y.squeeze()).float().mean()
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(
self.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay
)
# datamodule.py - Lightning 数据模块
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_split
class MyDataModule(LightningDataModule):
def __init__(
self,
data_path: str = './data', # 这些参数也会自动变成配置项!
batch_size: int = 32,
num_workers: int = 4,
train_val_split: float = 0.8,
pin_memory: bool = True,
drop_last: bool = False,
):
super().__init__()
self.save_hyperparameters()
self.data_path = data_path
self.batch_size = batch_size
self.num_workers = num_workers
self.train_val_split = train_val_split
self.pin_memory = pin_memory
self.drop_last = drop_last
def setup(self, stage=None):
# 这个方法会在训练开始前自动调用
full_dataset = MyDataset(self.data_path)
train_size = int(len(full_dataset) * self.train_val_split)
val_size = len(full_dataset) - train_size
self.train_dataset, self.val_dataset = random_split(
full_dataset, [train_size, val_size]
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=True,
drop_last=self.drop_last,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)
class MyDataset(Dataset):
def __init__(self, data_path):
self.data_path = data_path
# 加载数据...
def __len__(self):
return 1000
def __getitem__(self, idx):
return torch.randn(784), torch.randint(0, 10, (1,))
# main.py - 超级简洁的入口文件
from pytorch_lightning.cli import LightningCLI
from model import MyLightningModel
from datamodule import MyDataModule
if __name__ == "__main__":
cli = LightningCLI(
MyLightningModel,
MyDataModule,
seed_everything_default=42,
save_config_kwargs={"overwrite": True},
)
三、配置文件的使用
# config.yaml - 所有参数都在这里配置
model:
hidden_dim: 256
learning_rate: 0.001
weight_decay: 0.01
dropout: 0.2
data:
data_path: "./data"
batch_size: 64
num_workers: 8
train_val_split: 0.8
pin_memory: true
drop_last: true
trainer: # PyTorch Lightning Trainer 的参数
max_epochs: 100
accelerator: gpu
devices: 2
precision: 16-mixed
# 梯度相关
accumulate_grad_batches: 4
gradient_clip_val: 1.0
# 日志和检查点
log_every_n_steps: 10
val_check_interval: 1.0
# 早停
callbacks:
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: val_loss
patience: 10
mode: min
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
save_top_k: 3
mode: min
filename: "{epoch}-{val_loss:.2f}"
四、使用命令对比
# 传统方式
python main.py --lr 0.001 --batch_size 32 --epochs 100
# LightningCLI 方式
python main.py fit --config config.yaml # 训练
python main.py validate --config config.yaml # 验证
python main.py test --config config.yaml # 测试
python main.py predict --config config.yaml # 推理
# 覆盖配置文件中的参数
python main.py fit --config config.yaml --model.learning_rate=0.01
# 使用不同的训练策略
python main.py fit --config config.yaml --trainer.strategy=ddp --trainer.devices=4
五、需要注意的关键点
1. 类的设计要求
# ❌ 错误:参数没有类型注解和默认值
class BadModel(LightningModule):
def __init__(self, hidden_dim, learning_rate): # 没有类型和默认值
super().__init__()
# ✅ 正确:有类型注解和默认值
class GoodModel(LightningModule):
def __init__(
self,
hidden_dim: int = 128, # 有类型注解
learning_rate: float = 1e-3, # 有默认值
):
super().__init__()
self.save_hyperparameters() # 保存超参数
2. 目录结构建议
project/
├── models/
│ └── my_model.py # LightningModule
├── data/
│ └── datamodule.py # LightningDataModule
├── configs/
│ ├── default.yaml # 默认配置
│ ├── debug.yaml # 调试配置
│ └── production.yaml # 生产配置
├── callbacks/
│ └── custom_callbacks.py # 自定义回调
├── main.py # CLI 入口
└── requirements.txt
3. 常见陷阱和解决方案
# 陷阱 1:忘记 save_hyperparameters()
class Model(LightningModule):
def __init__(self, lr: float = 1e-3):
super().__init__()
self.save_hyperparameters() # 必须加这行!
# 否则参数不会被保存到检查点
# 陷阱 2:参数名冲突
class Model(LightningModule):
def __init__(
self,
model: str = "resnet", # ❌ 'model' 是保留字
data: str = "cifar", # ❌ 'data' 是保留字
):
# 改成:
# model_name: str = "resnet"
# dataset_name: str = "cifar"
# 陷阱 3:动态参数
class Model(LightningModule):
def __init__(self, **kwargs): # ❌ CLI 无法解析
# 必须显式声明所有参数
六、高级特性
1. 多个配置文件组合
# 组合多个配置文件
python main.py fit \
--config configs/base.yaml \
--config configs/experiment.yaml
2. 子命令和子模块
# 支持多个模型
from pytorch_lightning.cli import LightningCLI
class MyCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_argument("--note", type=str, help="实验备注")
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
cli = MyCLI(
model_class=MyModel,
datamodule_class=MyDataModule,
subclass_mode_model=True, # 允许子类化
subclass_mode_data=True,
)
3. 自动调优
# tune_config.yaml
trainer:
auto_lr_find: true
auto_scale_batch_size: binsearch
python main.py fit --config tune_config.yaml
七、迁移清单
从传统方式迁移到 LightningCLI:
- 将模型改写为
LightningModule - 将数据加载改写为
LightningDataModule - 添加类型注解和默认值
- 添加
save_hyperparameters() - 创建配置文件
- 简化 main.py 为 CLI 入口
- 移除手写的训练循环
- 移除手写的参数解析
- 测试配置文件是否正确
八、性能和调试技巧
# debug.yaml - 快速调试配置
trainer:
fast_dev_run: 5 # 只跑5个batch
limit_train_batches: 0.1 # 只用10%数据
limit_val_batches: 0.1
max_epochs: 2
profiler: simple # 性能分析
detect_anomaly: true # 检测数值异常
4496

被折叠的 条评论
为什么被折叠?



