统计计算量

全连接层(Linear)

全连接层的计算量主要由矩阵乘法决定。假设:
输入维度:B × M (B为批量大小,M为输入特征数)
权重矩阵:M × N (N为输出特征数)
偏置:N

前向FLOPs:≈ 2 × B × M × N
训练FLOPs:≈ 6 × B × M × N

nn.Linear(in_features=d_in, out_features=d_out, bias=True)
  • 权重矩阵d_in × d_out
  • 偏置项d_out(若存在)

示例1:单样本推理

import torch

layer = torch.nn.Linear(1024, 4096)  # M=1024, N=4096
input = torch.randn(1, 1024)         # B=1

# 手动计算FLOPs
flops_forward = 2 * 1 * 1024 * 4096  # ≈ 8.4M FLOPs
print(f"FLOPs (forward): {flops_forward:,}")

层归一化(LayerNorm)

B:Batch size ; L:序列长度

nn.LayerNorm(d_model)
  • 公式
    FLOPs ≈ 5 × B × L × d_model
    • 5:均值、方差、归一化、缩放、偏移
  • 示例
    B=32, L=512, d_model=7685×32×512×768 ≈ 63 MFLOPs

词嵌入层(Embedding)

  • 通常不计FLOPs:仅为查表操作,但投影回词表需计算:
    2 × B × L × d_model × V
    (实际中常忽略,因V过大)

自注意力机制

自注意力机制的总FLOPs计算公式 4d²N + 2dN² 是通过逐步分析注意力机制的四个核心运算环节得到的。以下是详细的推导过程:

1. 自注意力的FLOPs

# 假设输入x: (B, L, d_model)
Q = x @ W_q  # W_q: (d_model, d_head)
K = x @ W_k
V = x @ W_v
attn = (Q @ K.T) / sqrt(d_head)
output = attn @ V @ W_o  # W_o: (d_head, d_model)

