使用Sentence Transformers训练Quora重复问题检测模型的技术指南

使用Sentence Transformers训练Quora重复问题检测模型的技术指南

sentence-transformers Multilingual Sentence & Image Embeddings with BERT sentence-transformers 项目地址: https://gitcode.com/gh_mirrors/se/sentence-transformers

项目背景

Sentence Transformers是一个强大的框架,用于训练和部署句子嵌入模型。本文将重点介绍如何使用该框架训练专门用于检测Quora重复问题的模型。这类模型在信息检索和问答系统中有着广泛的应用价值。

数据集介绍

Quora重复问题数据集包含超过50万条句子和40万对标注数据,每对问题都被标记为是否重复。这个数据集非常适合训练能够识别语义相似性的模型。

训练方法选择

在训练过程中,选择合适的损失函数至关重要。针对Quora重复问题检测任务,我们主要考虑两种损失函数:

1. 对比损失(Contrastive Loss)

对比损失特别适合处理成对分类任务。它的工作原理是:

  • 将相似对(标记为1)在向量空间中拉近
  • 将不相似对(标记为0)推离至超过设定的边界值

改进版的在线对比损失(OnlineContrastiveLoss)能自动检测批次中的困难样本,并仅针对这些样本计算损失。

from datasets import load_dataset
from sentence_transformers import losses

train_dataset = load_dataset("sentence-transformers/quora-duplicates", "pair-class", split="train")
train_loss = losses.OnlineContrastiveLoss(model=model, margin=0.5)

2. 多重负样本排序损失(MultipleNegativesRankingLoss)

这种损失函数特别适合信息检索任务,它只需要正样本对(即重复问题对)。其优势在于:

  • 不需要调整超参数
  • 能有效处理大规模候选集
train_dataset = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train")
train_loss = losses.MultipleNegativesRankingLoss(model)

为了提高训练效果,我们可以利用问题的对称性,将(anchor, positive)和(positive, anchor)都加入训练集:

from datasets import concatenate_datasets

train_dataset = concatenate_datasets([
    train_dataset,
    train_dataset.rename_columns({"anchor": "positive", "positive": "anchor"})
])

多任务学习策略

结合两种损失函数的优势,我们可以采用多任务学习策略:

  • 对比损失:擅长区分重复和非重复问题对
  • 多重负样本排序损失:擅长从大量候选中找出相似问题

实现代码示例:

from sentence_transformers import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
    model=model,
    train_dataset={
        "mnrl": mnrl_train_dataset,
        "cl": cl_train_dataset,
    },
    loss={
        "mnrl": mnrl_train_loss,
        "cl": cl_train_loss,
    },
)
trainer.train()

训练技巧

  1. 批次大小:尽可能使用大的批次大小,因为从更多候选中识别正确匹配能提高模型性能
  2. 硬件配置:在32GB GPU内存上,可以设置批次大小为350
  3. 数据假设:多重负样本排序损失假设随机采样的问题对通常不重复,如果数据集不满足此假设,效果可能不佳

预训练模型

目前可用的预训练模型包括:

  1. distilbert-base-nli-stsb-quora-ranking:基于distilbert-base-nli-stsb-mean-tokens模型,使用两种损失函数在Quora数据集上微调
  2. distilbert-multilingual-nli-stsb-quora-ranking:上述模型的多语言版本,支持50种语言

加载预训练模型:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("distilbert-base-nli-stsb-quora-ranking")

应用场景

训练好的模型可以用于:

  • 重复问题挖掘
  • 大规模语义相似性搜索
  • 问答系统去重
  • 社区内容管理

通过本文介绍的方法,开发者可以训练出高效的重复问题检测模型,应用于各种实际场景中。

sentence-transformers Multilingual Sentence & Image Embeddings with BERT sentence-transformers 项目地址: https://gitcode.com/gh_mirrors/se/sentence-transformers

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

杨洲泳Egerton

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值