【2025】Datawhale AI夏令营-多模态RAG-Task1、Task2笔记-任务理解与Baseline代码解读
任务背景
科大讯飞AI比赛-多模态RAG图文问答挑战赛:https://challenge.xfyun.cn/topic/info?type=Multimodal-RAG-QA&ch=dwsf2517
此次比赛以财报问答任务作为具体场景,旨在解决图文信息混排的情况下需要同时整合图表和文字的复杂推理问题。在这类问题中,模型不仅需要理解自然语言,还要”看懂“图像,将二者关联并进行推理。
例:
问题:根据图表显示,产品A的销售额在哪个季度开始下降?
材料:

此比赛的输入与输出要求如下:
输入:
test.json:包含所有评测问题,文件里只包含question字段 ,answer、filename和page缺失,需要参赛者回答补充。财报数据库.zip:包含多个PDF文件,这些PDF是真实世界的公司财报,格式上图文混排,内容上包含大量段落、数据表格以及各种图表(如条形图、饼图、折线图等)。test.json中所有问题的答案必须从这些PDF文档中寻找,不能依赖任何外部知识。
输出:
为 test.json 中的每个问题,预测答案(answer) 、来源文件名(filename )和来源页码( page ),然后整理成JSON文件,形如:
[
{
"filename": "xx.pdf",
"page": 1,
"question": "广联达在苏中建设集团南宁龙湖春江天越项目中,具体运用了哪些BIM技术,并取得了哪些成果?",
"answer": "广联达在苏中建设集团南宁龙湖春江天越项目中,具体运用了哪些BIM技术,并取得了哪些成果?"
},
……
]
任务难点
-
如何同时利用图文混排PDF中的图表和文本信息回答问题?
(多模态信息理解、跨模态检索、图文关联推理)在传统的RAG方法中,信息检索的对象只有文本,不能充分利用图表信息,这使得部分需要联合图文才能正确回答的问题无法得到有效回答。
✅Baseline方案如何处理此挑战?
Baseline解决方案的代码中暂时没有考虑图表信息,只使用文本信息。此方案使用PyMuPDF提取PDF文件中的内容。
-
如何准确地从文档中检索可能和问题相关的内容?
检索的内容是LLM回答问题的前提,如果检索的内容本身不包含与问题相关的内容,LLM也无法正确地回答问题。特别地,如果检索得到的K个内容中,只有1个与问题相关,其余K-1个不相关,这将严重干扰LLM的判断,导致LLM基于错误的信息作答。
如何优化检索策略以提高检索返回内容的信噪比?——建议引入重排Re-ranking技术
✅Baseline方案如何处理此挑战?
Baseline解决方案计算文档分片(chunk)的embedding和问题(query)的embedding二者之间的sin相似度,提取sin相似度最高的K个chunk作为检索结果。没有处理信噪比相关的问题。
-
如何基于检索结果正确地回答问题、追溯回答依据?
(答案生成)答案中需要准确地指出答案的来源(具体到文件名、页码)。
✅Baseline方案如何处理此挑战?
Baseline解决方案的代码在整个处理过程中定义并维护变量结构,该结构中使用"filename"和"page"字段记录检索结果的来源文件及具体页码。
Baseline方案设计
Baseline代码的设计将会分成几大块:
-
PDF文档处理
- 从PDF文档中提取文本信息(离线)
- 将文本信息分片(离线)
- 将文本分片转化为嵌入(在线,一次完成)
- 将文本信息和嵌入导入知识库中(在线,一次完成)
-
文档信息检索
- 将查询转化为嵌入(在线,每次查询)
- 计算查询嵌入与知识库中所有文本分片嵌入的相似性,选择相似性最高的K个分片作为检索结果(在线,每次查询)
-
LLM推理回答
使用prompt模板组织检索结果和问题,通过API调用LLM推理服务,获得LLM回答结果(在线,每次查询)
-
回答格式整理
将相关检索结果的来源与LLM回答结果按规定格式(比赛输出要求)保存在json文件中。
Baseline的整体流程图如下:

