FlagEmbedding项目微调教程:金融领域数据准备实战

FlagEmbedding项目微调教程:金融领域数据准备实战

FlagEmbedding Dense Retrieval and Retrieval-augmented LLMs FlagEmbedding 项目地址: https://gitcode.com/gh_mirrors/fl/FlagEmbedding

前言

在自然语言处理领域,微调预训练模型是提升模型在特定领域表现的关键步骤。FlagEmbedding项目提供了一套完整的嵌入模型解决方案,本教程将重点讲解如何为金融领域的微调任务准备数据集。

环境准备

首先需要安装必要的数据处理工具:

pip install -U datasets

这个工具包将帮助我们高效地加载和处理数据集。

数据集介绍

我们选用了一个开源的金融问答数据集,该数据集包含7000条记录,每一条记录包含以下字段:

  • question:金融相关问题
  • answer:对应问题的答案
  • context:答案所在的上下文
  • ticker:股票代码
  • filing:文件类型

数据预处理

1. 加载原始数据

from datasets import load_dataset
ds = load_dataset("virattt/financial-qa-10K", split="train")

2. 字段重命名与选择

为了适配FlagEmbedding的微调格式,我们需要对字段进行选择和重命名:

ds = ds.select_columns(column_names=["question", "context"])
ds = ds.rename_column("question", "query")
ds = ds.rename_column("context", "pos")
ds = ds.add_column("id", [str(i) for i in range(len(ds))])

3. 处理正负样本

在嵌入模型训练中,负样本对模型性能至关重要。原始数据没有提供负样本,我们可以从整个语料库中随机采样:

import numpy as np
np.random.seed(520)
neg_num = 10

def str_to_lst(data):
    data["pos"] = [data["pos"]]
    return data

# 采样负样本
new_col = []
for i in range(len(ds)):
    ids = np.random.randint(0, len(ds), size=neg_num)
    while i in ids:
        ids = np.random.randint(0, len(ds), size=neg_num)
    neg = [ds[i.item()]["pos"] for i in ids]
    new_col.append(neg)
ds = ds.add_column("neg", new_col)

# 将正样本转换为列表形式
ds = ds.map(str_to_lst)

4. 添加查询提示

FlagEmbedding在推理时会使用查询指令(query_instruction_for_retrieval),我们需要在数据中添加相应的提示:

instruction = "Represent this sentence for searching relevant passages: "
ds = ds.add_column("prompt", [instruction]*len(ds))

5. 数据格式说明

最终每条数据应包含以下字段:

  • query:查询文本
  • pos:正样本文本列表
  • neg:负样本文本列表
  • id:唯一标识符
  • prompt:查询提示

数据集划分

将数据集划分为训练集和测试集:

split = ds.train_test_split(test_size=0.1, shuffle=True, seed=520)
train = split["train"]
test = split["test"]

评估数据准备

为了后续评估模型性能,我们需要准备专门的测试数据:

1. 查询数据

queries = test.select_columns(column_names=["id", "query"])
queries = queries.rename_column("query", "text")

2. 语料数据

corpus = ds.select_columns(column_names=["id", "pos"])
corpus = corpus.rename_column("pos", "text")

3. 查询-文档相关性数据

qrels = test.select_columns(["id"])
qrels = qrels.rename_column("id", "qid")
qrels = qrels.add_column("docid", list(test["id"]))
qrels = qrels.add_column("relevance", [1]*len(test))

数据保存

最后,将所有处理好的数据保存到本地:

# 训练数据
train.to_json("ft_data/training.json")

# 评估数据
queries.to_json("ft_data/test_queries.jsonl")
corpus.to_json("ft_data/corpus.jsonl")
qrels.to_json("ft_data/test_qrels.jsonl")

总结

通过本教程,我们完成了FlagEmbedding项目在金融领域微调前的数据准备工作。关键步骤包括:

  1. 原始数据加载与字段选择
  2. 正负样本处理
  3. 查询提示添加
  4. 数据集划分
  5. 评估数据准备

这些准备工作为后续的模型微调奠定了坚实基础,确保模型能够在金融领域获得更好的嵌入表示效果。

FlagEmbedding Dense Retrieval and Retrieval-augmented LLMs FlagEmbedding 项目地址: https://gitcode.com/gh_mirrors/fl/FlagEmbedding

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

毕素丽

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值