torch.einsum 的 10 个常见用法详解以及多头注意力实现

torch.einsum 是 PyTorch 提供的一个高效的张量运算函数,能够用紧凑的 Einstein Summation 约定(Einstein Summation Convention, Einsum)描述复杂的张量操作,例如矩阵乘法、转置、内积、外积、批量矩阵乘法等。


1. 基本语法

torch.einsum(equation, *operands)

• equation:爱因斯坦求和表示法的字符串,例如 “ij,jk->ik”
• operands:参与计算的张量,可以是多个

2. 基本概念

Einsum 使用 -> 将输入与输出模式分开:
• 左侧:表示输入张量的索引
• 右侧:表示输出张量的索引
• 省略求和索引:会自动对省略的索引进行求和(即 Einstein Summation 规则)


3. torch.einsum 的 10 个常见用法

(1) 矩阵乘法 (torch.mm)

import torch

A = torch.randn(2, 3)
B = torch.randn(3, 4)

C = torch.einsum("ij,jk->ik", A, B)  # 矩阵乘法
print(C.shape)  # torch.Size([2, 4])

解析:
• ij 表示 A 的形状 (2,3)
• jk 表示 B 的形状 (3,4)
• 由于 j 在 -> 右侧没有出现,因此对其求和,最终得到形状 (2,4)


(2) 向量点积 (torch.dot)

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

dot_product = torch.einsum("i,i->", a, b)  # 向量点积
print(dot_product)  # 输出: 32

解析:
• i,i-> 代表对应位置相乘并求和,等价于 torch.dot(a, b)


(3) 矩阵转置 (torch.transpose)

A = torch.randn(2, 3)

A_T = torch.einsum("ij->ji", A)  # 矩阵转置
print(A_T.shape)  # torch.Size([3, 2])

解析:
• ij->ji 交换 i 和 j 维度,相当于 A.T


(4) 矩阵外积 (torch.outer)

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

outer_product = torch.einsum("i,j->ij", a, b)  # 外积
print(outer_product)
# tensor([[ 4,  5,  6],
#         [ 8, 10, 12],
#         [12, 15, 18]])

解析:
• i,j->ij 生成形状 (3,3) 的矩阵


(5) 批量矩阵乘法 (torch.bmm)

A = torch.randn(5, 2, 3)
B = torch.randn(5, 3, 4)

C = torch.einsum("bij,bjk->bik", A, B)  # 批量矩阵乘法
print(C.shape)  # torch.Size([5, 2, 4])

解析:
• b 代表 batch 维度,不求和,保持
• j 出现在两个输入中但未出现在输出中,所以对其求和


(6) 计算均值 (torch.mean)

A = torch.randn(3, 4)

mean_A = torch.einsum("ij->", A) / A.numel()  # 计算均值
print(mean_A)

解析:
• ij-> 表示所有元素求和
• A.numel() 是总元素数,等价于 torch.mean(A)


(7) 计算范数 (torch.norm)

A = torch.randn(3, 4)

norm_A = torch.einsum("ij,ij->", A, A).sqrt()  # Frobenius 范数
print(norm_A)

解析:
• ij,ij-> 表示 A 的所有元素平方求和
• .sqrt() 计算范数

(8) 计算 Softmax

A = torch.randn(3, 4)

softmax_A = torch.einsum("ij->ij", torch.exp(A)) / torch.einsum("ij->i1", torch.exp(A))
print(softmax_A)

解析:
• torch.exp(A) 计算指数
• torch.einsum(“ij->i1”, torch.exp(A)) 计算行和


(9) 对角线提取 (torch.diagonal)

A = torch.randn(3, 3)

diag_A = torch.einsum("ii->i", A)  # 提取主对角线
print(diag_A)

解析:
• ii->i 只保留对角线元素,等价于 torch.diagonal(A)


(10) 计算张量 Hadamard 积(逐元素乘法)

