GMM NZ 全流程详解实战:FSDP MOE 训练加速

GMM NZ 全流程详解实战:FSDP MOE 训练加速


一、从推理到训练:性能优化的新起点

在昇腾生态的推理实践中,Grouped MatMul(简称 GMM)结合 NZ(Fractal_NZ)格式已被证明能带来显著性能收益。 但当我们试图把这项优化迁移到训练端,各种各样的问题接踵而至——框架支持差异、反向计算适配、内存搬运开销等,使得原本的“推理提速方案”在训练阶段几乎无法使用。

本次实战的目标,就是在 FSDP2 框架下训练 Qwen3 235B MOE 模型 时,使能 GMM NZ 格式,并尽可能获得真实的性能提升。

小示例:如何构造一个最小 GMM 输入

import torch
import torch_npu

# 构造一个最小化 GMM 输入
x = torch.randn(4, 16).npu()          # 模拟 4 tokens,dim 16
weight = torch.randn(2, 16, 32).npu() # 模拟 2 个专家,每个 Expert 的 W: 16→32
tokens_per_expert = torch.tensor([2, 2], dtype=torch.int32).npu()

outs = torch_npu.npu_grouped_matmul(
    [x], [weight],
    group_list=tokens_per_expert,
    group_type=0, group_list_type=1, split_item=2
)[0]

print("outs shape:", outs.shape)
print("outs dtype:", outs.dtype)


二、GMM 算子适配:从前向到反向的全链路打通

1. PTA 框架下的 Grouped MatMul

昇腾提供的 torch_npu.npu_grouped_matmul 能够批量执行矩阵乘法,将相同形状的计算统一处理,以减少内存访问和调度开销。 对于 MOE 结构,它能自然支持按专家维度(Expert)分组计算。

比较经典的调用如下:

outs = torch_npu.npu_grouped_matmul(
    [x], [weight],
    group_list=tokens_per_expert,
    group_type=0, group_list_type=1, split_item=2
)[0]

这段代码把输入 x 按 token 数切分为多个专家子块,与对应的权重分片分别执行 MatMul,然后合并输出。

2. 手动适配反向传播

然而 npu_grouped_matmul 并未原生支持反向计算。 为此,我们基于 PyTorch 的 Function 机制,手动实现了前反向逻辑:

class NpuGMMOp(Function):
    @staticmethod
    def forward(ctx, weight, x, tokens_per_expert):
        ctx.save_for_backward(weight, x, tokens_per_expert)
        outs = torch_npu.npu_grouped_matmul(
            [x], [weight],
            group_list=tokens_per_expert, group_type=0,
            group_list_type=1, split_item=2
        )[0]
        return outs

    @staticmethod
    def backward(ctx, grad_output):
        weight, input_tensor, tokens_per_expert = ctx.saved_tensors
        grad_input = torch_npu.npu_grouped_matmul(
            [grad_output], [weight.transpose(1,2)],
            group_list=tokens_per_expert, group_type=0,
            group_list_type=1, split_item=2
        )[0]
        grad_weight = torch_npu.npu_grouped_matmul(
            [input_tensor.T], [grad_output],
            group_list=tokens_per_expert, split_item=3,
            group_type=2, group_list_type=1
        )[0]
        return grad_weight, grad_input, None

到此,GMM 的正反向计算路径已在 PTA 框架中打通了。

验证

# 验证自定义 GMM 是否能够正常完成前反向
x = torch.randn(4, 16, dtype=torch.float16).npu().requires_grad_(True)
weight = torch.randn(2, 16, 32, dtype=torch.float16).npu().requires_grad_(True)
tokens = torch.tensor([2, 2], dtype=torch.int32).npu()

out = NpuGMMOp.apply(weight, x, tokens)
loss = out.sum()
loss.backward()

print("out shape:", out.shape)
print("x grad:", x.grad.shape)
print("weight grad:", weight.grad.shape)

那就是打通啦,嘻嘻。


三、ND → NZ 转换带来的新开销

