全连接层的算力(矩阵乘法)计算方式

本文介绍了神经网络全连接层的矩阵乘法计算过程,包括乘法和加法次数的计算,并以3x3矩阵为例进行具体演示。此外,讨论了稀疏矩阵的概念及其在计算中的作用,特别是对于数值为0的元素,它们如何影响计算效率。文中还提到了A100中的稀疏结构,展示了如何通过结构化稀疏矩阵减少内存存储和带宽,提高计算效率。

神经网络的全链接层计算过程可以看成两个矩阵相乘,如下图所示,一个MxN的矩阵乘以一个NxP的矩阵,得到一个MxP的矩阵,进行乘法的次数为:

(N)*(M*P)

加法次数为:

(N-1)*M*P

所以,矩阵乘法总的计算量为(N)*(M*P)+(N-1)*M*P = (2N-1)*M*P

每计算出一个结果,需要对一个N维向量作内积,内积需要进行N次乘法和N-1次加法(第一次计算不需要作加法,或者看成+0,就不需要-1了),计算一个结果的计算次数为2N-1.

比如,就拿3*3的矩阵乘法为例:

计算如下:

所以,它的计算量为:

乘法次数:3*3*3=27次.

加法次数:  2*3*3 =18次.

算在一起浮点操作为27+18=45次.

用公式计算(2N-1)*M*P=5*3*3=45次,互相印证符合。

当然,如果将MAC计算初始值看成0,则初始情况下实际上做了一个+0的加法操作,每个结果元素进行的加法次数也可以认为是N次而非N-1次,这样相当于增加了M*P个加法(每个结果元素1个+0操作),因为你可以认为一开始进行了加0操作。这样的化公式就更加简介,直接就是2N*M*P就,N*M*P个乘法和N*M*P个加法,乘法和加法次数各占一半,每次乘法对应一次加法,正好可以由一个MAC单元去执行。

这样,矩阵乘法总的计算次数就变成了2N*M*P。

计算/访存比

所以,对于一个M*K*N的矩阵乘加运算(MxK与KxN的矩阵相乘,再与MxN的矩阵相加),它的计算访存比为:

\frac{2 \times m \times k \times n}{m \times k + k \times n + 2 \times m \times n}

分子表示计算量,前面已经推导过了,分母中mxk表示读取第一个矩阵的读次数,相应的kxn是第二个矩阵的读次数,因为第三个矩阵既是操作数,又是结果,所以需要读写两次,为 2xmxn。

如果对于方阵,K=N=M,此时计算/访存比可以简单表示为:

\frac{m}{2}

说明矩阵规模越大,计算/访存比会越高,利用大矩阵配合有效分块算法,会获得较大的计算密度。

根据上面的公式也可以看出,对于标量运算来说,计算密度为1/2,是小于矩阵运算的。

矩阵乘法的计算模式

矩阵乘法的计算模式相对固定,面向神经网络的专用加速器多遵循两种计算模式:

1.模式一,矩阵乘法被看作若干对向量进行逐元素对应相乘,得到新的向量后在进行向量内相加归约加合得到最终结果。结果矩阵为mxn时,会有mxn个向量进行该操作。这种模式对应向量乘法单元和加法树单元的结构:

下面是一个加法树乘法器的verilog实现:

module multi_add_tree(a,b,clk,out);
output [15:0] out;
input [7:0] a,b;
input clk;
wire [15:0] out;

wire [15:0] out1,c1;
wire [13:0] out2;
wire [11:0] out3,c2;
wire [9:0] out4;

reg [14:0] temp0; 
reg [13:0] temp1;
reg [12:0] temp2;
reg [11:0] temp3;
reg [10:0] temp4;
reg [9:0] temp5;
reg [8:0] temp6;
reg [7:0] temp7;

// 8*1乘法器

function [7:0] mut8_1;
input [7:0] operand;
input sel;

begin
	mut8_1 = sel ? operand : 8'b0000_0000;
end
endfunction 

//操作数b各位与操作数a相乘
always @(posedge clk)
begin
	temp7 = mut8_1(a,b[0]);
	temp6 = (mut8_1(a,b[1]))<<1;
	temp5 = (mut8_1(a,b[2]))<<2;
	temp4 = (mut8_1(a,b[3]))<<3;
	temp3 = (mut8_1(a,b[4]))<<4;
	temp2 = (mut8_1(a,b[5]))<<5;
	temp1 = (mut8_1(a,b[6]))<<6;
	temp0 = (mut8_1(a,b[7]))<<7;
