推荐系统之采样修正的双塔模型

本文介绍的论文题目是:《Sampling-Bias-Corrected Neural Modeling for Large Corpus Item Recommendations》
论文下载地址是:Google工业风最新论文, Youtube提出双塔结构流式模型进行大规模推荐

本文是谷歌工业风论文的新作,介绍了在大规模推荐系统中使用双塔模型来做召回的一些经验,值得细细品读。本文仅对文章内容做一个简单介绍,更多细节建议阅读原论文。

1、背景

大规模推荐系统一般分为两阶段,即召回和排序阶段,本文重点关注召回阶段。

给定{用户,上下文,物品}的三元组,一个通用的方法首先是分别计算{用户,上下文} 和 {物品} 的向量表示,然后通过一定的方式如点积来计算二者的匹配得分。这种基于表示学习的方法通常面临两个方面的挑战:

1)工业界中物品的数量十分巨大。
2)通过收集用户反馈得到的数据集十分稀疏,导致模型对于长尾物品的预测具有很大的方差,同时也面临着物品冷启动的问题。

工业界现有的推荐系统都需要从一个超大规模的候选集中拉取item进行打分排序。解决数据稀疏和指数级候选集分布的一种通常做法是从item的内容特征中学习出item的稠密表示。这里很自然地就想到了工业界大名鼎鼎且应用广泛的双塔神经网络结构,其中的一塔就是从丰富的item内容特征中学习到item的表示。

工业界目前训练双塔结构一般是通过随机mini-batch的方式来优化损失函数。这种训练方式存在的一个显著问题就是in-batch loss会因为随机采样偏差而导致模型效果不好,尤其是当样本分布出现明显倾斜的时候。我们提出了一种全新的算法,可以从流式数据中预估item的频率。通过理论分析和实验,新算法有能力在不知道候选集全部的词典情况下做出无偏差的估计并且可以自适应候选集分布的变化。在Youtube线上的实验也证明了该算法的有效性。
我们考虑一种通用的推荐问题设定:给定一系列query和候选集,目标就是在给定query的情况下返回最相关的一个候选子集。针对这里的query和候选集中的item,都可以用各自的特征向量来进行表示。在个性化推荐场景中,则是用户user和会话的上下文context构成这里的query侧。
都知道Youtube的推荐架构主要分为两个阶段:召回和排序。而本文则主要聚焦于新增一路如下图所示的双塔召回。query侧的塔是由大量的用户观看历史形成的user features以及共同的seed features构成,候选集侧的塔则是由视频特征构成。训练样本的Label则是由用户的点击和播放时长加权得到。

模型架构

近几年来,随着深度学习的发展,双塔模型常用来用做召回阶段的模型,双塔模型的一般结构如下:
在这里插入图片描述
可以看到,双塔模型两侧分别对{用户,上下文} 和 {物品} 进行建模,并在最后一层计算二者的内积。对于每一个正样本,需要随机采样一些负样本,当物品数量十分巨大的

