分布式训练与多设备加速:MirroredStrategy、TPUStrategy、ParameterServerStrategy 深度剖析
📌 本章重点:
- 理解 TensorFlow 分布式策略的设计理念
- 使用 MirroredStrategy 在单机多卡训练
- 使用 TPUStrategy 实现 TPU 上的高性能训练
- 使用 ParameterServerStrategy 进行多机多进程调度
- 编写分布式友好的训练代码(可单机/分布式通用)
- 完整分布式训练流程示例(含代码)
一、分布式训练架构图谱总览
┌───────────────────────────────┐
│ tf.distribute.Strategy │
└────────────┬──────────────────┘
│
┌─────────────────┼──────────────────┐
▼ ▼ ▼
MirroredStrategy TPUStrategy ParameterServerStrategy
(单机多GPU) (TPU) (多机多worker)
二、核心理念:分布式训练中的数据、梯度与变量管理
概念 | 解释 |
---|---|
Global Batch Size | 总 batch size,会被每个 replica 平分 |
Replica | 每块设备(GPU/TPU)上运行的模型副本 |
Strategy Scope | with strategy.scope(): 内部定义的变量会自动同步 |
AllReduce | 多卡之间同步梯度的操作(用于同步训练) |
CrossReplicaContext | 用于获取/控制多副本状态,如 strategy.reduce() |
三、📍MirroredStrategy:单机多卡的首选
适合:单机多GPU(如 2/4/8块卡)训练,最常见策略。
✅ 使用方式:
strategy = tf.distribute.MirroredStrategy()
print("Num GPUs:", strategy.num_replicas_in_sync)
✅ 使用 strategy.scope()
构建模型与优化器
with strategy.scope():
model = create_model()
optimizer = tf.keras.optimizers.Adam()
model.compile(optimizer=optimizer, loss='mse')
所有
tf.Variable
必须在strategy.scope()
内部创建,以确保设备复制一致。
✅ 使用 fit()
的自动并行训练:
model.fit(train_ds, validation_data=val_ds, epochs=5)
Keras 会自动:
- 拆分 batch 到每个 GPU
- 汇总 loss 和 metrics
- 自动 AllReduce 进行梯度平均
四、🚀 TPUStrategy:极致训练速度(支持 Google Cloud / Colab TPU)
适合:大规模预训练、Transformer、ViT、LLM 等需要大 batch 的任务
✅ 启用 TPU:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
✅ 与 MirroredStrategy 使用完全一致:
with strategy.scope():
model = create_model()
model.compile(...)
model.fit(...)
⚠️ 注意:TPU 不支持某些 op(如
tf.image.resize_bicubic
),建议提前测试模型兼容性。
五、🔗 ParameterServerStrategy:多机异步训练(工业部署)
适合:多机异步任务调度,尤其用于大模型 + 非对称硬件资源集群
✅ 启动多角色进程:
# worker 0
python train.py --job_name=worker --task_index=0
# worker 1
python train.py --job_name=worker --task_index=1
# ps 0
python train.py --job_name=ps --task_index=0
配置集群地址:
cluster_spec = {
"worker": ["host1:port", "host2:port"],
"ps": ["host3:port"]
}
os.environ["TF_CONFIG"] = json.dumps({
"cluster": cluster_spec,
"task": {"type": job_name, "index": task_index}
})
然后用:
strategy = tf.distribute.ParameterServerStrategy()
工业中,PS 策略用于异步更新梯度,效率高但一致性差
六、自定义训练循环(跨策略通用写法)
✅ 推荐写法:
with strategy.scope():
model = MyModel()
optimizer = tf.keras.optimizers.Adam()
@tf.function
def distributed_train_step(batch_data):
per_replica_losses = strategy.run(train_step, args=(batch_data,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
for epoch in range(epochs):
for batch in train_dist_ds:
loss = distributed_train_step(batch)
七、数据加载:使用 strategy.experimental_distribute_dataset
GLOBAL_BATCH_SIZE = 256
train_ds = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10000).batch(GLOBAL_BATCH_SIZE)
train_dist_ds = strategy.experimental_distribute_dataset(train_ds)
- 自动分配数据子集给每个副本
- 可以与
@tf.function
配合自动图优化
八、完整实战:2卡训练 MNIST
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(128)
train_dist_ds = strategy.experimental_distribute_dataset(train_ds)
model.fit(train_dist_ds, epochs=5)
支持 GPU 并行、TPU、Cloud 分布式一键切换
九、各种策略对比总结:
策略 | 场景 | 特点 | 推荐使用 |
---|---|---|---|
MirroredStrategy | 单机多GPU | 同步训练,易部署 | ✅ 常用 |
TPUStrategy | Cloud TPU | 高吞吐量,支持超大 batch | ✅ 预训练 |
MultiWorkerMirroredStrategy | 多机 | 自动 AllReduce,同步更新 | ✅ 工程级 |
ParameterServerStrategy | 多机异步 | 异步更新,效率高一致性低 | 高级场景 |
CentralStorageStrategy | 多GPU但内存不足 | 参数放在 CPU | 特殊情况 |