目录
0. 本栏目竞赛汇总表
1. 本文主旨
- 大白话:上一篇文章已经实现了输入数据的准备,本文将定义用于训练AI模型的Transformer类,并概要讲解每个组件的作用。
- 通过本文可收获技能:Transformer类定义案例、Transformer类常用组件。
- 上文回顾:Eedi竞赛Transformer框架解决方案03-定义Transformer数据输入层
2. Transformer类架构
3. Transformer编码层
3.1 代码实现
def encode(self, features):
if features is None:
return None
# Transformer编码层的核心部分
psg_out = self.model(
input_ids=features['input_ids'], # token ids
attention_mask=features['attention_mask'], # 注意力掩码
return_dict=True
)
# 后续是池化层的操作
p_reps = self.sentence_embedding(psg_out.last_hidden_state,
features['attention_mask'])
if self.normlized:
p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
return p_reps.contiguous()
3.2 大白话Transformer编码层
- 输入:文本的token ids和attention mask
- 核心:使用self.model(Qwen模型)进行编码
- 输出:每个token的深层语义表示
- 作用:像一个翻译官,把文字翻译成数学向量
这个函数中的self.model(…)调用就是Transformer编码层的实际执行部分,它利用预训练语言模型的能力来理解文本。
4. 池化层
4.1 代码实现
def sentence_embedding(self, hidden_state, mask):
if self.sentence_pooling_method == 'mean':
# 平均池化:所有token加权平均
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
elif self.sentence_pooling_method == 'cls':
# CLS池化:取第一个token
return hidden_state[:, 0]
elif self.sentence_pooling_method == 'last':
# 最后token池化:取最后一个有效token
return self.last_token_pool(hidden_state, mask)
4.2 大白话池化层
就像是一个文章总结器:
输入:一篇文章(每个词都有一个向量表示)
处理:提供三种总结方式
- mean:看完所有词,取个平均值
- cls:只看第一个特殊标记词
- last:看最后一个有意义的词
输出:一个总结向量(代表整篇文章的意思)
这就像是三种读文章的方式:
- mean:仔细读完每个字
- cls:只看开头的摘要
- last:看最后的结论
不管用哪种方式,最终都能得到一个能代表整篇文章含义的"总结"。
5. 相似度计算层
5.1 代码实现
def forward(self, query, doc):
# 1. 获取编码
query_emb = self.encode(query) # [batch_size, hidden_size]
doc_emb = self.encode(doc) # [batch_size, hidden_size]
# 2. 计算相似度
scores = self.compute_similarity(query_emb, doc_emb) # [batch_size, batch_size]
scores = scores / self.temperature # 温度缩放
scores = scores.view(query_emb.size(0), -1) # 调整形状
# 3. 计算损失
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
loss = self.cross_entropy(scores, target)
5.2 大白话相似度计算层
就像是一个阅卷老师:
- 拿到两份文章(query和doc的向量表示)
- 计算它们的相似度(用compute_similarity)
- 用温度调节打分的严格程度(temperature)
- 整理分数的格式(view reshape)
- 最后算出一个总分(loss)
这个过程就像:
- 老师对比学生答案和标准答案
- 给出相似度分数
- 调整打分标准的严格程度
- 最后得出这次作答的总体评分
这样模型就能学会区分哪些文本对是相关的,哪些是不相关的。
6. Transformer模型类(汇总)
6.1 代码汇总
class CustomSimCSEModel(nn.Module):
def __init__(self, path, config, quantization_config, emb_size=1024,
sentence_pooling_method='last', normlized=True, temperature=0.02):
super().__init__()
self.model = AutoModel.from_pretrained(path, config=config,
quantization_config=quantization_config)
self.config = self.model.config
self.sentence_pooling_method = sentence_pooling_method
self.normlized = normlized
self.temperature = temperature
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean', label_smoothing=0.0)
def last_token_pool(self, last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def gradient_checkpointing_enable(self, **kwargs):
self.model.gradient_checkpointing_enable(**kwargs)
def sentence_embedding(self, hidden_state, mask):
if self.sentence_pooling_method == 'mean':
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
elif self.sentence_pooling_method == 'cls':
return hidden_state[:, 0]
elif self.sentence_pooling_method == 'last':
return self.last_token_pool(hidden_state, mask)
def encode(self, features):
if features is None:
return None
psg_out = self.model(input_ids=features['input_ids'],
attention_mask=features['attention_mask'],
return_dict=True)
p_reps = self.sentence_embedding(psg_out.last_hidden_state,
features['attention_mask'])
if self.normlized:
p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
return p_reps.contiguous()
def compute_similarity(self, q_reps, p_reps):
if len(p_reps.size()) == 2:
return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
def forward(self, query, doc):
query_emb = self.encode(query)
doc_emb = self.encode(doc)
scores = self.compute_similarity(query_emb, doc_emb) / self.temperature
scores = scores.view(query_emb.size(0), -1)
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
loss = self.cross_entropy(scores, target)
return dict(
loss=loss,
scores=scores,
query_emb=query_emb,
doc_emb=doc_emb,
)
6.2 大白话Transformer模型类
让我用一个教学场景来解释CustomSimCSEModel类:
这个类就像一个特殊的阅卷老师,专门处理这样一种题目:
“看到学生的错误答案,判断学生是因为什么误解导致做错的”
工作流程像这样:
- 阅读理解(Transformer编码层):
- 老师仔细阅读题目和错误答案
- 理解每个词的含义和它们之间的关系
- 就像一个专业翻译,把文字转成"理解向量"
- 提取要点(池化层):
- 把长长的文字总结成关键点
- 可以选择:看全文平均/看开头/看结尾
- 就像写读书笔记,提取核心内容
- 判断相似度(相似度计算层):
- 拿着题目和各种常见误解概念比对
- 计算它们之间的相似程度
- 可以调节判断的严格程度
- 就像给答案打分,看匹配度
- 学习改进(损失计算):
- 根据打分结果总结经验
- 提高自己判断的准确度
- 就像老师在不断积累教学经验
整体来说,这就是一个能通过对比学习,逐渐掌握"看到错误答案就能判断出背后误解原因"的智能助教系统。
(To be continued)