虽然理论上,权重转为 NZ 格式应能减少访存并提升并行效率。 但在训练中,如果我们在前向前执行如下操作:

weight = torch_npu.npu_format_cast(weight, 29)  # ND → NZ

那么每次调用都引入额外的 transdata 开销。 Profiling 结果显示:这部分耗时足以抵消 GMM NZ 本身带来的性能收益。 于是,如何在不显著增加数据搬运的情况下获得 NZ 格式,是新的方向。


四、sliceNZ:一次搬运,双倍收益

1. 设计动机

在 FSDP2 框架下,MOE 权重会经历如下流程:

初始化 → AllGather → Slice → 进入 GMM 前向

slice 操作与 ND→NZ 转换都涉及显存(HBM)搬运。 因此,我们尝试将这两个步骤融合为一个算子,即 sliceNZ。 这样一来,只需一次内存访问,就能完成切分与格式转换。

2. 算子原型

sliceNZ 算子的核心定义如下:

aclnnSliceNzGetWorkspaceSize(
    const aclTensor *in,
    uint64_t dim,
    uint64_t start,
    uint64_t end,
    aclTensor *output
)

output 通常是三维张量 [num_expert, n, k],当专家数为 1 时为二维。 输出结果将直接进入 GMM 的前反向运算。

3. PTA 注册与调用

op_plugin_functions.yaml 中新增接口:

- func: npu_special_slice(Tensor self, int dim, int start, int end, Tensor(a!) output) -> ()
    op_api: [v2.1, newest]
    gen_opapi:
      exec: aclnnSliceNz, self, dim, start, end, output

随后,在 PTA 源码中替换原有 slice 实现逻辑:

@torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1")
def split_with_sizes_copy(all_gather_output, all_gather_input_split_sizes, dim, out):
    if len(all_gather_input_split_sizes) > 1 and out[-1].shape[0] * out[-1].shape[1] >= 128 * 4096 * 1536:
        from special_op import npu_special_slice
        num_exp = 128
        in_dim1, out_dim1 = 1536, 4096
        in_dim2, out_dim2 = 4096, 3072

        out[-1].resize_(num_exp, out_dim1, in_dim1)
        out[-2].resize_(num_exp, out_dim2, in_dim2)

        npu_special_slice(all_gather_output, dim, weight_1_start, total_size, out[-1])
        npu_special_slice(all_gather_output, dim, weight_2_start, weight_1_start, out[-2])

经过编包、安装并重新加载 PTA,即可在 FSDP 训练过程中调用 sliceNZ 实现权重的自动转 NZ。


五、Profiling验证:NZ 启用成功,但性能仍未提升

当我们初步完成 sliceNZ 接入并在 FSDP2 框架中跑通训练流程后,第一时间使用 msprof 进行了全流程性能采样。

msprof --application="python train_qwen_moe.py" --output=./profile_nz

采样结果中可以明显看到以下现象:

  • sliceNZ 输出张量的格式标识为 FRACTAL_NZ,说明格式转换成功;
  • GroupedMatmul 的入参依旧被识别为 ND 格式,没有走到高性能 NZ 分支;
  • 整体 step 耗时与 ND baseline 基本持平,未见预期加速。

这显然是一个“成功了一半”的结果。

进一步对比 msprof 的 operator-level 统计发现:

  • sliceNZ 自身执行耗时约 1.7 ms,比原 slice + transdata 的组合方案略优(节省约 0.4 ms);
  • 但在 GMM 前反向中,出现了 多次冗余的 Transpose 操作,累计耗时 3~4 ms;
  • 每个 step 的计算图中,GMM 权重在正反向之间频繁出现 ND <-> NZ 转换迹象。

关键问题: 由于 PTA 框架对 tensor 的格式识别仍停留在 ND 层面,即便底层张量已经是 FRACTAL_NZ,框架仍会强制执行一次 transpose,以符合 ND 算法接口要求。这直接抵消了 NZ 格式的加速收益。

Profiling 数据片段(示例)

