# Utility function to create deep copies of a module
def get_clones(module, num_of_deep_copies):
# Create deep copies so that we can tweak each module's weights independently
return nn.ModuleList([copy.deepcopy(module) for _ in range(num_of_deep_copies)])
get_clones 是一个实用函数,目的是创建一个模块的深拷贝(deep copy)。深拷贝意味着复制模块及其所有内部参数、状态等,使得每个副本相互独立,互不影响。你可以对每个副本进行不同的操作,而不会修改其他副本的权重和状态。
class MultiHeadedAttention(nn.Module):
def __init__(self, model_dimension, number_of_heads, dropout_probability, log_attention_weights):
super().__init__()
assert model_dimension % number_of_heads == 0, f'Model dimension must be divisible by the number of heads.'
self.head_dimension = int(model_dimension / number_of_heads)
self.number_of_heads = number_of_heads
self.qkv_nets = get_clones(nn.Linear(model_dimension, model_dimension), 3) # get_clones(..., 3):表示对这个线性层进行 3次深拷贝,生成三个 独立的线性层副本,即一个用于处理 Query(Q),一个用于处理 Key(K),一个用于处理 Value(V)。
self.out_projection_net = nn.Linear(model_dimension, model_dimension)
self.attention_dropout = nn.Dropout(p=dropout_probability) # no pun intended, not explicitly mentioned in paper
self.softmax = nn.Softmax(dim=-1) # -1 stands for apply the softmax along the last dimension
self.log_attention_weights = log_attention_weights # should we log attention weights
self.attention_weights = None # for visualization purposes, I cache the weights here (translation_script.py)
def attention(self, query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dimension)
attention_weights = self.softmax(scores)
attention_weights = self.attention_dropout(attention_weights)
intermediate_token_representations = torch.matmul(attention_weights, value)
return intermediate_token_representations, attention_weights # attention weights for visualization purposes
def forward(self, query, key, value,):
batch_size = query.shape[0]
query, key, value = [net(x).view(batch_size, -1, self.number_of_heads, self.head_dimension).transpose(1, 2)
for net, x in zip(self.qkv_nets, (query, key, value))]#get_clones生成三个 独立的线性层副本,即一个用于处理 Query(Q),一个用于处理 Key(K),一个用于处理 Value(V)。
intermediate_token_representations, attention_weights = self.attention(query, key, value)
if self.log_attention_weights:
self.attention_weights = attention_weights
reshaped = intermediate_token_representations.transpose(1, 2).reshape(batch_size, -1, self.number_of_heads * self.head_dimension)
token_representations = self.out_projection_net(reshaped)
return token_representations