A = torch.randn(3, 4)
B = torch.randn(3, 4)

hadamard_product = torch.einsum("ij,ij->ij", A, B)  # 逐元素乘法
print(hadamard_product)

解析:
• ij,ij->ij 表示对相同索引位置元素相乘


总结

Einsum 公式作用等价 PyTorch 代码
ij,jk->ik矩阵乘法torch.mm(A, B)
i,i->向量点积torch.dot(a, b)
i,j->ji矩阵转置A.T
bij,bjk->bik批量矩阵乘法torch.bmm(A, B)
ii->提取对角线torch.diagonal(A)
ij->矩阵所有元素求和A.sum()
ij,ij->ijHadamard 乘法A * B
ij,ij->Frobenius 范数的平方(A**2).sum()

使用 torch.einsum 计算多头注意力中的点积相似性

下面的代码示例演示如何使用 PyTorch 的 torch.einsum 函数来计算 Transformer 多头注意力机制中的点积注意力分数和输出。代码包含以下步骤:
1. 定义输入 Q, K, V:随机初始化查询(Query)、键(Key)、值(Value)张量,形状符合多头注意力的规范(包含 batch 维度和多头维度)。
2. 计算 QK^T / sqrt(d_k):使用 torch.einsum 计算每个注意力头的 Q 与 K 转置的点积相似性,并除以 d k \sqrt{d_k} dk (注意力头维度的平方根)进行缩放。
3. 计算 softmax 注意力权重:对第2步得到的相似性分数应用 softmax(在最后一个维度上),得到注意力权重分布。
4. 计算最终的注意力输出:将 softmax 得到的注意力权重与值 V 相乘(加权求和)得到每个头的输出。
5. 完整代码注释:代码中包含详尽的注释,解释每一步的用途。
6. 可视化注意力权重:使用 Matplotlib 可视化一个头的注意力权重矩阵,以便更好地理解注意力分布。
7. 具体参数设置:在代码开头指定 batch_size、sequence_length、embedding_dim、num_heads 等参数,便于调整。

import torch
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt

# 7. 参数设置:定义 batch 大小、序列长度、嵌入维度、注意力头数等
batch_size = 2           # 批处理大小
sequence_length = 5      # 序列长度(假设查询和键序列长度相同)
embedding_dim = 16       # 整体嵌入维度(embedding维度)
num_heads = 4            # 注意力头数量
head_dim = embedding_dim // num_heads  # 每个注意力头的维度 d_k(需保证能够整除)

# 1. 定义输入 Q, K, V 张量(随机初始化)
# 形状约定:[batch_size, num_heads, seq_len, head_dim]
Q = torch.randn(batch_size, num_heads, sequence_length, head_dim)
K = torch.randn(batch_size, num_heads, sequence_length, head_dim)
V = torch.randn(batch_size, num_heads, sequence_length, head_dim)

# 打印 Q, K, V 的形状以验证
print("Q shape:", Q.shape)  # 预期: (batch_size, num_heads, sequence_length, head_dim)
print("K shape:", K.shape)  # 预期: (batch_size, num_heads, sequence_length, head_dim)
print("V shape:", V.shape)  # 预期: (batch_size, num_heads, sequence_length, head_dim)

# 2. 计算 QK^T / sqrt(d_k)
# 使用 torch.einsum 进行张量乘法:
# 'b h q d, b h k d -> b h q k' 表示:
#  - b: batch维度
#  - h: 多头维度
#  - q: 查询序列长度维度
#  - k: 键序列长度维度
#  - d: 每个头的维度(将对该维度进行求和,相当于点积)
# Q 的形状是 [b, h, q, d],K 的形状是 [b, h, k, d]。
# einsum 根据 'd' 维度对 Q 和 K 相乘并求和,输出形状 [b, h, q, k],即每个头的 Q 与每个 K 的点积。
scores = torch.einsum('b h q d, b h k d -> b h q k', Q, K)  # 点积 Q * K^T (尚未除以 sqrt(d_k))
scores = scores / math.sqrt(head_dim)  # 缩放除以 sqrt(d_k)

