FlagEmbedding项目微调教程:金融领域数据准备实战
前言
在自然语言处理领域,微调预训练模型是提升模型在特定领域表现的关键步骤。FlagEmbedding项目提供了一套完整的嵌入模型解决方案,本教程将重点讲解如何为金融领域的微调任务准备数据集。
环境准备
首先需要安装必要的数据处理工具:
pip install -U datasets
这个工具包将帮助我们高效地加载和处理数据集。
数据集介绍
我们选用了一个开源的金融问答数据集,该数据集包含7000条记录,每一条记录包含以下字段:
- question:金融相关问题
- answer:对应问题的答案
- context:答案所在的上下文
- ticker:股票代码
- filing:文件类型
数据预处理
1. 加载原始数据
from datasets import load_dataset
ds = load_dataset("virattt/financial-qa-10K", split="train")
2. 字段重命名与选择
为了适配FlagEmbedding的微调格式,我们需要对字段进行选择和重命名:
ds = ds.select_columns(column_names=["question", "context"])
ds = ds.rename_column("question", "query")
ds = ds.rename_column("context", "pos")
ds = ds.add_column("id", [str(i) for i in range(len(ds))])
3. 处理正负样本
在嵌入模型训练中,负样本对模型性能至关重要。原始数据没有提供负样本,我们可以从整个语料库中随机采样:
import numpy as np
np.random.seed(520)
neg_num = 10
def str_to_lst(data):
data["pos"] = [data["pos"]]
return data
# 采样负样本
new_col = []
for i in range(len(ds)):
ids = np.random.randint(0, len(ds), size=neg_num)
while i in ids:
ids = np.random.randint(0, len(ds), size=neg_num)
neg = [ds[i.item()]["pos"] for i in ids]
new_col.append(neg)
ds = ds.add_column("neg", new_col)
# 将正样本转换为列表形式
ds = ds.map(str_to_lst)
4. 添加查询提示
FlagEmbedding在推理时会使用查询指令(query_instruction_for_retrieval),我们需要在数据中添加相应的提示:
instruction = "Represent this sentence for searching relevant passages: "
ds = ds.add_column("prompt", [instruction]*len(ds))
5. 数据格式说明
最终每条数据应包含以下字段:
- query:查询文本
- pos:正样本文本列表
- neg:负样本文本列表
- id:唯一标识符
- prompt:查询提示
数据集划分
将数据集划分为训练集和测试集:
split = ds.train_test_split(test_size=0.1, shuffle=True, seed=520)
train = split["train"]
test = split["test"]
评估数据准备
为了后续评估模型性能,我们需要准备专门的测试数据:
1. 查询数据
queries = test.select_columns(column_names=["id", "query"])
queries = queries.rename_column("query", "text")
2. 语料数据
corpus = ds.select_columns(column_names=["id", "pos"])
corpus = corpus.rename_column("pos", "text")
3. 查询-文档相关性数据
qrels = test.select_columns(["id"])
qrels = qrels.rename_column("id", "qid")
qrels = qrels.add_column("docid", list(test["id"]))
qrels = qrels.add_column("relevance", [1]*len(test))
数据保存
最后,将所有处理好的数据保存到本地:
# 训练数据
train.to_json("ft_data/training.json")
# 评估数据
queries.to_json("ft_data/test_queries.jsonl")
corpus.to_json("ft_data/corpus.jsonl")
qrels.to_json("ft_data/test_qrels.jsonl")
总结
通过本教程,我们完成了FlagEmbedding项目在金融领域微调前的数据准备工作。关键步骤包括:
- 原始数据加载与字段选择
- 正负样本处理
- 查询提示添加
- 数据集划分
- 评估数据准备
这些准备工作为后续的模型微调奠定了坚实基础,确保模型能够在金融领域获得更好的嵌入表示效果。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考