算子名调用次数平均耗时 (ms)格式备注
npu_special_slice11.72NZsliceNZ 成功执行
npu_grouped_matmul43.15ND未识别 NZ
npu_transpose83.92ND → NZ冗余操作
total_step_time-22.1-与 baseline 基本持平

由此可见,瓶颈不在 sliceNZ 本身,而在格式识别与矩阵布局转换


Profiling 采集与关键指标分析

在验证 sliceNZ 的实际效果时,可以使用 msproftorch_npu.profiler 双路径采样。 下面是完整代码模板(兼容昇腾 NPU 环境):

import torch
import torch_npu
from torch_npu.profiler.analysis.profiler_config import ProfilerConfig
from torch_npu.profiler import profiler

# GMM 训练循环
def train_step(weight, x, tokens_per_expert):
    out = NpuGMMOp.apply(weight, x, tokens_per_expert)
    loss = out.sum()
    loss.backward()
    return loss.item()

# 配置 NPU profiler
config = ProfilerConfig(
    output_path="./profile_data",
    record_shapes=True,
    with_stack=True,
    sample_rate=1,
)

with profiler(config):
    for i in range(10):
        loss = train_step(weight, x, tokens_per_expert)
        if i % 2 == 0:
            print(f"Step {i} loss: {loss:.4f}")

# 结果分析命令:
# msprof --view ./profile_data --report report.html

检查 tensor 格式的辅助代码

Profiling 结果只是宏观确认,我们也可以在训练中动态打印 tensor 格式验证:

def debug_format_check(tensor, name="tensor"):
    fmt = torch_npu.get_npu_format(tensor)
    print(f"[DEBUG] {name} format: {fmt}")

# 调用
debug_format_check(weight, "sliceNZ weight")
debug_format_check(out, "GMM output")

输出:

[DEBUG] sliceNZ weight format: FRACTAL_NZ [DEBUG] GMM output format: ND


六、瓶颈优化:提前转置消除额外开销

在分析性能热点后,我们锁定问题根源——GroupedMatmul 的输入维度约定

1. 问题剖析

在常规的线性层中,我们的计算逻辑是:

其中

的 shape 通常为 (out_feature, in_feature)

而在 torch_npu.npu_grouped_matmul 实现中,权重输入的 shape 被定义为:

这意味着当用户直接传入 PyTorch 风格的权重时,系统会自动执行一次 transpose(in_feature, out_feature) 来匹配算子接口。

问题在于,在 FSDP 训练中,这种自动 transpose 每个 step 都会触发! 结合反向传播阶段的多次权重复用,每个 GMM 模块的 transpose 累积耗时甚至超过了算子本身的计算时间


2. 解决思路:将问题前移

既然 transpose 是为满足算子入参约定,那就不妨提前在权重加载阶段完成一次性转置,让后续计算直接以目标布局运行。

在 Qwen3 235B MOE 中,涉及 GMM 的主要模块包括:

  • gate_proj
  • up_proj
  • down_proj

其中 gate_projup_proj 共享输入维度并在输出轴拼接,因此我们可以在模型初始化时,对其权重执行:

def prepare_weight(weight):
    # 仅在模型加载阶段执行一次
    return weight.permute(1, 0).contiguous()

model.gate_proj.weight.data = prepare_weight(model.gate_proj.weight.data)
model.up_proj.weight.data   = prepare_weight(model.up_proj.weight.data)
model.down_proj.weight.data = prepare_weight(model.down_proj.weight.data)

这样,在训练阶段 GMM 调用时,权重与算子接口天然对齐,无需额外 transpose。


3. 效果验证

优化后再次采集 Profiling,性能表现如下:

指标项优化前 (ND)启用 sliceNZ启用 sliceNZ + 提前转置
每 step 耗时22.1 ms21.7 ms17.4 ms
GMM 模块计算耗时8.6 ms8.4 ms5.2 ms
transpose 调用数8 次6 次2 次
内存带宽占用率88%86%79%

最终结果表明:

  • sliceNZ 仅能带来小幅收益;
  • 结合提前转置后,GMM NZ 的性能提升约 21%
  • 整体 FSDP2 训练吞吐(tokens/s)提升约 18.6%