# 3. 计算 softmax 注意力权重
# 对最后一个维度 k 应用 softmax,得到注意力权重矩阵 (对每个 query位置,在所有 key位置上的权重分布和为1)
attention_weights = F.softmax(scores, dim=-1)

# 打印注意力权重矩阵的形状以验证
print("Attention weights shape:", attention_weights.shape)  # 预期: (batch_size, num_heads, seq_len, seq_len)

# 4. 计算最终的注意力输出
# 将注意力权重矩阵与值 V 相乘,得到每个查询位置的加权值。
# 我们再次使用 einsum:
# 'b h q k, b h k d -> b h q d' 表示:
#  - 将 attention_weights [b, h, q, k] 与 V [b, h, k, d] 在 k 维相乘并对 k 求和,
#    得到输出形状 [b, h, q, d](每个头针对每个查询位置输出一个长度为d的向量)。
attention_output = torch.einsum('b h q k, b h k d -> b h q d', attention_weights, V)

# (可选)如果需要将多头的输出合并为一个张量,可以进一步 reshape/transpose 
# 并通过线性层投影。但这里我们仅关注多头内部的注意力计算。
# 合并示例: 将 out 从 [b, h, q, d] 变形为 [b, q, h*d],再通过线性层投影回 [b, q, embedding_dim]。
combined_output = attention_output.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, -1)
# 上面这行代码将 [b, h, q, d] 先变为 [b, q, h, d],再合并h和d维度为[h*d]。
print("Combined output shape (after concatenating heads):", combined_output.shape)
# 注意:combined_output 的最后一维大小应当等于 embedding_dim(num_heads * head_dim)。

# 打印一个注意力输出张量的示例值(比如第一个 batch,第一头,第一查询位置的输出向量)
print("Sample attention output (batch 0, head 0, query 0):", attention_output[0, 0, 0])

# 5. 完整代码注释已在上方各步骤体现。

# 6. 可视化注意力权重
# 我们以第一个样本(batch 0)的第一个注意力头(head 0)的注意力权重矩阵为例进行可视化。
# 这个矩阵形状为 [seq_len, seq_len],其中每行表示查询位置,每列表示键位置。
attn_matrix = attention_weights[0, 0].detach().numpy()  # 取出 batch 0, head 0 的注意力权重矩阵并转换为 numpy

plt.figure(figsize=(5,5))
plt.imshow(attn_matrix, cmap='viridis', origin='upper')
plt.colorbar()
plt.title("Attention Weights (Head 0 of Batch 0)")
plt.xlabel("Key position")
plt.ylabel("Query position")
plt.show()

运行上述代码后,您将看到打印的张量形状和示例值,以及一幅可视化的注意力权重热力图。图中纵轴为查询序列的位置,横轴为键序列的位置,颜色越亮表示注意力权重越高。通过该示例,您可以直观理解多头注意力机制中各查询对不同键“关注”的程度。
输出:

Q shape: torch.Size([2, 4, 5, 4])
K shape: torch.Size([2, 4, 5, 4])
V shape: torch.Size([2, 4, 5, 4])
Attention weights shape: torch.Size([2, 4, 5, 5])
Combined output shape (after concatenating heads): torch.Size([2, 5, 16])
Sample attention output (batch 0, head 0, query 0): tensor([-0.8224, -1.1715, -0.0423, -0.0106])

多头部分的计算:

import torch

# 定义多头注意力机制的点积计算函数
def compute_attention_scores(queries, keys):
    # 计算点积相似性分数
    energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
    return energy

# 示例数据
N = 1            # 批次大小
q = 2            # 查询序列长度
k = 3            # 键序列长度
h = 2            # 注意力头数量
d = 4            # 每个注意力头的维度