end

//加法树运算
assign out1 = temp0 + temp1;
assign out2 = temp2 + temp3;
assign out3 = temp4 + temp5;
assign out4 = temp6 + temp7;
assign c1 = out1 + out2;
assign c2 = out3 + out4;
assign out = c1 + c2;

endmodule

仿真计算2x100=0xc8=200:

2.模式2,把矩阵相乘分解至以标量为单位,每次操作都是三个标量(axb+c)间的乘加运算,这种模式对应乘加单元组成的脉动阵列结构。

两种方法的差别在于各自的数据流调度方式,即数据存储和复用上的差异,nvidia tensor core则遵循了矩阵乘法的计算模式,结合寄存器的读取共享,尽可能地优化数据的复用。

两个矩阵相乘的几何结构

下图中的每个小立方体都是一个MAC单元。

稀疏化操作原理

若矩阵中数值为0的元素数目远远多于非0元素的数目,并且非0元素分布没有规律时,则称该矩阵为稀疏矩阵;与之相反,若非0元素数目占大多数时,则称该矩阵为稠密矩阵。定义非零元素的总数比上矩阵所有元素的总数为矩阵的稠密度。

对于AXB=C的形式来说,A的行中对应为0的项,对应B中列的项即便不为0,对最终的结果也起不到任何作用。所以,稀疏化操作的原理如下:

1.首先对输入矩阵进行fine tuning, 将小于某个阈值的矩阵元素设置为0参与计算。

2.对稀疏矩阵进行压缩,压缩后生成一个规则的稠密矩阵以及压缩指示矩阵dices, 后者用来描述稀疏矩阵的压缩规则,比如典型的,dices矩阵中的元素代表对应行的压缩矩阵元素在原矩阵中的行坐标。这样DICE矩阵行向量的维数必须和压缩矩阵行向量的维数相同。

3.与激活数据做乘法的时候,用DICES矩阵中的行向量作为selector, 选择激活数据中的对应元素参与计算,没有选中的代表对应的权重为0, 其激活数据已经没有意义了,不需要参与计算。

至此,矩阵稀疏花单元达到了压缩稀疏矩阵,减少运算量的目的。

A100中的稀疏结构符合一些前置约定,是通过新的2:4稀疏矩阵定义来实现的,该定义允许在每个四项向量中有两个非零值。支持行上2:4的结构化稀疏性,如上图所示。由于矩阵的结构定义良好,它可以有效地压缩,并将内存存储和带宽减少近2倍。

参考文档

脉动阵列实现卷积计算_脉动阵列 卷积-优快云博客

数字电路基础知识——组合逻辑电路之乘法器的设计(一)—— 并行、移位相加、加法树、查找表乘法器_乘法器电路图-优快云博客


结束