图源Datawhale教程。“process_pdfs_to_chunks.py”应为“fitz_pipeline_all.py”。“阶段二”的代码由rag_from_page_chunks.py文件实现。
Baseline代码仓库
modelscope仓库(实际使用、可以预览数据):https://www.modelscope.cn/datasets/Datawhale/AISumerCamp_multiModal_RAG.git
作者的github仓库(内容更丰富、功能说明更详细):https://github.com/li-xiu-qi/spark_multi_rag
Baseline运行
1、在Git 环境中初始化并配置Git LFS(Large File Storage)功能
git lfs install
Git LFS是Git的扩展,用于更高效地处理大型二进制文件。把大文件上传到一个独立的LFS存储服务器上,而Git仓库里只存一个指向它的“指针”。
执行此命令后出现Git LFS initialized.即初始化成功(其他多余的输出,如Error: failed to call git rev-parse --git-dir: exit status 128 : fatal: 不是 git 仓库(或者任何父目录):.git可以忽略,不影响实际使用)。
执行此命令后,如果出现提示:
git:'lfs' 不是一个 git 命令。参见 'git --help'。
最相似的命令是
log
说明没有安装Git LFS插件。可以使用sudo apt install git-lfs安装,然后重新执行此命令。
2、下载baseline代码
git clone https://www.modelscope.cn/datasets/Datawhale/AISumerCamp_multiModal_RAG.git
下载完成后,进入AISumerCamp_multiModal_RAG文件夹。
3、配置密钥
打开baseline文件夹 -> 双击打开env.txt文件 -> 在LOCAL_API_KEY处粘贴个人使用的API密钥 -> 保存env.txt文件,然后重命名为.env
4、运行baseline文件
运行baseline.ipynb文件。运行完成后得到rag_top1_pred.json即为结果文件。
Baseline代码(baseline.ipynb及引入的其他文件)理解
目录结构

这里的文件不是每个都用到,真正用到的文件有:
- fitz_pipeline_all.py
- rag_from_page_chunks.py
- get_text_embedding.py
- extract_json_array.py
一、解压数据
!unzip -n datas/财报数据库.zip -d datas/
执行输出如下图所示:

