Implementation of Self-Attention is a common question in MLE interview, especially in NLP field.
Multi-Head Self-Attention:
Given: df,nhd_f, n_hdf,nh, where dfd_fdf is the dimension of embedding and nhn_hnh is number of heads.
- Compute Q, K, V matrix.
- Divide Q, K, V into nhn_hnh heads and swap the head dimension and sequence dimension for later computation.
- Compute attention score between Q and K, remember to divide df/nh\sqrt{d_f/n_h}df/nh.
- Aggregate V according to attention score, then transpose V back to 3-dimensional tensor
- Linear transform the output.
class MultiHeadAttention(nn.Module):
def __init__(self, feature_dim, num_heads):
super(MultiHeadAttention, self).__init__()
assert feature_dim % num_heads == 0, "Feature dimension must be divisible by num_heads"
self.num_heads = num_heads
self.feature_dim = feature_dim
self.head_dim = feature_dim // num_heads
self.q_linear = nn.Linear(feature_dim, feature_dim)
self.k_linear = nn.Linear(feature_dim, feature_dim)
self.v_linear = nn.Linear(feature_dim, feature_dim)
self.out = nn.Linear(feature_dim, feature_dim)
def forward(self, x, mask=None):
batch_size, seq_len, feature_dim = x.size()
# (B, S, F) -> (B, S, F)
Q = self.q_linear(x)
K = self.k_linear(x)
V = self.v_linear(x)
# (B, S, F) -> (B, S, H, D) -> (B, H, S, D)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# (B, H, S, D) * (B, H, S, D).T -> (B, H, S, S)
weights = torch.matmul(Q, K.transpose(-1, -2)) / (self.head_dim ** 0.5)
if mask is not None:
weights = weights.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
# (B, H, S, S) * (B, H, S, D) -> (B, H, S, D)
output = torch.matmul(weights, V)
# (B, H, S, D) -> (B, S, H, D) -> (B, S, F)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, feature_dim)
# (B, S, F)
return self.out(output)
Transformer Block:
There are 2 LayerNorm layers, 2 residual connections, 2 dropout layers and a Feedforward Network in a transformer block despite the Multi-Head Self-Attention.
We can consider that each block mainly consists of two parts: self-attention and FFN. For each part, there is a dropout followed by a residual connection. And the last step of each part is layer norm to regulate the distribution to next part.
class TransformerBlock(nn.Module):
def __init__(self, feature_dim, num_heads, hidden_dim, dropout=0.1):
super(TransformerBlock, self).__init__()
self.attention = MultiHeadAttention(feature_dim, num_heads, rope)
self.norm1 = nn.LayerNorm(feature_dim)
self.norm2 = nn.LayerNorm(feature_dim)
self.ffn = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, feature_dim)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_output = self.attention(x, mask)
x = self.norm1(x + self.dropout(attn_output))
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output))
return x