文章目录
全连接层(Linear)
全连接层的计算量主要由矩阵乘法决定。假设:
输入维度:B × M (B为批量大小,M为输入特征数)
权重矩阵:M × N (N为输出特征数)
偏置:N
前向FLOPs:≈ 2 × B × M × N
训练FLOPs:≈ 6 × B × M × N
nn.Linear(in_features=d_in, out_features=d_out, bias=True)
- 权重矩阵:
d_in × d_out
- 偏置项:
d_out
(若存在)
示例1:单样本推理
import torch
layer = torch.nn.Linear(1024, 4096) # M=1024, N=4096
input = torch.randn(1, 1024) # B=1
# 手动计算FLOPs
flops_forward = 2 * 1 * 1024 * 4096 # ≈ 8.4M FLOPs
print(f"FLOPs (forward): {flops_forward:,}")
层归一化(LayerNorm)
B:Batch size ; L:序列长度
nn.LayerNorm(d_model)
- 公式:
FLOPs ≈ 5 × B × L × d_model
5
:均值、方差、归一化、缩放、偏移
- 示例:
B=32, L=512, d_model=768
→5×32×512×768 ≈ 63 MFLOPs
词嵌入层(Embedding)
- 通常不计FLOPs:仅为查表操作,但投影回词表需计算:
2 × B × L × d_model × V
(实际中常忽略,因V过大)
自注意力机制
自注意力机制的总FLOPs计算公式 4d²N + 2dN² 是通过逐步分析注意力机制的四个核心运算环节得到的。以下是详细的推导过程:
1. 自注意力的FLOPs
# 假设输入x: (B, L, d_model)
Q = x @ W_q # W_q: (d_model, d_head)
K = x @ W_k
V = x @ W_v
attn = (Q @ K.T) / sqrt(d_head)
output = attn @ V @ W_o # W_o: (d_head, d_model)
假设:
- 模型维度(隐藏层大小): d d d(即 d model d_{\text{model}} dmodel)
- 序列长度: N N N
- 注意力头数: h h h(但最终FLOPs与头数无关,因 d = h × d k d = h \times d_k d=h×dk)
(1) Q/K/V投影(输入线性变换)
- 三个独立的矩阵乘法: X W Q XW_Q XWQ, X W K XW_K XWK, X W V XW_V XWV,其中 X ∈ R N × d X \in \mathbb{R}^{N \times d} X∈RN×d, W Q / W K / W V ∈ R d × d W_Q/W_K/W_V \in \mathbb{R}^{d \times d} WQ/WK/WV∈Rd×d。
- 每个投影的FLOPs:
2 × N × d × d = 2 N d 2 2 \times N \times d \times d = 2Nd^2 2×N×d×d=2Nd2
(解释:矩阵乘法 ( N × d ) × ( d × d ) (N \times d) \times (d \times d) (N×d)×(d×d) 的FLOPs为 2 × 2 \times 2× 行×列×内积维度) - 三个投影总FLOPs: 3 × 2 N d 2 = 6 N d 2 3 \times 2Nd^2 = 6Nd^2 3×2Nd2=6Nd2
(2) QKᵀ计算(注意力分数)
- 计算 Q K T QK^T QKT,其中 Q , K ∈ R N × d Q, K \in \mathbb{R}^{N \times d} Q,K∈RN×d。
- FLOPs:
2
×
N
×
N
×
d
=
2
N
2
d
2 \times N \times N \times d = 2N^2d
2×N×N×d=2N2d
(解释:每个注意力分数 Q i K j T Q_i K_j^T QiKjT 需要 d d d 次乘加运算,共 N 2 N^2 N2 个分数)
(3) Softmax归一化
- 对
N
×
N
N \times N
N×N 的注意力矩阵每行计算softmax:
- 指数运算: N 2 N^2 N2 次
- 求和及除法: 2 N 2 2N^2 2N2 次
- 总FLOPs:
3 N 2 3N^2 3N2
(通常远小于其他项,可忽略)
(4) 注意力加权(AV计算)
- 计算 A × V A \times V A×V,其中 A ∈ R N × N A \in \mathbb{R}^{N \times N} A∈RN×N, V ∈ R N × d V \in \mathbb{R}^{N \times d} V∈RN×d。
- FLOPs: 2 × N × N × d = 2 N 2 d 2 \times N \times N \times d = 2N^2d 2×N×N×d=2N2d
(5) 输出投影
- 将注意力结果通过矩阵 W O ∈ R d × d W_O \in \mathbb{R}^{d \times d} WO∈Rd×d 投影
- FLOPs: 2 × N × d × d = 2 N d 2 2 \times N \times d \times d = 2Nd^2 2×N×d×d=2Nd2
2. 合并同类项
将各步骤FLOPs相加:
总FLOPs
=
6
N
d
2
⏟
Q/K/V投影
+
2
N
2
d
⏟
QKᵀ
+
2
N
2
d
⏟
AV
+
2
N
d
2
⏟
输出投影
=
(
6
N
d
2
+
2
N
d
2
)
+
(
2
N
2
d
+
2
N
2
d
)
=
8
N
d
2
+
4
N
2
d
\begin{align*} \text{总FLOPs} &= \underbrace{6Nd^2}_{\text{Q/K/V投影}} + \underbrace{2N^2d}_{\text{QKᵀ}} + \underbrace{2N^2d}_{\text{AV}} + \underbrace{2Nd^2}_{\text{输出投影}} \\ &= (6Nd^2 + 2Nd^2) + (2N^2d + 2N^2d) \\ &= 8Nd^2 + 4N^2d \end{align*}
总FLOPs=Q/K/V投影
6Nd2+QKᵀ
2N2d+AV
2N2d+输出投影
2Nd2=(6Nd2+2Nd2)+(2N2d+2N2d)=8Nd2+4N2d
为什么常见公式是 4 N d 2 + 2 N 2 d 4Nd^2 + 2N^2d 4Nd2+2N2d?
- 优化实现:实际中,Q/K/V的投影可通过单个大矩阵乘法完成(合并参数矩阵),将三个投影的FLOPs从 6 N d 2 6Nd^2 6Nd2 降为 2 × N × d × 3 d = 6 N d 2 2 \times N \times d \times 3d = 6Nd^2 2×N×d×3d=6Nd2(无变化)。但输出投影的 2 N d 2 2Nd^2 2Nd2 通常保留。
- 更精确的简化:
通常将总FLOPs近似为:
总FLOPs ≈ 4 N d 2 + 2 N 2 d \text{总FLOPs} \approx 4Nd^2 + 2N^2d 总FLOPs≈4Nd2+2N2d
这是因为:- Q/K/V投影的 6 N d 2 6Nd^2 6Nd2 中,输出投影的 2 N d 2 2Nd^2 2Nd2 已单独计算;(原因:Q/K/V投影可部分合并计算)
- 实际代码优化可能减少重复计算(如融合操作)。
3. 直观理解
-
4
N
d
2
4Nd^2
4Nd2 项:
来自输入/输出的线性变换(Q/K/V投影和输出投影),与模型维度 d d d 的平方成正比。 -
2
N
2
d
2N^2d
2N2d 项:
来自注意力分数的计算和加权,与序列长度 N N N 的平方成正比,是长序列的瓶颈。
4. 实例验证(GPT-3)
- 参数: d = 12288 d=12288 d=12288, N = 2048 N=2048 N=2048
- 计算:
4 × 12288 2 × 2048 ≈ 1.24 × 10 12 2 × 12288 × 2048 2 ≈ 1.03 × 10 11 总FLOPs ≈ 1.34 × 10 12 (每层) 4 \times 12288^2 \times 2048 \approx 1.24 \times 10^{12} \\ 2 \times 12288 \times 2048^2 \approx 1.03 \times 10^{11} \\ \text{总FLOPs} \approx 1.34 \times 10^{12} \text{(每层)} 4×122882×2048≈1.24×10122×12288×20482≈1.03×1011总FLOPs≈1.34×1012(每层)
与GPT-3论文中的单层FLOPs估算一致。
5. 总结
公式 FLOPs ≈ 4 N d 2 + 2 N 2 d \text{FLOPs} \approx 4Nd^2 + 2N^2d FLOPs≈4Nd2+2N2d 的由来:
- 线性变换(
4
N
d
2
4Nd^2
4Nd2):
- Q/K/V合并投影: 2 N d 2 2Nd^2 2Nd2
- 输出投影: 2 N d 2 2Nd^2 2Nd2
- 注意力交互(
2
N
2
d
2N^2d
2N2d):
- QKᵀ + AV计算:各 N 2 d N^2d N2d
此公式是理论最小值,实际实现可能因优化(如融合kernel)略有差异。
前馈网络(FFN)
nn.Sequential(
nn.Linear(d_model, 4*d_model), # 扩展
nn.GELU(),
nn.Linear(4*d_model, d_model) # 收缩
)
通常为两层全连接,若中间维度 4 d model 4d_{\text{model}} 4dmodel:
- 公式:
FLOPs = 2 × B × L × (d_model × 4d_model + 4d_model × d_model)
- 示例:
B=32, L=512, d_model=768
→2×32×512×(768×3072 + 3072×768) ≈ 1.5 TFLOPs
FLOPs = 2 × d model × 4 d model × L × 2 = 16 d model 2 N \text{FLOPs} = 2 \times d_{\text{model}} \times 4d_{\text{model}} \times L \times 2 = 16d_{\text{model}}^2 N FLOPs=2×dmodel×4dmodel×L×2=16dmodel2N
整体模型计算量
对于包含 ( L ) 层的Transformer模型:
- 单层FLOPs:
FLOPs layer = FLOPs attention + FLOPs FFN \text{FLOPs}_{\text{layer} = \text{FLOPs}_{\text{attention}} + \text{FLOPs}_{\text{FFN}}} FLOPslayer=FLOPsattention+FLOPsFFN - 总FLOPs(序列长度
(
N
)
( N )
(N)):
FLOPs total ≈ L × ( 4 d model 2 N + 2 d model N 2 + 16 d model 2 N ) \text{FLOPs}_{\text{total}} \approx L \times (4d_{\text{model}}^2 N + 2d_{\text{model}} N^2 + 16d_{\text{model}}^2 N) FLOPstotal≈L×(4dmodel2N+2dmodelN2+16dmodel2N)
可简化为:
FLOPs total ≈ L × ( 20 d model 2 N + 2 d model N 2 ) \text{FLOPs}_{\text{total}} \approx L \times (20d_{\text{model}}^2 N + 2d_{\text{model}} N^2) FLOPstotal≈L×(20dmodel2N+2dmodelN2)
卷积层
卷积层的计算量通常以浮点运算次数(FLOPs)衡量,以下是标准卷积和深度可分离卷积的计算公式及关键影响因素:
1. 标准卷积的计算量
公式(考虑乘法和加法):
FLOPs = 2 × C in × K 2 × C out × H out × W out \text{FLOPs} = 2 \times C_{\text{in}} \times K^2 \times C_{\text{out}} \times H_{\text{out}} \times W_{\text{out}} FLOPs=2×Cin×K2×Cout×Hout×Wout
其中:
-
C in C_{\text{in}} Cin:输入通道数
-
C out C_{\text{out}} Cout:输出通道数
-
K:卷积核尺寸(假设为方形)
-
H out H_{\text{out}} Hout, W out W_{\text{out}} Wout:输出特征图尺寸,由输入尺寸、步长(stride)和填充(padding)决定:
H out = ⌊ H in + 2 P − K S ⌋ + 1 H_{\text{out}} = \left\lfloor \frac{H_{\text{in}} + 2P - K}{S} \right\rfloor + 1 Hout=⌊SHin+2P−K⌋+1
关键点:
-
系数2:包含乘法(C_{\text{in}} \times K^2次)和加法(C_{\text{in}} \times K^2 - 1次)。
-
偏置项:每个输出点加1次加法,但通常忽略(占比小)。
2. 深度可分离卷积的计算量
公式(分两部分):
-
深度卷积(逐通道卷积):
FLOPs depthwise = C in × K 2 × H out × W out \text{FLOPs}_{\text{depthwise}} = C_{\text{in}} \times K^2 \times H_{\text{out}} \times W_{\text{out}} FLOPsdepthwise=Cin×K2×Hout×Wout
-
逐点卷积(1×1卷积):
FLOPs pointwise = C in × C out × H out × W out \text{FLOPs}_{\text{pointwise}} = C_{\text{in}} \times C_{\text{out}} \times H_{\text{out}} \times W_{\text{out}} FLOPspointwise=Cin×Cout×Hout×Wout
-
总计算量:
FLOPs total = C in × H out × W out × ( K 2 + C out ) \text{FLOPs}_{\text{total}} = C_{\text{in}} \times H_{\text{out}} \times W_{\text{out}} \times (K^2 + C_{\text{out}}) FLOPstotal=Cin×Hout×Wout×(K2+Cout)
优势:相比标准卷积,计算量减少约 1 C out + 1 K 2 \frac{1}{C_{\text{out}}} + \frac{1}{K^2} Cout1+K21倍。
3. 注意事项
FLOPs与MACs区别:部分工具(如fvcore)返回MACs(乘加次数),1 MACs ≈ 2 FLOPs。
工具
** 手动计算**
def calculate_llm_flops(L, d_model, N):
flops_attention = 4 * d_model**2 * N + 2 * d_model * N**2
flops_ffn = 16 * d_model**2 * N
return L * (flops_attention + flops_ffn)
# GPT-3 175B示例
flops = calculate_llm_flops(L=96, d_model=12288, N=2048)
print(f"总FLOPs: {flops / 1e12:.0f} TFLOPs")
** 库**
pip install torchprofile
from torchprofile import profile_macs
macs = profile_macs(model, input)
flops = 2 * macs # 1 MAC = 2 FLOPs
工具汇总
以下是 PyTorch 模型计算量统计工具的详细总结,涵盖 主流工具、核心功能 和 具体使用方法,帮助您快速选择适合的工具并正确统计 FLOPs、参数量等指标。
工具对比
工具名称 | 维护状态 | 核心功能 | 特点 | 适用场景 |
---|---|---|---|---|
fvcore | 活跃 | FLOPs、参数量、激活值统计 | Facebook 官方支持,精度高 | 科研论文、模型优化 |
thop | 维护中 | FLOPs(MACs×2)、参数量 | 轻量级,API简单 | 快速验证、原型开发 |
torchstat | 停止更新 | 分层FLOPs/参数量/内存消耗 | 可视化网络结构 | 教学/调试(静态输入) |
ptflops | 维护中 | 类似 torchstat,兼容新版 PyTorch | 支持动态输入 | 需要分层统计的场景 |
PyTorch Profiler | 官方工具 | 时间/内存/FLOPs(需CUDA) | 结合运行时性能分析 | 性能调优 |
工具 | 功能 | 适用场景 |
---|---|---|
THOP | FLOPs + 参数量 | 通用模型计算量统计 |
torchsummary | 参数量 + 层结构 | 快速检查模型架构 |
ptflops | FLOPs + 参数量(逐层) | 深度优化分析 |
PyTorch 内置 | 仅参数量 | 简单参数统计 |
ONNX 导出 | 计算图可视化 | 调试复杂模型 |
工具选型建议
需求 | 推荐工具 | 理由 |
---|---|---|
高精度FLOPs统计 | fvcore | 支持逐层分解,结果可靠 |
快速验证 | thop | 一行代码获取结果,适合开发调试 |
可视化网络结构 | ptflops | 替代 torchstat,兼容新版 PyTorch |
动态输入模型 | fvcore/ptflops | 均可处理可变输入尺寸 |
性能调优(时间/内存) | PyTorch Profiler | 官方工具,结合运行时分析 |
推荐优先使用 fvcore
,需要快速验证时选择 thop
,可视化分析用 ptflops
。
选择工具时,可以根据需求决定:
- 如果只需参数量,用
torchsummary
或 PyTorch 内置方法。 - 如果需要 FLOPs 和参数量,推荐 THOP 或 ptflops。
- 对于更复杂的分析(如逐层计算),ptflops 更合适。
1. 为什么不同工具统计结果不一致?
- 原因:FLOPs 计算方式不同(如是否忽略偏置项、是否×2转换MACs)。
- 解决方案:统一使用
fvcore
或手动验证基础层(如Conv2d
的 FLOPs=2 × Cin × Cout × K × K × H × W
)。
2. 如何统计自定义层的FLOPs?
- fvcore:通过
FlopCountAnalysis
自动检测(需继承torch.nn.Module
)。 - thop:使用
custom_ops
参数手动注册计算规则。
3. 动态输入如何处理?
# fvcore/ptflops 均支持
input = torch.randn(1, 3, 256, 256) # 任意尺寸
flops = FlopCountAnalysis(model, input).total()