torch.einsum 简单介绍计算流程

torch.einsum

>>> a = torch.arange(60.).reshape(5,3,4)
>>> b = torch.arange(24.
import torch import torch.nn as nn import torch.nn.functional as F class HybridAttentionPooling(nn.Module): def __init__(self, d_model, num_heads=4, lstm_layers=2): super().__init__() # LSTM层,提取双向上下文信息 self.lstm = nn.LSTM( input_size=d_model, hidden_size=d_model, num_layers=lstm_layers, bidirectional=True, batch_first=True ) # 多头注意力池化层 self.num_heads = num_heads self.query = nn.Parameter(torch.randn(num_heads, d_model * 2)) # 使用LSTM的双向输出 self.scale = (d_model * 2) ** -0.5 # LSTM双向输出的维度 self.out_proj = nn.Linear(d_model * 2 * num_heads, d_model) self.norm = nn.LayerNorm(d_model) def forward(self, x): """ x : (B, L, d_model) 输入数据 """ B, L, _ = x.shape # 1. 使用LSTM提取双向上下文信息 lstm_out, _ = self.lstm(x) # (B, L, d_model*2) # 2. 使用查询向量与LSTM输出计算注意力 queries = self.query.unsqueeze(0) * self.scale # (1, num_heads, d_model*2) keys = lstm_out.transpose(1, 2) # (B, d_model*2, L) # 计算注意力得分 attn = torch.matmul( queries, # (1, num_heads, d_model*2) keys.unsqueeze(1) # (B, 1, d_model*2, L) ) # -> (B, num_heads, 1, L) attn = attn.squeeze(2) # (B, num_heads, L) # Softmax归一化 attn = F.softmax(attn, dim=-1) # 优化计算,避免 repeat_interleave 增加内存开销 lstm_out_heads = lstm_out.unsqueeze(1).expand(-1, self.num_heads, -1, -1) # (B, num_heads, L, d_model*2) # 聚合注意力加权和 attn_heads = attn.unsqueeze(-1) # (B, num_heads, L, 1) pooled = torch.sum(attn_heads * lstm_out_heads, dim=2) # (B, num_heads, d_model*2) # 展平为一个向量 pooled = pooled.reshape(B, -1) # (B, num_heads * d_model*2) # 线性变换到原始维度 pooled = self.out_proj(pooled) # 归一化 return self.norm(pooled)
最新发布
03-18
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值