代码结构
prompt 定义
from langchain_core.prompts import ChatPromptTemplate
layout_prompt = ChatPromptTemplate([
("system", """您作为学术架构师,具备跨学科论文指导经验,擅长将研究构思转化为严谨的学术框架。请根据研究要素构建具有方法论指导价值的论文架构"""),
("human", """
# 论文大纲及章节规划生成指令
## 输入参数说明
• 论文标题:"{title}"
• 描述信息:"{description}"
## 核心任务
生成具备可扩展性的论文框架体系,包含:
1. 逻辑严谨的6章节大纲(一级目录)
2. 各章节详细的内容建设规划
## 大纲生成规范(请逐条遵守)
[结构要求]
▸ 必须包含:引言、相关工作、方法、实验与结果、讨论、结论
▸ 章节间需形成"问题提出→方法构建→验证分析→理论升华"的逻辑闭环
▸ 每个一级目录需标注3个核心要点(用中文分号分隔)
[内容要求]
▶ 引言:须包含研究背景破题、理论缺口、方法论创新三维度
▶ 相关工作:需设计对比分析矩阵,明确前人研究局限
▶ 方法:应区分理论框架与技术实现两个层面
▶ 实验:需预设对照组、评估指标、可复现性方案
▶ 讨论:要求建立与引言的理论呼应机制
▶ 结论:需提炼方法论贡献与实践指导价值
## 章节简介标准
[内容维度]
① 研究任务:本章的核心学术使命
② 内容要素:需涵盖的关键组件(按优先级排序)
③ 衔接设计:与前后章节的逻辑接口
④ 创新载体:区别于常规论文的突破点
[质量要求]
✓ 各章节描述需形成"方法-实验-结论"证据链
✓ 保持技术细节与理论深度的平衡
✓ 突出从具体实验到一般规律的研究跃迁
## 输出格式(严格遵循)
{format_instructions}
感谢你的创作 ~
""")
])
outline_prompt = ChatPromptTemplate([
("system", """你是一名论文导师, 非常擅长写论文, 用户将会给到你(层级标题/当前标题/描述信息)三部分内容, 请根据用户输入的信息生成 大纲和大纲下的简介 的列表
参考信息介绍:
层级标题: 代表当前的层级标题一般格式为: ["文章标题", "一级标题", "二级标题"]
当前标题: 当前需要生成大纲的标题, 一般格式为(三级标题)
描述信息: 描述当前标题需要做哪些内容
大纲列表生成要求如下:
大纲应具有逻辑性, 每个部分应包含关键点
确保大纲结构清晰,便于后续扩展成完整的论文
大纲只需要一次返回一级目录, 不需要返回多级
大纲对应简介列表生成要求如下:
简介内容为当前标题下的描述
简介介绍应该本段需要编写哪些内容的方向和指定计划
简介应当前后呼应(譬如实验准备和实验细节以及最后的实验结论)
简介内容应当详细, 该描述需要准确和全面, 对于本章节具有完全概括的能力
整体要求如下:
认真思考给到的“当前标题”和“描述信息”
生成输出的内容需要严谨, 请结合“层级标题”以内的内容作为边界, 可以参考当前所有的层级, 但是请注意不要拓展太多或者产生超纲的行为
内容可以是创新也可以是现有资料
"""),
("human", """请帮我生成大纲和对应的大纲简介谢谢:
层级标题: {deep_title}
当前标题: {title}
描述信息: {description}
并严格使用如下格式输出, 请不要输出严用格式以外的内容:
{format_instructions}
感谢你的创作 ~
""")
])
content_prompt = ChatPromptTemplate([
("system", """
【角色定义】
你是一位资深论文指导专家,具备跨学科研究经验,擅长将碎片化信息整合为严谨的学术论述。请根据用户提供的层级架构和内容描述,生成符合学术规范的段落内容。
【输入信息说明】
用户将依次提供:
1. 层级路径:["文章标题", "一级标题", "二级标题"]组成的树状结构
2. 当前节点:需撰写的具体标题名称及层级位置
3. 内容纲要:该段落需达成的论述目标与关键要素
【内容生成规范】
1. 结构控制:
- 以当前节点为圆心,向上兼容父级标题逻辑,向下预留子级展开空间
- 段落长度控制在200-800字,每段聚焦1个核心论点
- 使用学术过渡词("值得注意的是""需要强调的是")衔接段落
2. 术语运用:
- 采用《学科分类与代码》国家标准中的规范术语
- 新兴概念首次出现时标注英文原文及提出年份(例:具身认知(Embodied Cognition, Lakoff 1999))
- 术语密度控制在每百字3-5个专业词汇
3. 论述展开:
- 论点-论据-论证三层结构完整
- 实验数据类:说明样本量/显著性水平/效应值
- 理论分析类:引用2-3个经典理论框架
- 应用研究类:结合典型案例进行机理阐释
【内容质量要求】
1. 创新性处理:
- 对成熟理论采用"背景假设-核心主张-当代挑战"分析框架
- 对新兴领域运用"技术路线-关键瓶颈-突破方向"论述模式
- 交叉学科内容建立维度矩阵(例:时间/空间/主体三个分析维度)
2. 边界控制:
- 学术深度不超过该节点在层级树中的位置等级
- 二级标题内容需预留至少3个可量化研究空白点
- 方法论部分保持范式中立(实证/解释/批判)
【输出格式规范】
1. 段落构建:
[主题句]→[理论支撑]→[论证展开]→[承启句]
示例:
"数字化转型重构了企业知识管理范式(主题句)。基于动态能力观(Teece, 1997),组织学习机制正从单环学习向双环学习演进(理论支撑)。笔者对长三角412家制造企业的调研显示,采用数据中台的企业知识迭代速度提升37.2%(p<0.01),这印证了Nonaka(2021)提出的知识涌现理论(论证展开)。这种变革对传统管理体系提出了三方面挑战(承启句)。"
2. 特殊处理:
- 争议观点标注学术阵营(例:新古典学派主张...,而演化经济学派认为...)
- 重大发现标注学界反响(例:该结论被《管理世界》2023年度综述列为十大争议议题)
- 关键技术说明应用场景(例:区块链溯源技术在农产品供应链中的典型应用包括...)
"""),
("human", """请帮我生成内容谢谢:
层级标题: {deep_title}
当前标题: {title}
描述信息: {description}
并严格使用如下格式输出, 请不要输出严用格式以外的内容:
请直接返回文本即可, 不需要返回标题
历史对话如下:
{history}
感谢你的创作 ~
""")
])
方法定义
from typing import List
from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.runnables import RunnableWithMessageHistory
from langchain.memory import VectorStoreRetrieverMemory
from langchain_openai import ChatOpenAI
from paper_agent.bge_embedding import init_vector_store, init_rerank
from paper_agent.model import AgentState, ArticleStruct, OutlineModel, SectionModel
from langchain.globals import set_debug
from paper_agent.prompt import outline_prompt, layout_prompt, content_prompt
set_debug(True)
llm = ChatOpenAI(
api_key="",
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model="qwen2.5-72b-instruct",
verbose=True
)
store = {}
def get_session_history(session_id) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = InMemoryChatMessageHistory()
return store[session_id]
vector_store_memory, rerank_memory = init_vector_store(), init_rerank()
retriever_memory = VectorStoreRetrieverMemory(retriever=vector_store_memory.as_retriever(search_kwargs=dict(k=5)))
outline_parser = PydanticOutputParser(pydantic_object=OutlineModel)
outline_history_chain = RunnableWithMessageHistory(outline_prompt | llm, get_session_history, input_messages_key="title", history_messages_key="history")
layout_history_chain = RunnableWithMessageHistory(layout_prompt | llm, get_session_history, input_messages_key="title", history_messages_key="history")
content_history_chain = RunnableWithMessageHistory(content_prompt | llm, get_session_history, input_messages_key="title", history_messages_key="history")
def start_input_node(state: AgentState):
return {
"input": input("文章标题: "),
"description": input("文章简介: ")
}
def build_article(sections: List[SectionModel]):
content = []
for section in sections:
children = []
if section.sections:
children = build_article(section.sections)
if children:
content.append(
ArticleStruct(title=section.title, description=section.description, confirmed_outline=False, children=children)
)
else:
content.append(
ArticleStruct(title=section.title, description=section.description, confirmed_outline=False)
)
return content
def gen_content(paper_title, current_title, current_description, content, level=2, deep_title=None):
"""
:param paper_title: 论文标题
:param current_title: 当前子标题
:param current_description: 当前子描述
:param content:
:param level:
:param deep_title:
:return:
"""
if not content:
if paper_title == current_title:
chain_res: AIMessage = layout_history_chain.invoke({
"title": current_title,
"description": current_description,
"format_instructions": outline_parser.get_format_instructions(),
}, config={"configurable": {"session_id": paper_title}})
else:
chain_res: AIMessage = outline_history_chain.invoke({
"deep_title": deep_title,
"title": current_title,
"description": current_description,
"format_instructions": outline_parser.get_format_instructions(),
}, config={"configurable": {"session_id": paper_title}})
outline_model: OutlineModel = outline_parser.parse(chain_res.content)
content = build_article(outline_model.sections)
return {"content": content, "gen_end": False}
for item in content:
if not item.content:
chain_res: AIMessage = content_history_chain.invoke({
"deep_title": deep_title,
"title": item.title,
"description": current_description,
"format_instructions": outline_parser.get_format_instructions(),
"history": (retriever_memory.load_memory_variables({"input": current_description}) or {}).get("history", "")
}, config={"configurable": {"session_id": paper_title}})
item.content = chain_res.content
item.confirmed_content = False
return {"content": content, "gen_end": False}
if isinstance(item.children, list):
res = gen_content(paper_title, item.title, item.description, item.children, level + 1, [*deep_title, item.title])
if not res["gen_end"]:
item.children = res["content"]
return {"content": content, "gen_end": False}
return {"content": content, "gen_end": True}
def gen_content_node(state: AgentState):
"""根据用户的想法生成大纲"""
title = state["input"]
description = state["description"]
return gen_content(title, title, description, state["content"], level=2, deep_title=[title])
def gen_response(content: List[ArticleStruct], level=2):
response = ""
for item in content:
response += f"{level * '#'} {item.title}\n"
if item.content:
response += f"{item.content}\n"
if item.children:
response += gen_response(item.children, level=level + 1)
return response
def gen_response_node(state: AgentState):
"""根据content生成最终结果"""
response = f"# {state['input']}\n"
content = state["content"]
response += gen_response(content, level=2)
return {"response": response}
def user_confirm(content: List[ArticleStruct]):
confirmed_outline = content[0].confirmed_outline
if not confirmed_outline:
while True:
outlines = [f"{i}. {item.title}" for i, item in enumerate(content)]
outline_str = "\n".join(outlines)
user_input = input(f'{outline_str}\n, 请确认当前大纲(确认/重新生成/修改)?')
if user_input == "确认":
for item in content:
item.confirmed_outline = True
return {"content": content, "check_end": True}
elif user_input == "重新生成":
return {"content": [], "check_end": True}
elif user_input == "修改":
i = int(input("请输入需要修改的编号: "))
title = input("请输入新的标题: ")
content[i].title = title
for item in content:
if not item.confirmed_content:
user_input = input(f'{item.title}\n{item.content}\n, 请确认当前内容(确认/重新生成/修改/生成细节)?')
if user_input == "确认":
item.confirmed_content = True
retriever_memory.save_context({"system": item.description, "ai": item.content})
return {"content": content, "check_end": True}
elif user_input == "重新生成":
item.content = ""
item.confirmed_content = False
return {"content": content, "check_end": True}
elif user_input == "修改":
user_input_content = input("请输入新的内容: ")
item.content = user_input_content
elif user_input == "生成细节":
item.confirmed_content = True
item.children = []
return {"content": content, "check_end": True}
if item.children:
print("user_confirm item.children: ", item.children)
res = user_confirm(item.children)
if res["check_end"]:
item.children = res["content"]
return {"content": content, "check_end": True}
return {"content": content, "check_end": False}
def user_confirm_node(state: AgentState):
"""请求用户确认是否继续"""
content, is_confirm = state["content"], False
while not is_confirm:
res = user_confirm(content)
content = res["content"]
is_confirm = not res["check_end"]
return {"content": content}
embedding 历史记录
from typing import List
import faiss
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain_community.vectorstores import FAISS
from langchain_community.docstore import InMemoryDocstore
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.memory import VectorStoreRetrieverMemory
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
import torch
class CustomTransformerEmbeddings(Embeddings):
"""自定义 Transformers 嵌入模型"""
def __init__(self, model_name: str, device: str = "cuda:0"):
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(device)
self.model.eval()
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""文档嵌入"""
return self._embed(texts)
def embed_query(self, text: str) -> List[float]:
"""查询嵌入"""
return self._embed([text])[0]
def _embed(self, texts: List[str]) -> List[List[float]]:
"""实际嵌入实现"""
inputs = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = self._mean_pooling(outputs, inputs['attention_mask'])
return embeddings.cpu().numpy().tolist()
def _mean_pooling(self, model_output, attention_mask):
"""池化方法"""
token_embeddings = model_output.last_hidden_state
input_mask_expanded = (
attention_mask
.unsqueeze(-1)
.expand(token_embeddings.size())
.float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class CrossEncoderReranker:
"""基于交叉编码器的重排序器"""
def __init__(self, model_name: str, device: str = "cuda:0"):
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
self.model.eval()
def rerank(self, query: str, documents: List[Document], top_k: int = 5) -> List[Document]:
"""重排序文档"""
pairs = [(query, doc.page_content) for doc in documents]
scores = self._predict(pairs)
scored_docs = list(zip(documents, scores))
scored_docs.sort(key=lambda x: x[1], reverse=True)
return [doc for doc, score in scored_docs[:top_k]]
def _predict(self, pairs: List[tuple]) -> List[float]:
"""计算相关性分数"""
inputs = self.tokenizer(
pairs,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
scores = torch.sigmoid(outputs.logits).squeeze().cpu().numpy()
return scores.tolist()
def init_vector_store():
embedding_model = CustomTransformerEmbeddings("~/Documents/models/BAAI/bge_m3", "cpu")
index = faiss.IndexFlatL2(len(embedding_model.embed_query("hello world")))
return FAISS(
embedding_function=embedding_model,
index=index,
docstore=InMemoryDocstore(),
index_to_docstore_id={}
)
def init_rerank():
return CrossEncoderReranker("~/Documents/models/BAAI/bge_reranker_v2_m3", "cpu")
if __name__ == "__main__":
embedding_model = CustomTransformerEmbeddings("~/Documents/models/BAAI/bge_m3", "cpu")
reranker = CrossEncoderReranker("~/Documents/models/BAAI/bge_reranker_v2_m3", "cpu")
index = faiss.IndexFlatL2(len(embedding_model.embed_query("hello world")))
vector_store = FAISS(
embedding_function=embedding_model,
index=index,
docstore=InMemoryDocstore(),
index_to_docstore_id={}
)
memory = VectorStoreRetrieverMemory(retriever=vector_store.as_retriever(search_kwargs=dict(k=1)))
memory.save_context({"input": "What is FAISS?"}, {"output": "FAISS is FAISS"})
memory.save_context({"input": "What is LangChain?"}, {"output": "LangChain is a framework for developing LLM applications"})
docs = memory.load_memory_variables({"input": "What is LangChain?"})
print(docs)
结构体定义
from typing import TypedDict, Optional, List
from pydantic import Field, BaseModel
"""
1. 用户输入题目和简介, 开始生成
2. 生成大纲 -> 用户确认
3. 根据大纲逐步生成细节 -> 用户确认
4. 输出
"""
class SectionModel(BaseModel):
title: str = Field(description="标题, 在20个字以内")
description: str = Field(description="内容简介, 在200字以内")
sections: List['SectionModel'] = Field(description="大纲子项")
class OutlineModel(BaseModel):
sections: List[SectionModel] = Field(description="大纲列表(章节)")
class ArticleStruct(BaseModel):
title: str = Field(description="标题, 在20个字以内")
description: str = Field(description="内容简介, 在200字以内")
content: str = Field(description="正文内容", default="")
children: List['ArticleStruct'] = Field(description="子内容", default=None)
confirmed_outline: bool = Field(description="大纲是否确认", default=False)
confirmed_content: bool = Field(description="内容是否确认", default=False)
class AgentState(TypedDict):
input: str
description: str
content: List[ArticleStruct]
gen_end: bool
response: str
主流程构建
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, END, START
from langchain.prompts import ChatPromptTemplate
from paper_agent.func_call import gen_content_node, user_confirm_node, \
start_input_node, gen_response_node
from paper_agent.model import AgentState
def main():
"""
用户输入 标题/简介
生成大纲 用户确认 -> 大纲细化/重新生成/确认
大纲细化 -> 生成大纲
重新生成 -> 生成大纲
确认 -> 生成正文
"""
builder = StateGraph(AgentState)
builder.add_node("start_input_node", start_input_node)
builder.add_node("gen_content_node", gen_content_node)
builder.add_node("user_confirm_node", user_confirm_node)
builder.add_node("gen_response_node", gen_response_node)
builder.add_edge("start_input_node", "gen_content_node")
builder.add_edge("user_confirm_node", "gen_content_node")
builder.add_conditional_edges(
"gen_content_node", lambda state: "gen_response_node" if state["gen_end"] else "user_confirm_node", {
"gen_response_node": "gen_response_node",
"user_confirm_node": "user_confirm_node"
})
builder.set_entry_point("start_input_node")
app = builder.compile(debug=True)
result = app.invoke({
"input": "",
"description": "",
"content": [],
"gen_end": False,
"response": ""
}, {"recursion_limit": 100000})
print("result: ", result)
output_path = f"./paper/{result['input']}.md"
with open(output_path, "w") as file:
file.write(result["response"])
if __name__ == "__main__":
main()