二、安装依赖包
!pip install -r requirements.txt
除了numpy、tqdm、loguru(python日志记录库)等常见的依赖,需要安装的特别依赖有:
-
python-dotenv:读取
.env环境变量文件,将文件中的变量自动加载到os.environ中,供程序使用 -
openai:调用LLM服务的API接口
-
PyMuPDF:处理PDF和其他文档文件,可以进行读取、修改等操作
三、执行fitz_pipeline_all.py文件
!python fitz_pipeline_all.py
此文件功能:将每个PDF文件中每个页面的内容处理为一个chunk,然后存储到./datas/all_pdf_page_chunks.json文件中。
具体代码解读如下:
依赖包
import fitz # fitz即为PyMuPDF
import json
from pathlib import Path
主函数
def main():
'''设置PDF文件输入目录及最终输出json文件路径,然后调用process_pdfs_to_chunks进行处理'''
base_dir = Path(__file__).parent
datas_dir = base_dir / 'datas' # 设置PDF文件的输入目录
chunk_json_path = base_dir / 'all_pdf_page_chunks.json' # 设置最终输出json的路径
process_pdfs_to_chunks(datas_dir, chunk_json_path)
process_pdfs_to_chunks函数
def process_pdfs_to_chunks(datas_dir: Path, output_json_path: Path):
"""
使用 PyMuPDF 直接从 PDF 提取每页文本,并生成最终的 JSON 文件。
Args:
datas_dir (Path): 包含 PDF 文件的输入目录。
output_json_path (Path): 最终输出的 JSON 文件路径。
"""
all_chunks = []
# 递归查找 datas_dir 目录下的所有 .pdf 文件
pdf_files = list(datas_dir.rglob('*.pdf')) # rglob方法:递归地在指定目录及其子目录下查找所有符合指定模式的文件并返回Path对象
if not pdf_files:
print(f"警告:在目录 '{datas_dir}' 中未找到任何 PDF 文件。")
return
print(f"找到 {len(pdf_files)} 个 PDF 文件,开始处理...")
# 【重点】
for pdf_path in pdf_files:
file_name_stem = pdf_path.stem # 文件名(不含扩展名)
full_file_name = pdf_path.name # 完整文件名(含扩展名)
print(f" - 正在处理: {full_file_name}")
try:
# 使用 with 语句确保文件被正确关闭(with...as...语句确保文件处理完后自动关闭)
with fitz.open(pdf_path) as doc: # 使用fitz(即PyMuPDF)打开Path对象对应的文件,返回fitz.Document对象,这个对象代表整个PDF文档
# 遍历 PDF 的每一页
for page_idx, page in enumerate(doc): # page是fitz.Page对象,代表文档中的一页
# 提取当前页面的所有文本
content = page.get_text("text")
# 如果页面没有文本内容,则跳过
if not content.strip(): # strip()用于去除字符串首尾的空白字符
continue
# 构建符合最终格式的 chunk 字典
chunk = {
"id": f"{file_name_stem}_page_{page_idx}",
"content": content,
"metadata": {
"page": page_idx, # 0-based page index
"file_name": full_file_name
}
}
all_chunks.append(chunk)
except Exception as e:
print(f"处理文件 '{pdf_path}' 时发生错误: {e}")
# 确保输出目录存在
output_json_path.parent.mkdir(parents=True, exist_ok=True) # parents=True:如果父目录不存在,一并创建
# 将所有 chunks 写入一个 JSON 文件
with open(output_json_path, 'w', encoding='utf-8') as f:
json.dump(all_chunks, f, ensure_ascii=False, indent=2)
print(f"\n处理完成!所有内容已保存至: {output_json_path}")
简单而言,这个函数对PDF的处理方式较为简单,通过使用content = page.get_text("text")就完成了对一页PDF内容的提取。
此外,这个函数也初步定义了每个文本分片chunk的字典存储格式,为后续LLM的回答依据溯源服务。chunk字典的格式定义如下:
chunk = {
"id": (str),
"content": (str),
"metadata": {
"page": (int), # 0-based page index
"file_name": (str)
}
}
四、执行rag_from_page_chunks.py文件
!python rag_from_page_chunks.py
依赖包
import json
import os
import hashlib
from typing import List, Dict, Any
from tqdm import tqdm
import sys
import concurrent.futures
import random
import numpy as np
import time
from get_text_embedding import get_text_embedding
from dotenv import load_dotenv
from openai import OpenAI
# 统一加载项目根目录的.env
load_dotenv()
这里引入了get_text_embedding.py中的函数。
SimpleRAG类
SimpleRAG类主要负责:
-
(setup函数)
创建知识库:以vector_store为知识库,将所有chunk映射为embedding后存储到vector_store中,作为可检索的文档。
原始的baseline代码在将所有chunk映射为embedding的这个环节中,每次重新运行都要重新映射,造成不必要的时间消耗。我将代码修改为第一次运行时将映射结果存储到npy文件中,后续如果需要重新运行可以直接加载npy文件中的embedding,不需要重新映射。
-
(generate_answer函数)
-
检索知识库:接收到一个新的查询(query)时,使用相同的embedding_model将query转化为embedding,然后调用知识库中的search方法,找出跟query最相关的K个chunk。
-
整理LLM输入文本:检索得到与query最相关的K个chunk后,使用prompt模板组织文本作为LLM的输入。
-
使用LLM推理:使用API向LLM服务平台发送请求,在请求时,为了避免因为网络断连、请求频繁等问题而失败,需要添加异常处理和重试机制(这是我自己添加的,这个处理机制虽然在原来的baseline代码里没有,但是在大批量任务请求的场景中非常重要,所以我自己添加了)。
-
整理LLM推理输出:获取LLM的回答后,从LLM的回答中解析json对象并提取其中的’answer’、‘filename’、'page’元素,以字典结构返回针对该次query的输出。
检索的时候检索相关文档不止1个,最后组织LLM形成答案的时候只填1个文档来源,这是根据LLM的回答决定的,由LLM回答它认为的文档来源。如果无法提取这部分信息,则默认是第1个检索文档来源来充当。
-
SimpleRAG类在初始化时也初始化了3个对象:PageChunkLoader、EmbeddingModel、SimpleVectorStore。这3个对象在后续会继续介绍。
class SimpleRAG:
def __init__(self, chunk_json_path: str, model_path: str = None, batch_size: int = 32):
'''初始化PageChunkLoader、EmbeddingModel、SimpleVectorStore'''
self.loader = PageChunkLoader(chunk_json_path)
self.embedding_model = EmbeddingModel(batch_size=batch_size)
self.vector_store = SimpleVectorStore()
def setup(self):
'''
加载所有chunk、为所有chunk生成embedding、将chunk和embedding存储到vector_store中
'''
print("加载所有页chunk...")
chunks = self.loader.load_chunks()
print(f"共加载 {len(chunks)} 个chunk")
embeddings_cache_path = "./chunks_embeddings.npy"
if os.path.exists(embeddings_cache_path):
embeddings = np.load(embeddings_cache_path, allow_pickle=True)
else:
print("生成嵌入...")
embeddings = self.embedding_model.embed_texts([c['content'] for c in chunks])
np.save(embeddings_cache_path, np.array(embeddings))
print("存储向量...")
self.vector_store.add_chunks(chunks, embeddings)
print("RAG向量库构建完成!")
def generate_answer(self, question: str, top_k: int = 3) -> Dict[str, Any]:
"""
检索+大模型生成式回答,返回结构化结果
Args:
question (str)
top_k (int)
Returns:
answer_dict (Dict{
"question" (str):
"answer" (str):
"filename" (str):
"page" (int):
"retrieval_chunks" (List[Dict]):
})
"""
qwen_api_key = os.getenv('LOCAL_API_KEY')
qwen_base_url = os.getenv('LOCAL_BASE_URL')
qwen_model = os.getenv('LOCAL_TEXT_MODEL') # "Qwen/Qwen3-8B"
if not qwen_api_key or not qwen_base_url or not qwen_model:
raise ValueError('请在.env中配置LOCAL_API_KEY、LOCAL_BASE_URL、LOCAL_TEXT_MODEL')
# 为question生成embedding
q_emb = self.embedding_model.embed_text(question)
chunks = self.vector_store.search(q_emb, top_k)
# 拼接检索内容,带上元数据
context = "\n".join([
f"[文件名]{c['metadata']['file_name']} [页码]{c['metadata']['page']}\n{c['content']}" for c in chunks
])
# 明确要求输出JSON格式 answer/page/filename
prompt = (
f"你是一名专业的金融分析助手,请根据以下检索到的内容回答用户问题。\n"
f"请严格按照如下JSON格式输出:\n"
f'{{"answer": "你的简洁回答", "filename": "来源文件名", "page": "来源页码"}}'"\n"
f"检索内容:\n{context}\n\n问题:{question}\n"
f"请确保输出内容为合法JSON字符串,不要输出多余内容。"
)
client = OpenAI(api_key=qwen_api_key, base_url=qwen_base_url)
# 重试机制和异常处理
max_retries = 3
retry_delay = 2
for attempt in range(max_retries):
try:
completion = client.chat.completions.create(
model=qwen_model,
messages=[
{"role": "system", "content": "你是一名专业的金融分析助手。"},
{"role": "user", "content": prompt}
],
temperature=0.2,
max_tokens=1024
)
break
except Exception as e:
print(f"API请求失败 (尝试 {attempt + 1}/{max_retries}): {str(e)}")
if attempt < max_retries:
time.sleep(retry_delay)
retry_delay *= 2
import json as pyjson
from extract_json_array import extract_json_array
raw = completion.choices[0].message.content.strip()
# 用 extract_json_array 提取 JSON 对象
json_str = extract_json_array(raw, mode='objects')
if json_str:
try:
arr = pyjson.loads(json_str)
# 只取第一个对象
if isinstance(arr, list) and arr: # 成功提取出包含至少一个JSON对象的数组
j = arr[0]
answer = j.get('answer', '')
filename = j.get('filename', '')
page = j.get('page', '')
else: # 提取JSON对象数组失败或数组为空
answer = raw
filename = chunks[0]['metadata']['file_name'] if chunks else '' # 强制使用检索出的最相关的第1个chunk
page = chunks[0]['metadata']['page'] if chunks else ''
except Exception:
answer = raw
filename = chunks[0]['metadata']['file_name'] if chunks else ''
page = chunks[0]['metadata']['page'] if chunks else ''
else:
answer = raw
filename = chunks[0]['metadata']['file_name'] if chunks else ''
page = chunks[0]['metadata']['page'] if chunks else ''
# 结构化输出
return {
"question": question,
"answer": answer,
"filename": filename,
"page": page,
"retrieval_chunks": chunks
}
extract_json_array函数(extract_json_array.py)
此函数负责从字符串中提取第一个JSON数组或由多个对象组成的数组。
主要由各种正则表达式组成提取规则,这些规则代码比较繁琐,就不展开分析了,需要使用时可以使用GPT生成。
PageChunkLoader类
PageChunkLoader类主要负责:
- (load_chunks函数)将json文件中存储的chunk加载出来,供后续EmbeddingModel类转化为embedding。
class PageChunkLoader:
def __init__(self, json_path: str):
'''初始化存储chunk的json路径'''
self.json_path = json_path
def load_chunks(self) -> List[Dict[str, Any]]:
'''
将json文件中的每个chunk以Dict类型导出
Dict{
"id": (str)
"content": (str)
"metadata": {
"page": (int)
"file_name": (str)
}
}
'''
with open(self.json_path, 'r', encoding='utf-8') as f:
return json.load(f)
EmbeddingModel类
EmbeddingModel类主要负责:
-
(embed_texts函数)批量化地将chunk转化为embedding
-
(embed_text函数)每次接收到新的查询query时,将query转化为embedding
(以上两个函数本质是一样的)
class EmbeddingModel:
def __init__(self, batch_size: int = 64):
'''初始化EmbeddingModel的API配置'''
self.api_key = os.getenv('LOCAL_API_KEY')
self.base_url = os.getenv('LOCAL_BASE_URL')
self.embedding_model = os.getenv('LOCAL_EMBEDDING_MODEL') # 目前是BAAI/bge-m3
self.batch_size = batch_size
if not self.api_key or not self.base_url:
raise ValueError('请在.env中配置LOCAL_API_KEY和LOCAL_BASE_URL')
def embed_texts(self, texts: List[str]) -> List[List[float]]:
'''
Args:
texts (List[str]): chunk的内容的列表
Returns:
embed_texts (List[List[float]]): chunk的内容的embedding的列表
'''
return get_text_embedding(
texts,
api_key=self.api_key,
base_url=self.base_url,
embedding_model=self.embedding_model,
batch_size=self.batch_size
)
def embed_text(self, text: str) -> List[float]:
return self.embed_texts([text])[0]
get_text_embedding函数(get_text_embedding.py)
这个文件使用平台提供的embedding模型API服务,通过调用平台上的embedding模型API完成embedding的映射,避免在本地直接部署embedding模型。
from http import client
from openai import OpenAI
from dotenv import load_dotenv
import os
import hashlib
from typing import List, Dict, Optional, Tuple
import json
# LOCAL_API_KEY,LOCAL_BASE_URL,LOCAL_TEXT_MODEL,LOCAL_EMBEDDING_MODEL
load_dotenv() # 加载环境变量(可选,用户可自行读取)
def get_openai_client(api_key: str, base_url: str) -> OpenAI:
"""
获取 OpenAI 客户端,必须传递 api_key 和 base_url
"""
if not api_key or not base_url:
raise ValueError("api_key 和 base_url 必须显式传递!")
return OpenAI(
api_key=api_key,
base_url=base_url,
)
from tqdm import tqdm
def batch_get_embeddings(
texts: List[str],
batch_size: int = 64,
api_key: str = None,
base_url: str = None,
embedding_model: str = None
) -> List[List[float]]:
"""
批量获取文本的嵌入向量
:param texts: 文本列表
:param batch_size: 批处理大小
:param api_key: 可选,自定义 API KEY
:param base_url: 可选,自定义 BASE URL
:param embedding_model: 可选,自定义嵌入模型
:return: 嵌入向量列表
"""
if not api_key or not base_url or not embedding_model:
raise ValueError("api_key、base_url、embedding_model 必须显式传递!")
all_embeddings = []
client = get_openai_client(api_key, base_url) # 获取openai客户端
total = len(texts)
if total == 0:
return []
iterator = range(0, total, batch_size)
if total > 1:
iterator = tqdm(iterator, desc="Embedding", unit="batch")
import time
from openai import RateLimitError
for i in iterator:
batch_texts = texts[i:i + batch_size]
retry_count = 0
while True:
try:
response = client.embeddings.create( # 一次性处理batch_size个texts
model=embedding_model,
input=batch_texts
)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
break
except RateLimitError as e:
retry_count += 1
print(f"RateLimitError: {e}. 等待10秒后重试(第{retry_count}次)...")
time.sleep(10)
return all_embeddings
def get_text_embedding(
texts: List[str],
api_key: str = None,
base_url: str = None,
embedding_model: str = None,
batch_size: int = 64
) -> List[List[float]]:
"""
获取文本的嵌入向量,支持批次处理,保持输出顺序与输入顺序一致
:param texts: 文本列表
:param api_key: 可选,自定义 API KEY
:param base_url: 可选,自定义 BASE URL
:param embedding_model: 可选,自定义嵌入模型
:param batch_size: 批处理大小
:return: 嵌入向量列表
"""
if not api_key or not base_url or not embedding_model:
raise ValueError("api_key、base_url、embedding_model 必须显式传递!")
# 直接批量获取文本的embedding,不做缓存
return batch_get_embeddings(
texts,
batch_size=batch_size,
api_key=api_key,
base_url=base_url,
embedding_model=embedding_model
)
SimpleVectorStore类
SimpleVectorStore类主要负责:
- (add_chunks函数)将chunk、embedding分别加入知识库中
- (search函数)接收query的embedding,然后计算query与知识库中所有文本chunk的sim相似度,返回相似度最高的前K个chunk
class SimpleVectorStore:
def __init__(self):
self.embeddings = []
self.chunks = []
def add_chunks(self, chunks: List[Dict[str, Any]], embeddings: List[List[float]]):
self.chunks.extend(chunks)
self.embeddings.extend(embeddings)
def search(self, query_embedding: List[float], top_k: int = 3) -> List[Dict[str, Any]]:
'''
通过计算query_embedding和vectorStore中embedding的sim相似度,返回sim相似度最高的top_k个embedding
Args:
query_bemdding (List[float])
top_k (int)
Returns:
searched_chunk (List[Dict[str, Any]])
'''
from numpy import dot
from numpy.linalg import norm
# import numpy as np
if not self.embeddings:
return []
emb_matrix = np.array(self.embeddings) # emb_matrix.shape=[num_samples, embed_dim]
query_emb = np.array(query_embedding) # query_emb.shape=[embed_dim, ]
sims = emb_matrix @ query_emb / (norm(emb_matrix, axis=1) * norm(query_emb) + 1e-8) # sims.shape=[num_samples, ]
idxs = sims.argsort()[::-1][:top_k] # sims.argsort():对sims数组进行排序,默认是升序;[(start):(end):-1(step)]:反转数组(变成降序排序)
return [self.chunks[i] for i in idxs]
主函数
初始的baseline代码推理完规定数量的全部题目后才会保存结果文件,这种处理方式存在程序因网络故障(如:openai.APITimeoutError: Request timed out.)等原因而中断时已推理的结果被浪费的问题。因此,我修改了baseline代码,每次(少量并发)推理完之后就保存结果文件。具体如下所示:
if __name__ == '__main__':
# 路径可根据实际情况调整
chunk_json_path = "./all_pdf_page_chunks.json"
rag = SimpleRAG(chunk_json_path)
rag.setup()
# 控制测试时读取的题目数量,默认只随机抽取10个,实际跑全部时设为None
TEST_SAMPLE_NUM = None # 10 # 设置为None则全部跑
FILL_UNANSWERED = True # 未回答的也输出默认内容
# 批量评测脚本:读取测试集,检索+大模型生成,输出结构化结果
test_path = "./datas/test.json"
if os.path.exists(test_path):
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)
results_to_save = []
raw_out_path = "./rag_top1_pred_raw.json"
# 记录所有原始索引
all_indices = list(range(len(test_data)))
selected_indices = all_indices
if TEST_SAMPLE_NUM is not None and TEST_SAMPLE_NUM > 0: # 随机抽取部分题目用于测试
if len(test_data) > TEST_SAMPLE_NUM:
selected_indices = sorted(random.sample(all_indices, TEST_SAMPLE_NUM))
elif TEST_SAMPLE_NUM is None: # 需要测试所有题目
if os.path.exists(raw_out_path):
# 如果已经有"./rag_top1_pred_raw.json"文件,打开并检查哪些问题的idx不存在,重新推理这些问题
with open(raw_out_path, 'r', encoding='utf-8') as f:
exist_results = json.load(f)
for exist_result in exist_results:
selected_indices.remove(exist_result[0])
results_to_save.append(exist_result)
# 不存在这个文件,说明没有任何问题被推理,要全部重新推理
print(f"selected_indices: {selected_indices}")
def save_results(results, raw_out_path):
tmp = raw_out_path + ".tmp"
with open(tmp, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
os.replace(tmp, raw_out_path) # 用新文件替换旧文件
def process_one(idx):
item = test_data[idx]
question = item['question']
tqdm.write(f"[{selected_indices.index(idx)+1}/{len(selected_indices)}] 正在处理: {question[:30]}...")
result = rag.generate_answer(question, top_k=5)
return idx, result
if selected_indices:
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: # 创建一个线程池,最多同时运行5个线程,用于并发执行任务
futures = [executor.submit(process_one, idx) for idx in selected_indices]
for i, future in enumerate(
tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
desc="并发批量生成"
)
):
idx, result = future.result()
results_to_save.append((idx, result))
save_results(results_to_save, raw_out_path) # 每推理完成1个问题就保存1次,避免已推理的问题被浪费
# 只保留结果部分,并去除 retrieval_chunks 字段
idx2result = {result[0]: {k: v for k, v in result[1].items() if k != "retrieval_chunks"} for result in results_to_save}
# result[0]是idx,result[1]是result
filtered_results = []
for idx, item in enumerate(test_data):
if idx in idx2result:
filtered_results.append(idx2result[idx])
elif FILL_UNANSWERED:
# 未被回答的,补默认内容
filtered_results.append({
"question": item.get("question", ""),
"answer": "",
"filename": "",
"page": "",
})
# 输出结构化结果到json
out_path = "./rag_top1_pred.json"
with open(out_path, 'w', encoding='utf-8') as f:
json.dump(filtered_results, f, ensure_ascii=False, indent=2)
print(f'已输出结构化检索+大模型生成结果到: {out_path}')
else:
print("datas/test.json 不存在")

被折叠的 条评论
为什么被折叠?