从单步优化到全流程增益,关键就在于:减少不必要的数据重排,让计算更贴近硬件本地格式


4. 小总结

我们发现,算子级性能提升的瓶颈往往不在算子本身,而在 框架与算子之间的“接口层”。 NZ 格式在推理阶段表现优异,是因为数据格式在编译前已固化; 而训练中,动态反向路径和多进程通信导致格式频繁切换。

提前转置、格式冻结、算子融合,正是我们将推理优化思路迁移到训练端的重点。


5. 代码:模型加载阶段提前转置(可直接嵌入Qwen或MOE模块)

def prepare_weight(weight):
    # 仅在模型初始化或权重加载阶段执行
    return weight.permute(1, 0).contiguous()

def apply_pre_transpose(model):
    # 识别并处理所有包含 GroupedMatmul 的模块
    target_layers = ["gate_proj", "up_proj", "down_proj"]
    for name, module in model.named_modules():
        for t in target_layers:
            if t in name and hasattr(module, "weight"):
                module.weight.data = prepare_weight(module.weight.data)
                print(f"[INFO] Transposed weight in {name}")

在训练启动前调用:

apply_pre_transpose(model)

这样后续执行 GMM 算子时,权重布局已经与算子定义完全一致。


6. 代码:对比测试小脚本:观察 transpose 调用差异

你可以通过计时测试不同方案的耗时差:

import time

def test_transpose_cost():
    x = torch.randn(4096, 1536, dtype=torch.bfloat16).npu()
    w = torch.randn(1536, 4096, dtype=torch.bfloat16).npu()

    # baseline: 每次临时转置
    t1 = time.time()
    for _ in range(100):
        y = torch.matmul(x, w.T)
    torch.npu.synchronize()
    t2 = time.time()
    print("临时转置耗时:", t2 - t1, "s")

    # 优化后:预转置一次
    w_opt = w.T.contiguous()
    t3 = time.time()
    for _ in range(100):
        y = torch.matmul(x, w_opt)
    torch.npu.synchronize()
    t4 = time.time()
    print("提前转置耗时:", t4 - t3, "s")

test_transpose_cost()

实测结果直观展示:提前转置能减少 15%~25% 的整体矩阵乘耗时。


7. 代码:训练性能统计脚本

最后可以附带一段性能采集逻辑,用于评估整体吞吐:

import time

def benchmark_training(model, dataloader, steps=100):
    start = time.time()
    for i, batch in enumerate(dataloader):
        if i >= steps:
            break
        loss = model(batch)
        loss.backward()
    torch.npu.synchronize()
    end = time.time()
    print(f"Avg step time: {(end - start) / steps * 1000:.2f} ms")

# 比较两种方案
print("=== Baseline (ND) ===")
benchmark_training(model_nd, dataloader)

print("=== With sliceNZ + Pre-Transpose ===")
benchmark_training(model_nz, dataloader)

配合 tokens/s 统计,就能量化最终性能收益。


七、总结与启示

整个 GMM NZ 训练使能过程,从框架底层到算子适配,经历了多次性能权衡与调试。 最终的经验可总结为:

  • 一次搬运,双重优化:sliceNZ 的核心价值在于合并 slice 与格式转换;
  • 格式识别是关键:PTA 层需正确识别 NZ 格式,才能真正走高性能路径;
  • 冗余操作要前移处理:提前转置可显著减少训练时的算子级开销;
  • 验证手段要量化:profiling 数据是判断优化是否有效的唯一依据。

在推理端验证的性能思路,迁移到训练端有时候并不是简单的复用。 但是通过算子融合、前后向一体化设计与格式识别优化, GMM NZ 的潜力在训练端才被真正激活。

未来,随着 PTA 主线对 NZ 支持的持续完善, 这套方案也将更容易被开发者直接复用在更大规模的模型训练中。

注明:昇腾PAE案例库对本文写作亦有帮助。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

初九之潜龙勿用

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值