SBERT 的输入处理过程
SBERT(Sentence-BERT)是基于 BERT 的模型,用于生成句子嵌入。它的输入处理过程主要包括以下几个步骤:分词、编码、池化和输出句子嵌入。SBERT 的设计简化了句子对任务中的计算,并专注于生成句子的固定长度嵌入。
1. 输入格式
SBERT 的输入可以是:
- 单个句子,用于生成单个句子的嵌入。
- 句子对,用于比较句子之间的语义相似性。
输入示例:
单个句子:Sentence: "A man is playing a guitar."
句子对:
Sentence 1: "A man is playing a guitar."
Sentence 2: "Someone is performing music."
2. 输入处理步骤
Step 1: 分词(Tokenization)
SBERT 使用与 BERT 相同的分词器(如 WordPiece 或 SentencePiece),将输入句子分解为子词(Tokens)。
- 特殊标记:
- [CLS]:表示句子的起始,用于句子级别的分类任务。
- [SEP]:表示句子的结束,用于分隔句子对。
单个句子分词:
输入句子:
"A man is playing a guitar."
分词结果:
['[CLS]', 'A', 'man', 'is', 'playing', 'a', 'guitar', '.', '[SEP]']
句子对分词:
输入句子对:
Sentence 1: "A man is playing a guitar."
Sentence 2: "Someone is performing music."
分词结果:
['[CLS]', 'A', 'man', 'is', 'playing', 'a', 'guitar', '.', '[SEP]', 'Someone', 'is', 'performing', 'music', '.', '[SEP]']
Step 2: 将分词结果转换为输入 ID
每个 Token 会映射到词汇表中的唯一 ID(一个整数值),这些 ID 将被模型作为输入。
示例:
['[CLS]', 'A', 'man', 'is', 'playing', 'a', 'guitar', '.', '[SEP]']
→ [101, 1037, 2158, 2003, 2652, 1037, 9676, 1012, 102]
Step 3: 创建输入张量
SBERT 的输入包含以下张量:
Input IDs:分词结果的 Token ID,表示每个 Token 的索引值。
Attention Mask:用于指示哪些位置是有效的 Token(值为 1),哪些位置是填充(值为 0)。
Token Type IDs(仅用于句子对):
-
- 区分句子 1 和句子 2 的标记:
- 句子 1 的 Token Type ID 为 0。
- 句子 2 的 Token Type ID 为 1。
- 区分句子 1 和句子 2 的标记:
示例:
Input IDs: [101, 1037, 2158, 2003, 2652, 1037, 9676, 1012, 102, 3642, 2003, 5640, 2189, 1012, 102]
Attention Mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Token Type IDs: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
Step 4: 编码(Embedding Generation)
经过分词和张量化的输入会被送入 SBERT 的编码器(基于 BERT 的 Transformer 模型)。编码器对每个 Token 生成上下文相关的嵌入(转到token embedding部分)。
- 输出:
- Token Embeddings:每个 Token 的上下文向量(通常是 768 维)。
- Sequence Embeddings:表示整个序列的特征(如 [CLS] 的嵌入)。
Step 5: 池化(Pooling Layer)
SBERT 在 BERT 的基础上添加了一个池化层,用于将 Token 级别的嵌入转化为句子级别的嵌入。
常用的池化方法:
- Mean-Pooling:对所有 Token 的嵌入向量取平均值,生成句子嵌入。
- [CLS] Token:使用 [CLS] Token 的嵌入作为句子嵌入。
- Max-Pooling:对每个 Token 的每个维度取最大值。
示例:
假设输入的句子含有 5 个 Token,BERT 输出每个 Token 的嵌入为 768 维的向量:
Token Embeddings:
[
[0.1, 0.2, 0.3, ..., 0.7], # Token 1
[0.2, 0.1, 0.4, ..., 0.6], # Token 2
[0.5, 0.3, 0.2, ..., 0.8], # Token 3
[0.6, 0.2, 0.1, ..., 0.5], # Token 4
[0.4, 0.4, 0.3, ..., 0.6], # Token 5
]
Mean-Pooling 的结果:
less
复制代码
Sentence Embedding: [平均值(0.1, 0.2, ..., 0.4), ..., 平均值(0.7, 0.6, ..., 0.6)]
Step 6: 输出句子嵌入
经过池化后,SBERT 生成一个固定长度的向量(通常是 768 维),表示句子的全局语义嵌入。
对于句子对任务:
- 分别计算两个句子的嵌入向量。
- 使用余弦相似度或其他度量方法计算句子之间的相似性。
完整流程示意图
- 输入文本:输入单个句子或句子对。
- 分词:使用 BERT 分词器将文本分解为子词。
- 张量化:将分词结果转换为 Input IDs、Attention Mask 和 Token Type IDs。
- 编码:将张量输入 BERT 编码器,生成每个 Token 的嵌入。
- 池化:应用池化层生成固定维度的句子嵌入。
- 输出:返回句子嵌入或句子对的相似度。
总结
SBERT 的输入处理过程是从文本到嵌入的一整套流水线,包括:
- 分词(Tokenization)。
- 张量化(Tensorization)。
- 编码(Embedding Generation)。
- 池化(Pooling)。
- 输出句子嵌入。
相比于 BERT,SBERT 的优化使得句子嵌入生成更高效,尤其在语义相似度计算和检索任务中表现突出。