假设:

  • 模型维度(隐藏层大小): d d d(即 d model d_{\text{model}} dmodel
  • 序列长度: N N N
  • 注意力头数: h h h(但最终FLOPs与头数无关,因 d = h × d k d = h \times d_k d=h×dk

(1) Q/K/V投影(输入线性变换)

  • 三个独立的矩阵乘法: X W Q XW_Q XWQ, X W K XW_K XWK, X W V XW_V XWV,其中 X ∈ R N × d X \in \mathbb{R}^{N \times d} XRN×d W Q / W K / W V ∈ R d × d W_Q/W_K/W_V \in \mathbb{R}^{d \times d} WQ/WK/WVRd×d
  • 每个投影的FLOPs:
    2 × N × d × d = 2 N d 2 2 \times N \times d \times d = 2Nd^2 2×N×d×d=2Nd2
    (解释:矩阵乘法 ( N × d ) × ( d × d ) (N \times d) \times (d \times d) (N×d)×(d×d) 的FLOPs为 2 × 2 \times 2× 行×列×内积维度)
  • 三个投影总FLOPs: 3 × 2 N d 2 = 6 N d 2 3 \times 2Nd^2 = 6Nd^2 3×2Nd2=6Nd2

(2) QKᵀ计算(注意力分数)

  • 计算 Q K T QK^T QKT,其中 Q , K ∈ R N × d Q, K \in \mathbb{R}^{N \times d} Q,KRN×d
  • FLOPs: 2 × N × N × d = 2 N 2 d 2 \times N \times N \times d = 2N^2d 2×N×N×d=2N2d
    (解释:每个注意力分数 Q i K j T Q_i K_j^T QiKjT 需要 d d d 次乘加运算,共 N 2 N^2 N2 个分数)

(3) Softmax归一化

  • N × N N \times N N×N 的注意力矩阵每行计算softmax:
    • 指数运算: N 2 N^2 N2
    • 求和及除法: 2 N 2 2N^2 2N2
  • 总FLOPs:
    3 N 2 3N^2 3N2
    (通常远小于其他项,可忽略)

(4) 注意力加权(AV计算)

  • 计算 A × V A \times V A×V,其中 A ∈ R N × N A \in \mathbb{R}^{N \times N} ARN×N V ∈ R N × d V \in \mathbb{R}^{N \times d} VRN×d
  • FLOPs: 2 × N × N × d = 2 N 2 d 2 \times N \times N \times d = 2N^2d 2×N×N×d=2N2d

(5) 输出投影

  • 将注意力结果通过矩阵 W O ∈ R d × d W_O \in \mathbb{R}^{d \times d} WORd×d 投影
  • FLOPs: 2 × N × d × d = 2 N d 2 2 \times N \times d \times d = 2Nd^2 2×N×d×d=2Nd2

2. 合并同类项

将各步骤FLOPs相加:
总FLOPs = 6 N d 2 ⏟ Q/K/V投影 + 2 N 2 d ⏟ QKᵀ + 2 N 2 d ⏟ AV + 2 N d 2 ⏟ 输出投影 = ( 6 N d 2 + 2 N d 2 ) + ( 2 N 2 d + 2 N 2 d ) = 8 N d 2 + 4 N 2 d \begin{align*} \text{总FLOPs} &= \underbrace{6Nd^2}_{\text{Q/K/V投影}} + \underbrace{2N^2d}_{\text{QKᵀ}} + \underbrace{2N^2d}_{\text{AV}} + \underbrace{2Nd^2}_{\text{输出投影}} \\ &= (6Nd^2 + 2Nd^2) + (2N^2d + 2N^2d) \\ &= 8Nd^2 + 4N^2d \end{align*} FLOPs=Q/K/V投影 6Nd2+QKᵀ 2N2d+AV 2N2d+输出投影 2Nd2=(6Nd2+2Nd2)+(2N2d+2N2d)=8Nd2+4N2d

为什么常见公式是 4 N d 2 + 2 N 2 d 4Nd^2 + 2N^2d 4Nd2+2N2d
  • 优化实现:实际中,Q/K/V的投影可通过单个大矩阵乘法完成(合并参数矩阵),将三个投影的FLOPs从 6 N d 2 6Nd^2 6Nd2 降为 2 × N × d × 3 d = 6 N d 2 2 \times N \times d \times 3d = 6Nd^2 2×N×d×3d=6Nd2(无变化)。但输出投影的 2 N d 2 2Nd^2 2Nd2 通常保留。
  • 更精确的简化:
    通常将总FLOPs近似为:
    总FLOPs ≈ 4 N d 2 + 2 N 2 d \text{总FLOPs} \approx 4Nd^2 + 2N^2d FLOPs4Nd2+2N2d
    这是因为:
    1. Q/K/V投影的 6 N d 2 6Nd^2 6Nd2 中,输出投影的 2 N d 2 2Nd^2 2Nd2 已单独计算;(原因:Q/K/V投影可部分合并计算)
    2. 实际代码优化可能减少重复计算(如融合操作)。

3. 直观理解

  • 4 N d 2 4Nd^2 4Nd2 项:
    来自输入/输出的线性变换(Q/K/V投影和输出投影),与模型维度 d d d 的平方成正比。
  • 2 N 2 d 2N^2d 2N2d 项:
    来自注意力分数的计算和加权,与序列长度 N N N 的平方成正比,是长序列的瓶颈。

4. 实例验证(GPT-3)

  • 参数: d = 12288 d=12288 d=12288, N = 2048 N=2048 N=2048
  • 计算:
    4 × 12288 2 × 2048 ≈ 1.24 × 10 12 2 × 12288 × 2048 2 ≈ 1.03 × 10 11 总FLOPs ≈ 1.34 × 10 12 (每层) 4 \times 12288^2 \times 2048 \approx 1.24 \times 10^{12} \\ 2 \times 12288 \times 2048^2 \approx 1.03 \times 10^{11} \\ \text{总FLOPs} \approx 1.34 \times 10^{12} \text{(每层)} 4×122882×20481.24×10122×12288×204821.03×1011FLOPs1.34×1012(每层)
    与GPT-3论文中的单层FLOPs估算一致。

5. 总结

公式 FLOPs ≈ 4 N d 2 + 2 N 2 d \text{FLOPs} \approx 4Nd^2 + 2N^2d FLOPs4Nd2+2N2d 的由来:

  1. 线性变换( 4 N d 2 4Nd^2 4Nd2):
    • Q/K/V合并投影: 2 N d 2 2Nd^2 2Nd2
    • 输出投影: 2 N d 2 2Nd^2 2Nd2
  2. 注意力交互( 2 N 2 d 2N^2d 2N2d):
    • QKᵀ + AV计算:各 N 2 d N^2d N2d

此公式是理论最小值,实际实现可能因优化(如融合kernel)略有差异。

前馈网络(FFN)

nn.Sequential(
    nn.Linear(d_model, 4*d_model),  # 扩展
    nn.GELU(),
    nn.Linear(4*d_model, d_model)    # 收缩
)

通常为两层全连接,若中间维度 4 d model 4d_{\text{model}} 4dmodel

  • 公式
    FLOPs = 2 × B × L × (d_model × 4d_model + 4d_model × d_model)
  • 示例
    B=32, L=512, d_model=7682×32×512×(768×3072 + 3072×768) ≈ 1.5 TFLOPs

FLOPs = 2 × d model × 4 d model × L × 2 = 16 d model 2 N \text{FLOPs} = 2 \times d_{\text{model}} \times 4d_{\text{model}} \times L \times 2 = 16d_{\text{model}}^2 N FLOPs=2×dmodel×4dmodel×L×2=16dmodel2N

整体模型计算量

对于包含 ( L ) 层的Transformer模型:

  • 单层FLOPs
    FLOPs layer = FLOPs attention + FLOPs FFN \text{FLOPs}_{\text{layer} = \text{FLOPs}_{\text{attention}} + \text{FLOPs}_{\text{FFN}}} FLOPslayer=FLOPsattention+FLOPsFFN
  • 总FLOPs(序列长度 ( N ) ( N ) (N))
    FLOPs total ≈ L × ( 4 d model 2 N + 2 d model N 2 + 16 d model 2 N ) \text{FLOPs}_{\text{total}} \approx L \times (4d_{\text{model}}^2 N + 2d_{\text{model}} N^2 + 16d_{\text{model}}^2 N) FLOPstotalL×(4dmodel2N+2dmodelN2+16dmodel2N)
    可简化为:
    FLOPs total ≈ L × ( 20 d model 2 N + 2 d model N 2 ) \text{FLOPs}_{\text{total}} \approx L \times (20d_{\text{model}}^2 N + 2d_{\text{model}} N^2) FLOPstotalL×(20dmodel2N+2dmodelN2)

卷积层

卷积层的计算量通常以浮点运算次数(FLOPs)衡量,以下是标准卷积和深度可分离卷积的计算公式及关键影响因素:
1. 标准卷积的计算量

公式(考虑乘法和加法):

FLOPs = 2 × C in × K 2 × C out × H out × W out \text{FLOPs} = 2 \times C_{\text{in}} \times K^2 \times C_{\text{out}} \times H_{\text{out}} \times W_{\text{out}} FLOPs=2×Cin×K2×Cout×Hout×Wout

其中:

  • C in C_{\text{in}} Cin:输入通道数

  • C out C_{\text{out}} Cout:输出通道数

  • K:卷积核尺寸(假设为方形)

  • H out H_{\text{out}} Hout, W out W_{\text{out}} Wout:输出特征图尺寸,由输入尺寸、步长(stride)和填充(padding)决定:

    H out = ⌊ H in + 2 P − K S ⌋ + 1 H_{\text{out}} = \left\lfloor \frac{H_{\text{in}} + 2P - K}{S} \right\rfloor + 1 Hout=SHin+2PK+1

关键点:

  • 系数2:包含乘法(C_{\text{in}} \times K^2次)和加法(C_{\text{in}} \times K^2 - 1次)。

  • 偏置项:每个输出点加1次加法,但通常忽略(占比小)。

2. 深度可分离卷积的计算量

公式(分两部分):

  • 深度卷积(逐通道卷积):

    FLOPs depthwise = C in × K 2 × H out × W out \text{FLOPs}_{\text{depthwise}} = C_{\text{in}} \times K^2 \times H_{\text{out}} \times W_{\text{out}} FLOPsdepthwise=Cin×K2×Hout×Wout

  • 逐点卷积(1×1卷积):

    FLOPs pointwise = C in × C out × H out × W out \text{FLOPs}_{\text{pointwise}} = C_{\text{in}} \times C_{\text{out}} \times H_{\text{out}} \times W_{\text{out}} FLOPspointwise=Cin×Cout×Hout×Wout

  • 总计算量:
    FLOPs total = C in × H out × W out × ( K 2 + C out ) \text{FLOPs}_{\text{total}} = C_{\text{in}} \times H_{\text{out}} \times W_{\text{out}} \times (K^2 + C_{\text{out}}) FLOPstotal=Cin×Hout×Wout×(K2+Cout)

优势:相比标准卷积,计算量减少约 1 C out + 1 K 2 \frac{1}{C_{\text{out}}} + \frac{1}{K^2} Cout1+K21倍。

3. 注意事项

FLOPs与MACs区别:部分工具(如fvcore)返回MACs(乘加次数),1 MACs ≈ 2 FLOPs。

工具

** 手动计算**

def calculate_llm_flops(L, d_model, N):
    flops_attention = 4 * d_model**2 * N + 2 * d_model * N**2
    flops_ffn = 16 * d_model**2 * N
    return L * (flops_attention + flops_ffn)

# GPT-3 175B示例
flops = calculate_llm_flops(L=96, d_model=12288, N=2048)
print(f"总FLOPs: {flops / 1e12:.0f} TFLOPs")

** 库**

pip install torchprofile
from torchprofile import profile_macs
macs = profile_macs(model, input)
flops = 2 * macs  # 1 MAC = 2 FLOPs

工具汇总

以下是 PyTorch 模型计算量统计工具的详细总结,涵盖 主流工具核心功能具体使用方法,帮助您快速选择适合的工具并正确统计 FLOPs、参数量等指标。


工具对比

工具名称维护状态核心功能特点适用场景
fvcore活跃FLOPs、参数量、激活值统计Facebook 官方支持,精度高科研论文、模型优化
thop维护中FLOPs(MACs×2)、参数量轻量级,API简单快速验证、原型开发
torchstat停止更新分层FLOPs/参数量/内存消耗可视化网络结构教学/调试(静态输入)
ptflops维护中类似 torchstat,兼容新版 PyTorch支持动态输入需要分层统计的场景
PyTorch Profiler官方工具时间/内存/FLOPs(需CUDA)结合运行时性能分析性能调优

工具功能适用场景
THOPFLOPs + 参数量通用模型计算量统计
torchsummary参数量 + 层结构快速检查模型架构
ptflopsFLOPs + 参数量(逐层)深度优化分析
PyTorch 内置仅参数量简单参数统计
ONNX 导出计算图可视化调试复杂模型

工具选型建议

需求推荐工具理由
高精度FLOPs统计fvcore支持逐层分解,结果可靠
快速验证thop一行代码获取结果,适合开发调试
可视化网络结构ptflops替代 torchstat,兼容新版 PyTorch
动态输入模型fvcore/ptflops均可处理可变输入尺寸
性能调优(时间/内存)PyTorch Profiler官方工具,结合运行时分析

推荐优先使用 fvcore,需要快速验证时选择 thop,可视化分析用 ptflops
选择工具时,可以根据需求决定:

  • 如果只需参数量,用 torchsummary 或 PyTorch 内置方法。
  • 如果需要 FLOPs 和参数量,推荐 THOPptflops
  • 对于更复杂的分析(如逐层计算),ptflops 更合适。

1. 为什么不同工具统计结果不一致?

  • 原因:FLOPs 计算方式不同(如是否忽略偏置项、是否×2转换MACs)。
  • 解决方案:统一使用 fvcore 或手动验证基础层(如 Conv2d 的 FLOPs=2 × Cin × Cout × K × K × H × W)。

2. 如何统计自定义层的FLOPs?

  • fvcore:通过 FlopCountAnalysis 自动检测(需继承 torch.nn.Module)。
  • thop:使用 custom_ops 参数手动注册计算规则。

3. 动态输入如何处理?

# fvcore/ptflops 均支持
input = torch.randn(1, 3, 256, 256)  # 任意尺寸
flops = FlopCountAnalysis(model, input).total()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值