<think>我们正在讨论PyTorch中的广义矩阵乘法函数`torch.matmul`。根据用户要求,我们需要提供一个使用指南和示例,并参考站内引用内容(虽然引用内容有限,但我们会尽量关联)。同时,我们需要按照格式要求使用LaTeX,并生成相关问题。###一、`torch.matmul`概述`torch.matmul`是PyTorch中执行**矩阵乘法**的核心函数,支持从向量到高维张量的多种乘法操作。其核心规则遵循线性代数中的矩阵乘法法则:对于两个张量,它计算的是**最后两个维度**的矩阵乘积,而前面的维度视为批量维度(并支持广播)[^1][^3]。数学表达:-若输入为两个矩阵$A\in\mathbb{R}^{m\timesn}$和$B\in\mathbb{R}^{n\timesp}$,则输出$C=AB\in\mathbb{R}^{m\timesp}$,其中$C_{ij}=\sum_{k=1}^{n}A_{ik}B_{kj}$。-对于高维张量,如$A\in\mathbb{R}^{b\timesm\timesn}$和$B\in\mathbb{R}^{b\timesn\timesp}$,则输出$C\in\mathbb{R}^{b\timesm\timesp}$,其中每个批量$C_b=A_bB_b$。###二、核心使用场景与示例####1.向量点积(1维张量)```pythonimporttorcha=torch.tensor([1,2,3])#形状(3,)b=torch.tensor([4,5,6])#形状(3,)c=torch.matmul(a,b)#输出:tensor(32)即1*4+2*5+3*6=32[^3]```####2.矩阵乘法(2维张量)```pythonA=torch.tensor([[1,2],[3,4]])#(2,2)B=torch.tensor([[5,6],[7,8]])#(2,2)C=torch.matmul(A,B)#输出:tensor([[19,22],[43,50]])#计算:#[1*5+2*7,1*6+2*8]=[19,22]#[3*5+4*7,3*6+4*8]=[43,50][^1]```####3.矩阵与向量相乘```pythonA=torch.tensor([[1,2],[3,4]])#(2,2)v=torch.tensor([5,6])#(2,)result=torch.matmul(A,v)#输出:tensor([17,39])```####4.批量矩阵乘法(3维张量)```pythonbatch_size=3A=torch.randn(batch_size,4,5)#3个4x5矩阵B=torch.randn(batch_size,5,6)#3个5x6矩阵C=torch.matmul(A,B)#输出形状:(3,4,6)```####5.广播机制当张量的批量维度不一致但可广播时:```pythonA=torch.randn(2,1,4,5)#形状(2,1,4,5)B=torch.randn(3,5,6)#形状(3,5,6)->可广播为(1,3,5,6)C=torch.matmul(A,B)#输出形状:(2,3,4,6)```###三、与相关函数的对比|函数|适用维度|是否支持广播|典型场景||-------------------|-----------------------|--------------|------------------------||`torch.matmul`|任意维度|是|通用矩阵乘法||`torch.mm`|仅2维|否|无批量的矩阵乘法||`torch.bmm`|仅3维(批量)|否|固定批量矩阵乘法[^1]||`torch.einsum`|任意维度|是|复杂张量运|###四、性能优化建议1.**维度对齐**:确保最后两个维度满足矩阵乘法规则($m\timesn$乘$n\timesp$)。2.**避免显式转置**:使用`permute`或`transpose`调整维度时,尽量在内存连续时操作。3.**利用广播**:当批量维度不一致时,广播机制能减少数据复制开销。4.**混合精度计算**:使用`torch.cuda.amp`自动混合精度可加速计算(尤其在GPU上)。###五、典型错误及修复**错误1:维度不匹配**```pythonA=torch.randn(3,4)B=torch.randn(2,4)#错误!需要(4,?)但实际是(2,4)#修正:转置B或调整AB=B.t()#转为(4,2)C=torch.matmul(A,B)#正确:(3,4)@(4,2)->(3,2)```**错误2:批量维度不兼容**```pythonA=torch.randn(2,3,4)B=torch.randn(3,4,5)#批量维度2≠3#修正:对齐批量维度B=B[:2]#取前两个批次C=torch.matmul(A,B)#形状(2,3,5)```###六、在神经网络中的应用1.**全连接层**:权重矩阵乘法(参考引用[2]中嵌入层的权重矩阵)```pythoninput=torch.randn(128,64)#批量128,特征64weight=torch.randn(64,10)#权重矩阵[^2]output=input@weight#等价于torch.matmul(input,weight)```2.**注意机制**:```pythonQ=torch.randn(8,10,64)#8头注意,每头10个词,64维K=torch.randn(8,10,64)attn_scores=Q@K.transpose(-2,-1)#形状(8,10,10)```§§相关问题§§1.如何在不使用`torch.matmul`的情况下手动实现批量矩阵乘法?2.`torch.matmul`在分布式训练中如何优化性能?3.高维张量(>3维)使用`torch.matmul`时广播机制的具体规则是什么?4.`torch.matmul`与`torch.einsum`在计算性能和可读性上的对比?</think>### PyTorch `torch.matmul` 广义矩阵乘法使用指南 `torch.matmul` 是 PyTorch 中执行**广义矩阵乘法**的核心函数,支持从标量到高维张量的各种乘法操作。以下通过分类示例详细解析其用法: #### 一、基础乘法类型 ##### 1. 标量乘法(0维张量) ```python import torch a = torch.tensor(2) # 标量 b = torch.tensor(3) # 标量 result = torch.matmul(a, b) print(result) # 输出: tensor(6) [^3] ``` ##### 2. 向量点积(1维张量) ```python a = torch.tensor([1, 2, 3]) # 向量 (3,) b = torch.tensor([4, 5, 6]) # 向量 (3,) result = torch.matmul(a, b) print(result) # 输出: tensor(32) # 计算: $1\times4 + 2\times5 + 3\times6 = 32$ [^3] ``` ##### 3. 矩阵乘法(2维张量) ```python a = torch.tensor([[1, 2], [3, 4]]) # 2x2 b = torch.tensor([[5, 6], [7, 8]]) # 2x2 result = torch.matmul(a, b) print(result) # 输出: tensor([[19, 22], [43, 50]]) # 计算逻辑: # $[1\times5+2\times7, 1\times6+2\times8] = [19,22]$ # $[3\times5+4\times7, 3\times6+4\times8] = [43,50]$ [^1] ``` ##### 4. 矩阵-向量乘法 ```python a = torch.tensor([[1, 2], [3, 4]]) # 2x2 b = torch.tensor([5, 6]) # 向量 (2,) result = torch.matmul(a, b) print(result) # 输出: tensor([17, 39]) # 计算: $[1\times5+2\times6, 3\times5+4\times6]$ [^3] ``` #### 二、批量矩阵乘法(3维+张量) ##### 1. 三维张量乘法 ```python a = torch.randn(2, 3, 4) # 两个3x4矩阵 b = torch.randn(2, 4, 5) # 两个4x5矩阵 result = torch.matmul(a, b) print(result.shape) # 输出: torch.Size([2, 3, 5]) [^3] ``` ##### 2. 广播机制示例 ```python a = torch.randn(3, 4, 5) # 三个4x5矩阵 b = torch.randn(5, 6) # 单个5x6矩阵 → 自动广播 result = torch.matmul(a, b) print(result.shape) # 输出: torch.Size([3, 4, 6]) [^2] ``` #### 三、运符等价形式 `@` 运符是 `torch.matmul` 的简化写法: ```python a = torch.randn(2, 3) b = torch.randn(3, 4) result1 = torch.matmul(a, b) result2 = a @ b # 完全等价 print(torch.allclose(result1, result2)) # 输出: True [^2] ``` #### 四、性能优化技巧 1. **维度对齐**:确保相乘维度匹配 ```python # 优化前:需要转置操作 result = torch.matmul(a, b.t()) # 优化后:直接使用原始维度 b = torch.randn(4, 3) # 预先调整为匹配维度 result = torch.matmul(a, b) ``` 2. **批量处理**:利用三维张量减少循环 ```python # 低效方式 results = [torch.matmul(a[i], b[i]) for i in range(batch_size)] # 高效方式 result = torch.matmul(a, b) # a.shape=(batch, m, n), b.shape=(batch, n, p) ``` 3. **混合精度**:使用FP16加速计算 ```python with torch.cuda.amp.autocast(): result = torch.matmul(a.half(), b.half()) ``` #### 五、与相关函数对比 | 函数 | 适用维度 | 特点 | |------------------|-------------------|------------------------------| | `torch.matmul()` | 任意维度 | 支持广播,推荐首选 | | `torch.mm()` | 仅 2D × 2D | 标准矩阵乘法,无广播 | | `torch.bmm()` | 仅 3D × 3D | 批量矩阵乘法,无广播 [^1] | | `torch.einsum()` | 任意维度 | Einstein求和约定,更灵活 | #### 六、常见错误及解决 ##### 维度不匹配错误 ```python a = torch.randn(3, 4) b = torch.randn(2, 3) # 错误! a的列数(4) ≠ b的行数(2) # 正确做法: 调整维度满足 (m×n) @ (n×p) 规则 b = b.transpose(0, 1) # 转置为3×2 → 仍不匹配! b = torch.randn(4, 2) # 修正为4×2 result = torch.matmul(a, b) # 3×4 @ 4×2 → 得到3×2结果 [^2] ``` ##### 批量维度不匹配 ```python a = torch.randn(2, 3, 4) b = torch.randn(3, 4, 5) # 错误! 批量维度2≠3 # 解决方案: 使用广播或调整维度 b = b[:2] # 截取前两个批次 result = torch.matmul(a, b) # 正常执行 ``` #### 七、典型应用场景 1. **全连接层计算** ```python input = torch.randn(128, 64) # 批量128, 特征64 weight = torch.randn(64, 10) # 权重矩阵 [^2] output = input @ weight # 128×10输出 ``` 2. **注意机制** ```python Q = torch.randn(8, 10, 64) # 8头注意 K = torch.randn(8, 10, 64) scores = torch.matmul(Q, K.transpose(-2, -1)) # 8×10×10 ``` 3. **嵌入层计算** ```python embeddings = torch.nn.Embedding(1000, 128) # 嵌入矩阵 [^2] indices = torch.tensor([1, 5, 9]) result = embeddings(indices) @ weight # 嵌入向量矩阵乘法 ```
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

papaofdoudou

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值