# 随机生成 queries 和 keys
queries = torch.rand((N, q, h, d))  # Shape (1, 2, 2, 4)
keys = torch.rand((N, k, h, d))    # Shape (1, 3, 2, 4)

# 计算注意力分数
energy = compute_attention_scores(queries, keys)

print("Energy shape:", energy.shape)
print(energy)

输出
# Energy shape: torch.Size([1, 2, 2, 3])
# tensor([[[[0.7102, 0.3867, 0.5860],
#           [0.9586, 0.5920, 0.6626]],

#          [[1.3163, 0.9486, 0.5482],
#           [1.0403, 0.4555, 0.3656]]]])

更多资料:
torch.einsum用法详解
多头注意力:torch.einsum详解
一文学会 Pytorch 中的 einsum
Python广播机制

### 注意力机制的实现方法详解 #### 1. 基本概念 注意力机制是一种模拟人类视觉注意的方法,在处理序列数据时能够聚焦于重要的部分。通过加权求和的方式,模型可以更关注输入中的某些特定位置[^1]。 #### 2.注意力机制 (Self-Attention) 自注意力允许模型的不同位置相互关联,从而更好地捕捉上下文信息。具体来说,对于给定的一组词向量 \(X\) ,计算三个矩阵:查询矩阵\(Q\)、键矩阵\(K\) 和值矩阵\(V\) 。这三个矩阵由原始输入经过线性变换得到: \[ Q=W_Q X, K=W_K X,V=W_V X \] 其中 \(W_Q , W_K,W_V\) 是可训练参数。接着利用缩放点积计算注意力分数并应用Softmax函数获得权重分布: \[ Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}}) V \] 这里 \(d_k\) 表示键维度大小用于稳定梯度。 ```python import torch.nn as nn import math class SelfAttention(nn.Module): def __init__(self, embed_size, heads): super(SelfAttention, self).__init__() self.embed_size = embed_size self.heads = heads # 定义线性层 self.values = nn.Linear(self.embed_size, self.embed_size, bias=False) self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False) self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False) def forward(self, values, keys, query): N = query.shape[0] value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # 计算QKV queries = self.queries(query) keys = self.keys(keys) values = self.values(values) energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3) out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape( N, query_len, self.embed_size ) return out ``` #### 3. 多头注意力机制(Multi-head Attention) 为了使模型能够在不同表示子空间中捕获更多信息,引入了多头结构。即创建多个独立的小型自注意力模块并将它们的结果拼接起来作为最终输出: \[ MultiHead(Q,K,V )=\text {Concat}(head_1,..., head_h)W_O \] 每个多头内部依旧遵循上述单头操作流程。 #### 4. SE通道注意力(Squeeze-and-Excitation Networks) 该网络旨在增强特征图中有用的信息而抑制无用的部分。其核心思想是对每个channel的重要性打分再重新调整激活程度。具体过程分为两步:“挤压”——全局池化获取全局描述; “激励” —— 使用MLP预测各通道权重[^2]. ```python from keras.layers import GlobalAveragePooling2D,Dense,Multiply,Lambda,Reshape def se_block(input_feature,ratio=16): channel_axis = 1 if K.image_data_format() == "channels_first" else -1 channel = input_feature._keras_shape[channel_axis] shared_layer_one = Dense(channel//ratio, activation='relu', kernel_initializer='he_normal', use_bias=True, bias_initializer='zeros') shared_layer_two = Dense(channel, activation='sigmoid', kernel_initializer='he_normal', use_bias=True, bias_initializer='zeros') avg_pool = GlobalAveragePooling2D()(input_feature) avg_pool = Reshape((1,1,channel))(avg_pool) assert isinstance(avg_pool, object) avg_pool = shared_layer_one(avg_pool) avg_pool = shared_layer_two(avg_pool) multiply = Multiply()([input_feature, Lambda(lambda x: x + 1)(avg_pool)]) return multiply ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值