多头注意力机制
import math
import torch
from torch import nn
from d2l import torch as d2l
实现类
#@save
class MultiHeadAttention(nn.Module):
"""多头注意力"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(valu

本文介绍了如何使用PyTorch实现多头注意力机制,通过维度变换提升计算效率,适用于Transformer等模型。核心内容包括MultiHeadAttention类的定义,以及transpose_qkv和transpose_output函数的作用。实例展示了如何创建并测试一个具有5个注意力头的模型。
最低0.47元/天 解锁文章
3164

被折叠的 条评论
为什么被折叠?