<think> 我们正在讨论如何训练一个支持检索增强生成(RAG)的模型。根据引用内容,RAG的核心思想是结合检索和生成两个过程:首先从外部知识源检索相关信息,然后利用这些信息生成最终输出[^1][^2][^3]。因此,训练一个RAG模型需要同时考虑检索模块和生成模块的训练训练RAG模型的方法可以分为以下几个关键步骤: 1. **数据准备**: - 构建一个包含查询(query)、相关文档(relevant documents)和参考答案(reference answer)的数据集。 - 例如,对于每个查询,我们需要有对应的人工标注的相关文档(或段落)以及人工撰写的答案。 2. **训练检索器(Retriever)**: - 目标:训练一个检索器,能够对于给定的查询,从大规模文档库中检索出最相关的文档(或段落)。 - 方法: - 可以使用双塔模型(Dual Encoder)结构:一个编码器用于编码查询,另一个编码器用于编码文档,然后计算查询向量和文档向量的相似度(如余弦相似度)。 - 训练目标:最大化相关查询-文档对的相似度,最小化不相关对的相似度。常用的损失函数有对比损失(Contrastive Loss)或交叉熵损失(使用负采样)。 - 负样本采样:从同一个批(batch)中随机选择其他文档作为负样本,或者使用难负例挖掘(hard negative mining)[^3]。 ```python # 伪代码:双塔模型训练 query_encoder = Transformer(...) # 查询编码器 doc_encoder = Transformer(...) # 文档编码器 # 前向传播 query_embed = query_encoder(query_input) doc_embed = doc_encoder(doc_input) # 计算相似度矩阵(batch内计算) similarities = torch.matmul(query_embed, doc_embed.T) # [batch_size, batch_size] # 对比损失:对角线上的样本为正样本 labels = torch.arange(batch_size) # 对角线位置为正样本 loss = cross_entropy(similarities, labels) ``` 3. **训练生成器(Generator)**: - 目标:训练一个生成模型(通常是序列到序列的模型),它能够根据检索到的相关文档来生成答案。 - 方法: - 输入:将查询(query)和检索到的文档(retrieved documents)拼接起来,作为生成器的输入。 - 输出:参考答案(reference answer)。 - 训练目标:最大化生成答案的条件对数似然(即标准的自回归语言模型训练目标)。 ```python # 伪代码:生成器训练 generator = T5ForConditionalGeneration(...) # 或其他Seq2Seq模型 # 输入格式:将查询和检索到的文档拼接 input_text = f"问题:{query} 上下文:{retrieved_documents}" output_text = reference_answer # 训练步骤 outputs = generator(input_ids=input_encoded, labels=output_encoded) loss = outputs.loss ``` 4. **联合训练(可选)**: - 在某些RAG框架中(如RAG-Token[^3]),检索器和生成器可以进行端到端的联合训练。这意味着检索器的梯度可以通过生成器的梯度进行更新。 - 方法: - 在每一步,检索器检索出文档后,生成器基于检索到的文档生成答案。 - 损失函数:生成器的损失(如交叉熵损失)会反向传播到检索器,从而更新检索器的参数。 - 挑战:由于检索过程通常是离散的(比如从大量文档中选取top-K),因此需要通过一些技巧(如REINFORCE或Gumbel-Softmax)来实现梯度传播。 5. **迭代训练(Iterative Training)**: - 可以采用迭代的方式改进RAG模型: 1. 首先使用初始数据集训练检索器和生成器。 2. 使用训练好的模型在未标注的数据上生成答案,并收集那些模型预测置信度低或答案错误的样本。 3. 对这些困难的样本进行人工标注(包括相关文档和参考答案),然后重新训练模型。 - 这个过程可以迭代多次,以逐步提升模型性能。 6. **评估与调优**: - 使用验证集评估RAG模型的整体性能(如生成答案的质量)以及检索模块的性能(如召回率@K)。 - 根据评估结果调整超参数,例如: - 检索文档的数量(top-K) - 生成器的最大输出长度 - 训练时的学习率等 7. **注意事项**: - **外部知识库更新**:RAG的一个优势是外部知识库可以独立更新,而无需重新训练整个模型[^3]。但是,如果知识库的结构或内容发生巨大变化,可能需要重新训练检索器。 - **计算效率**:检索器的训练需要处理大规模文档库,因此需要高效的向量索引(如FAISS)和负采样策略。 8. **工具支持**: - 可以使用现有的框架来简化RAG模型的训练,例如: - Hugging Face Transformers库:提供了预训练的检索模型(如DPR)和生成模型(如BART、T5)的实现。 - Haystack:一个用于构建端到端问答系统的框架,支持RAG。 - FAISS:用于高效相似度搜索的库。 综上所述,训练一个RAG模型需要分别训练检索器和生成器,并可能进行联合训练。迭代训练和难负例挖掘是提升模型性能的关键技术。 参考文献: [^1]: RAG(检索增强生成)一文全解。 [^2]: 大模型检索增强生成RAG。 [^3]: Lewis et al., Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks (arXiv:2005.11401). 相关问题: 1. 如何构建适用于RAG训练的数据集? 2. 在训练检索器时,如何有效进行负样本采样? 3. RAG模型的联合训练具体是如何实现的?有哪些挑战? 4. 如何评估RAG模型的性能? 5. RAG与微调(Fine-tuning)预训练语言模型相比有哪些优势? </think> ### 训练支持检索增强生成(RAG)模型的方法 训练RAG模型涉及三个核心模块的协同优化:**检索器(Retriever)**、**生成器(Generator)**和**融合机制**。以下是关键训练方法和技术: --- #### 1. **双模块联合训练框架** $$ \mathcal{L}_{\text{total}} = \mathcal{L}_{\text{ret}} + \lambda \mathcal{L}_{\text{gen}} $$ - **检索器训练**: - 目标:学习查询$q$与文档$d$的语义匹配 - 损失函数:对比损失(Contrastive Loss) $$ \mathcal{L}_{\text{ret}} = -\log \frac{e^{\text{sim}(q,d^+)}}{\sum_{d^-} e^{\text{sim}(q,d^-)}} $$ 其中$d^+$为正例文档,$d^-$为负例文档[^3] - 常用模型:DPR(Dense Passage Retrieval) - **生成器训练**: - 目标:基于检索结果$d$生成答案$a$ - 损失函数:标准语言模型损失 $$ \mathcal{L}_{\text{gen}} = -\sum_{t} \log P(a_t | a_{<t}, q, d) $$ - 常用模型:BART/T5等seq2seq架构[^4] > **联合训练技巧**:冻结检索器参数初始训练生成器,再联合微调[^2] --- #### 2. **负采样策略(关键优化)** 检索器的性能依赖高质量的负样本: - **随机负采样**:从语料库随机选取无关文档 - **难负例挖掘**: - BM25检索结果中相关性中等(非最相关)的文档[^3] - 生成器置信度高但实际错误的样本 - **对抗负采样**:GAN生成混淆性负例 ```python # 难负例挖掘示例(伪代码) def hard_negative_mining(query, top_k=50): bm25_results = bm25_retriever(query, k=top_k) gold_doc = get_gold_document(query) # 标注正例 # 选择排名5-15的文档作为难负例 negatives = [doc for doc in bm25_results[5:15] if doc != gold_doc] return negatives ``` --- #### 3. **端到端训练方法** **RAG-Token模式**(更灵活): $$ P(a|q) = \prod_t P(a_t|q, d_t) \quad \text{(每个token可参考不同文档)} $$ - **梯度传播路径**:生成器梯度反向传播至检索器 - **实现方案**: 1. 检索器输出文档的概率分布 2. 通过Gumbel-Softmax实现可微检索[^3] 3. 生成器基于加权文档向量生成文本 --- #### 4. **多阶段训练流程 ```mermaid graph LR A[预训练初始化] --> B[检索器训练] B --> C[生成器训练] C --> D[端到端微调] D --> E[迭代增强] ``` 1. **阶段1:检索器预训练** - 数据集:MS MARCO/NQ等问答数据集 - 目标:学习query-document映射 2. **阶段2:生成器训练** - 输入格式:拼接查询与检索文档 ```text [CLS] 问题:{question} [SEP] 上下文:{retrieved_text} [SEP] ``` 3. **阶段3:联合微调** - 使用多任务损失:$\lambda$通常取0.8~0.9[^4] - 逐步降低检索器学习率 --- #### 5. **迭代增强策略** 通过自生成数据改进模型: 1. 初始模型生成答案$a_{\text{pred}}$ 2. 人工/规则修正错误答案得$a_{\text{corrected}}$ 3. 将$(q, a_{\text{corrected}})$作为新训练数据 4. 重新检索与$a_{\text{corrected}}$最匹配的文档$d_{\text{new}}$ 5. 用新三元组$(q, d_{\text{new}}, a_{\text{corrected}})$微调模型[^4] > 实验显示3轮迭代可使准确率提升12%+[^4] --- #### 6. **关键训练配置 | 组件 | 推荐配置 | 作用 | |------------|-----------------------------------|----------------------| | 优化器 | AdamW (lr=3e-5) | 参数优化 | | 批大小 | 32-64 (检索器), 16-32 (生成器) | 平衡显存与稳定性 | | 文档数量 | 5-10篇/查询 | 信息覆盖与噪声平衡 | | 混合精度 | FP16+梯度裁剪 | 加速训练 | > 硬件建议:单机8×A100(检索器需大内存支持)[^2] --- #### 7. **评估与调优 - **检索器指标**: - Recall@k:前k结果包含正确答案的比例 - MRR(平均倒数排名) - **生成器指标**: - ROUGE-L/BLEU-4 - 人工评估答案流畅性 - **端到端测试**: - 纠错率:$\frac{\text{修正错误数}}{\text{总错误数}}$ - 幻觉减少率[^1] > **调优重点**:负采样质量 > 文档数量 > 损失权重$\lambda$[^3] --- ### 相关问题 1. 如何为RAG模型构建高质量的外部知识库? 2. RAG与微调(Fine-tuning)预训练语言模型的优劣比较? 3. 如何处理RAG中的检索噪声导致的生成错误?[^1] 4. 小型企业部署RAG系统的最低硬件要求是什么?[^4] 5. 如何实现RAG模型的持续学习以适应更新的知识库?[^3]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值