BERT模型
import torch
from torch import nn
from d2l import torch as d2l
模型输入
#@save
def get_tokens_and_segments(tokens_a, tokens_b=None):
"""获取输入序列的词元及其片段索引"""
tokens = ['<cls>'] + tokens_a + ['<sep>']
# 0和1分别标记片段A和B
segments = [0] * (len(tokens_a) + 2)
if tokens_b is not None:
tokens += tokens_b + ['<sep>']
segments += [1] * (len(tokens_b) + 1)
return tokens, segments
两句话
tokens_a和tokens_b前面加上开始符和分隔符
segments是BERT三个编码token embedding、segment embedding、position embedding的segment embedding,标记是第几句话
BERT编码器
#@save
class BERTEncoder(nn.Module):
"""BERT编码器"""
def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
ffn_num_hiddens, num_heads, num_layers, dropout,
max_len=1000, key_size=768, query_size=768, value_size=768,
**kwargs):
super(BERTEncoder, self).__init__(**kwargs)
self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
self.segment_embedding = nn.Embedding(2, num_hiddens)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module(f"{
i}", d2l.EncoderBlock(
key_size,

该博客详细介绍了BERT模型的构建过程,包括输入处理、编码器的设计、掩蔽语言模型任务以及下一句预测任务。BERT模型使用词元、片段和位置嵌入,并通过Transformer编码器抽取序列信息。掩蔽语言模型用于预测被掩蔽的词元,而下一句预测任务则判断两个句子是否连续。整个模型在预训练阶段用于这两个任务,为下游任务提供强大的语义表示。
最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



