Langgraph 论文生成最佳实践

代码结构

prompt 定义

from langchain_core.prompts import ChatPromptTemplate

layout_prompt = ChatPromptTemplate([
    ("system", """您作为学术架构师,具备跨学科论文指导经验,擅长将研究构思转化为严谨的学术框架。请根据研究要素构建具有方法论指导价值的论文架构"""),
    ("human", """
# 论文大纲及章节规划生成指令

## 输入参数说明
• 论文标题:"{title}"
• 描述信息:"{description}"

## 核心任务
生成具备可扩展性的论文框架体系,包含:
1. 逻辑严谨的6章节大纲(一级目录)
2. 各章节详细的内容建设规划

## 大纲生成规范(请逐条遵守)
[结构要求]
▸ 必须包含:引言、相关工作、方法、实验与结果、讨论、结论
▸ 章节间需形成"问题提出→方法构建→验证分析→理论升华"的逻辑闭环
▸ 每个一级目录需标注3个核心要点(用中文分号分隔)

[内容要求]
▶ 引言:须包含研究背景破题、理论缺口、方法论创新三维度
▶ 相关工作:需设计对比分析矩阵,明确前人研究局限
▶ 方法:应区分理论框架与技术实现两个层面
▶ 实验:需预设对照组、评估指标、可复现性方案
▶ 讨论:要求建立与引言的理论呼应机制
▶ 结论:需提炼方法论贡献与实践指导价值

## 章节简介标准
[内容维度]
① 研究任务:本章的核心学术使命
② 内容要素:需涵盖的关键组件(按优先级排序)
③ 衔接设计:与前后章节的逻辑接口
④ 创新载体:区别于常规论文的突破点

[质量要求]
✓ 各章节描述需形成"方法-实验-结论"证据链
✓ 保持技术细节与理论深度的平衡
✓ 突出从具体实验到一般规律的研究跃迁

## 输出格式(严格遵循)
{format_instructions}

感谢你的创作 ~
    """)
])


outline_prompt = ChatPromptTemplate([
    ("system", """你是一名论文导师, 非常擅长写论文, 用户将会给到你(层级标题/当前标题/描述信息)三部分内容, 请根据用户输入的信息生成 大纲和大纲下的简介 的列表
参考信息介绍:
    层级标题: 代表当前的层级标题一般格式为: ["文章标题", "一级标题", "二级标题"]
    当前标题: 当前需要生成大纲的标题, 一般格式为(三级标题)
    描述信息: 描述当前标题需要做哪些内容


大纲列表生成要求如下:
    大纲应具有逻辑性, 每个部分应包含关键点
    确保大纲结构清晰,便于后续扩展成完整的论文
    大纲只需要一次返回一级目录, 不需要返回多级


大纲对应简介列表生成要求如下:
    简介内容为当前标题下的描述
    简介介绍应该本段需要编写哪些内容的方向和指定计划
    简介应当前后呼应(譬如实验准备和实验细节以及最后的实验结论)
    简介内容应当详细, 该描述需要准确和全面, 对于本章节具有完全概括的能力


整体要求如下:
    认真思考给到的“当前标题”和“描述信息”
    生成输出的内容需要严谨, 请结合“层级标题”以内的内容作为边界, 可以参考当前所有的层级, 但是请注意不要拓展太多或者产生超纲的行为
    内容可以是创新也可以是现有资料

"""),
    ("human", """请帮我生成大纲和对应的大纲简介谢谢:
层级标题: {deep_title}
当前标题: {title}
描述信息: {description}


并严格使用如下格式输出, 请不要输出严用格式以外的内容:
{format_instructions}

感谢你的创作 ~
    """)
])


content_prompt = ChatPromptTemplate([
    ("system", """
【角色定义】
你是一位资深论文指导专家,具备跨学科研究经验,擅长将碎片化信息整合为严谨的学术论述。请根据用户提供的层级架构和内容描述,生成符合学术规范的段落内容。

【输入信息说明】
用户将依次提供:
1. 层级路径:["文章标题", "一级标题", "二级标题"]组成的树状结构
2. 当前节点:需撰写的具体标题名称及层级位置
3. 内容纲要:该段落需达成的论述目标与关键要素

【内容生成规范】
1. 结构控制:
- 以当前节点为圆心,向上兼容父级标题逻辑,向下预留子级展开空间
- 段落长度控制在200-800字,每段聚焦1个核心论点
- 使用学术过渡词("值得注意的是""需要强调的是")衔接段落

2. 术语运用:
- 采用《学科分类与代码》国家标准中的规范术语
- 新兴概念首次出现时标注英文原文及提出年份(例:具身认知(Embodied Cognition, Lakoff 1999))
- 术语密度控制在每百字3-5个专业词汇

3. 论述展开:
- 论点-论据-论证三层结构完整
- 实验数据类:说明样本量/显著性水平/效应值
- 理论分析类:引用2-3个经典理论框架
- 应用研究类:结合典型案例进行机理阐释

【内容质量要求】
1. 创新性处理:
- 对成熟理论采用"背景假设-核心主张-当代挑战"分析框架
- 对新兴领域运用"技术路线-关键瓶颈-突破方向"论述模式
- 交叉学科内容建立维度矩阵(例:时间/空间/主体三个分析维度)

2. 边界控制:
- 学术深度不超过该节点在层级树中的位置等级
- 二级标题内容需预留至少3个可量化研究空白点
- 方法论部分保持范式中立(实证/解释/批判)

【输出格式规范】
1. 段落构建:
[主题句]→[理论支撑]→[论证展开]→[承启句]
示例:
"数字化转型重构了企业知识管理范式(主题句)。基于动态能力观(Teece, 1997),组织学习机制正从单环学习向双环学习演进(理论支撑)。笔者对长三角412家制造企业的调研显示,采用数据中台的企业知识迭代速度提升37.2%(p<0.01),这印证了Nonaka(2021)提出的知识涌现理论(论证展开)。这种变革对传统管理体系提出了三方面挑战(承启句)。"

2. 特殊处理:
- 争议观点标注学术阵营(例:新古典学派主张...,而演化经济学派认为...)
- 重大发现标注学界反响(例:该结论被《管理世界》2023年度综述列为十大争议议题)
- 关键技术说明应用场景(例:区块链溯源技术在农产品供应链中的典型应用包括...)

"""),
    ("human", """请帮我生成内容谢谢:
层级标题: {deep_title}
当前标题: {title}
描述信息: {description}


并严格使用如下格式输出, 请不要输出严用格式以外的内容:
请直接返回文本即可, 不需要返回标题


历史对话如下:
{history}


感谢你的创作 ~
    """)
])

方法定义

# ===== 节点函数定义 =====
from typing import List

from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.runnables import RunnableWithMessageHistory
from langchain.memory import VectorStoreRetrieverMemory
from langchain_openai import ChatOpenAI

from paper_agent.bge_embedding import init_vector_store, init_rerank
from paper_agent.model import AgentState, ArticleStruct, OutlineModel, SectionModel

from langchain.globals import set_debug  # 导入在 langchain 中设置调试模式的函数

from paper_agent.prompt import outline_prompt, layout_prompt, content_prompt

set_debug(True)  # 启用 langchain 的调试模式


llm = ChatOpenAI(
    api_key="",
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
    model="qwen2.5-72b-instruct",
    verbose=True
)
store = {}

def get_session_history(session_id) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = InMemoryChatMessageHistory()
    return store[session_id]



vector_store_memory, rerank_memory = init_vector_store(), init_rerank()
retriever_memory = VectorStoreRetrieverMemory(retriever=vector_store_memory.as_retriever(search_kwargs=dict(k=5)))


outline_parser = PydanticOutputParser(pydantic_object=OutlineModel)


# 使用参考: https://python.langchain.com/api_reference/core/runnables/langchain_core.runnables.history.RunnableWithMessageHistory.html
outline_history_chain = RunnableWithMessageHistory(outline_prompt | llm, get_session_history, input_messages_key="title", history_messages_key="history")
layout_history_chain = RunnableWithMessageHistory(layout_prompt | llm, get_session_history, input_messages_key="title", history_messages_key="history")
content_history_chain = RunnableWithMessageHistory(content_prompt | llm, get_session_history, input_messages_key="title", history_messages_key="history")


def start_input_node(state: AgentState):
    return {
        "input": input("文章标题: "),
        "description": input("文章简介: ")
    }

def build_article(sections: List[SectionModel]):
    content = []
    for section in sections:
        children = []
        if section.sections:
            children = build_article(section.sections)
        if children:
            content.append(
                ArticleStruct(title=section.title, description=section.description, confirmed_outline=False, children=children)
            )
        else:
            content.append(
                ArticleStruct(title=section.title, description=section.description, confirmed_outline=False)
            )
    return content


def gen_content(paper_title, current_title, current_description, content, level=2, deep_title=None):
    """
    :param paper_title: 论文标题
    :param current_title: 当前子标题
    :param current_description: 当前子描述
    :param content:
    :param level:
    :param deep_title:
    :return:
    """
    if not content:
        if paper_title == current_title:
            chain_res: AIMessage = layout_history_chain.invoke({
                "title": current_title,
                "description": current_description,
                "format_instructions": outline_parser.get_format_instructions(),
            }, config={"configurable": {"session_id": paper_title}})
        else:
            chain_res: AIMessage = outline_history_chain.invoke({
                "deep_title": deep_title,
                "title": current_title,
                "description": current_description,
                "format_instructions": outline_parser.get_format_instructions(),
            }, config={"configurable": {"session_id": paper_title}})

        outline_model: OutlineModel = outline_parser.parse(chain_res.content)
        content = build_article(outline_model.sections)

        return {"content": content, "gen_end": False}


    for item in content:
        if not item.content:
            chain_res: AIMessage = content_history_chain.invoke({
                "deep_title": deep_title,
                "title": item.title,
                "description": current_description,
                "format_instructions": outline_parser.get_format_instructions(),
                "history": (retriever_memory.load_memory_variables({"input": current_description}) or {}).get("history", "")
            }, config={"configurable": {"session_id": paper_title}})
            item.content = chain_res.content
            item.confirmed_content = False
            return {"content": content, "gen_end": False}

        if isinstance(item.children, list):
            res = gen_content(paper_title, item.title, item.description, item.children, level + 1, [*deep_title, item.title])
            if not res["gen_end"]:
                item.children = res["content"]
                return {"content": content, "gen_end": False}

    return {"content": content, "gen_end": True}


def gen_content_node(state: AgentState):
    """根据用户的想法生成大纲"""
    title = state["input"]
    description = state["description"]

    return gen_content(title, title, description, state["content"], level=2, deep_title=[title])


def gen_response(content: List[ArticleStruct], level=2):
    response = ""
    for item in content:
        response += f"{level * '#'} {item.title}\n"
        if item.content:
            response += f"{item.content}\n"
        if item.children:
            response += gen_response(item.children, level=level + 1)

    return response


def gen_response_node(state: AgentState):
    """根据content生成最终结果"""
    response = f"# {state['input']}\n"
    content = state["content"]

    response += gen_response(content, level=2)

    return {"response": response}


def user_confirm(content: List[ArticleStruct]):
    confirmed_outline = content[0].confirmed_outline
    # 首先判断大纲是否确认, 然后再判断内容是否确认
    if not confirmed_outline:
        while True:
            outlines = [f"{i}. {item.title}" for i, item in enumerate(content)]
            outline_str = "\n".join(outlines)

            user_input = input(f'{outline_str}\n, 请确认当前大纲(确认/重新生成/修改)?')
            if user_input == "确认":
                for item in content:
                    item.confirmed_outline = True
                return {"content": content, "check_end": True}
            elif user_input == "重新生成":
                return {"content": [], "check_end": True}
            elif user_input == "修改":
                i = int(input("请输入需要修改的编号: "))
                title = input("请输入新的标题: ")
                content[i].title = title

    # 判断具体的内容, 是否需要通过
    for item in content:
        if not item.confirmed_content:
            user_input = input(f'{item.title}\n{item.content}\n, 请确认当前内容(确认/重新生成/修改/生成细节)?')
            if user_input == "确认":
                item.confirmed_content = True
                # 确认完毕, 存储到历史记录
                retriever_memory.save_context({"system": item.description, "ai": item.content})
                return {"content": content, "check_end": True}
            elif user_input == "重新生成":
                item.content = ""
                item.confirmed_content = False
                return {"content": content, "check_end": True}
            elif user_input == "修改":
                user_input_content = input("请输入新的内容: ")
                item.content = user_input_content
            elif user_input == "生成细节":
                item.confirmed_content = True
                item.children = []
            return {"content": content, "check_end": True}

        if item.children:
            print("user_confirm item.children: ", item.children)
            res = user_confirm(item.children)
            if res["check_end"]:
                item.children = res["content"]
                return {"content": content, "check_end": True}

    return {"content": content, "check_end": False}



def user_confirm_node(state: AgentState):
    """请求用户确认是否继续"""
    content, is_confirm = state["content"], False
    while not is_confirm:
        res = user_confirm(content)
        content = res["content"]
        is_confirm = not res["check_end"]
    return {"content": content}

embedding 历史记录

from typing import List

import faiss
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain_community.vectorstores import FAISS
from langchain_community.docstore import InMemoryDocstore
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.memory import VectorStoreRetrieverMemory
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
import torch


# ======================
# 自定义 Embedding 类
# ======================
class CustomTransformerEmbeddings(Embeddings):
    """自定义 Transformers 嵌入模型"""

    def __init__(self, model_name: str, device: str = "cuda:0"):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)
        self.model.eval()

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """文档嵌入"""
        return self._embed(texts)

    def embed_query(self, text: str) -> List[float]:
        """查询嵌入"""
        return self._embed([text])[0]

    def _embed(self, texts: List[str]) -> List[List[float]]:
        """实际嵌入实现"""
        inputs = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=512
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)
            embeddings = self._mean_pooling(outputs, inputs['attention_mask'])

        return embeddings.cpu().numpy().tolist()

    def _mean_pooling(self, model_output, attention_mask):
        """池化方法"""
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = (
            attention_mask
            .unsqueeze(-1)
            .expand(token_embeddings.size())
            .float()
        )
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


# ======================
# 自定义 Rerank 类
# ======================
class CrossEncoderReranker:
    """基于交叉编码器的重排序器"""

    def __init__(self, model_name: str, device: str = "cuda:0"):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
        self.model.eval()

    def rerank(self, query: str, documents: List[Document], top_k: int = 5) -> List[Document]:
        """重排序文档"""
        pairs = [(query, doc.page_content) for doc in documents]
        scores = self._predict(pairs)

        # 组合文档与分数
        scored_docs = list(zip(documents, scores))
        # 按分数降序排序
        scored_docs.sort(key=lambda x: x[1], reverse=True)
        # 返回前k个文档
        return [doc for doc, score in scored_docs[:top_k]]

    def _predict(self, pairs: List[tuple]) -> List[float]:
        """计算相关性分数"""
        inputs = self.tokenizer(
            pairs,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=512
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)
            scores = torch.sigmoid(outputs.logits).squeeze().cpu().numpy()

        return scores.tolist()



def init_vector_store():
    embedding_model = CustomTransformerEmbeddings("~/Documents/models/BAAI/bge_m3", "cpu")
    index = faiss.IndexFlatL2(len(embedding_model.embed_query("hello world")))
    return FAISS(
        embedding_function=embedding_model,
        index=index,
        docstore=InMemoryDocstore(),
        index_to_docstore_id={}
    )


def init_rerank():
     return CrossEncoderReranker("~/Documents/models/BAAI/bge_reranker_v2_m3", "cpu")



if __name__ == "__main__":
    # 初始化组件
    embedding_model = CustomTransformerEmbeddings("~/Documents/models/BAAI/bge_m3", "cpu")
    reranker = CrossEncoderReranker("~/Documents/models/BAAI/bge_reranker_v2_m3", "cpu")

    # 创建 FAISS 向量库
    index = faiss.IndexFlatL2(len(embedding_model.embed_query("hello world")))
    vector_store = FAISS(
                embedding_function=embedding_model,
                index=index,
                docstore=InMemoryDocstore(),
                index_to_docstore_id={}
            )

    memory = VectorStoreRetrieverMemory(retriever=vector_store.as_retriever(search_kwargs=dict(k=1)))
    memory.save_context({"input": "What is FAISS?"}, {"output": "FAISS is FAISS"})
    memory.save_context({"input": "What is LangChain?"}, {"output": "LangChain is a framework for developing LLM applications"})

    docs = memory.load_memory_variables({"input": "What is LangChain?"})
    print(docs)

结构体定义

from typing import TypedDict, Optional, List
from pydantic import Field, BaseModel

"""
1. 用户输入题目和简介, 开始生成
2. 生成大纲 -> 用户确认
3. 根据大纲逐步生成细节 -> 用户确认
4. 输出
"""

class SectionModel(BaseModel):
    title: str = Field(description="标题, 在20个字以内")
    description: str = Field(description="内容简介, 在200字以内")  # 隐藏
    sections: List['SectionModel'] = Field(description="大纲子项")


class OutlineModel(BaseModel):
    sections: List[SectionModel] = Field(description="大纲列表(章节)")


class ArticleStruct(BaseModel):
    title: str = Field(description="标题, 在20个字以内")
    description: str = Field(description="内容简介, 在200字以内")  # 隐藏
    content: str = Field(description="正文内容", default="")
    children: List['ArticleStruct'] = Field(description="子内容", default=None)
    confirmed_outline: bool  = Field(description="大纲是否确认", default=False)
    confirmed_content: bool  = Field(description="内容是否确认", default=False)


class AgentState(TypedDict):
    input: str  # 用户原始输入
    description: str # 文章简介
    content: List[ArticleStruct]  # 段落列表
    gen_end: bool  # 是否生成结束
    response: str

主流程构建

from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, END, START
from langchain.prompts import ChatPromptTemplate


from paper_agent.func_call import gen_content_node, user_confirm_node, \
    start_input_node, gen_response_node

from paper_agent.model import AgentState

def main():
    """
    用户输入 标题/简介
    生成大纲 用户确认 -> 大纲细化/重新生成/确认
        大纲细化 -> 生成大纲
        重新生成 -> 生成大纲
        确认    -> 生成正文

    """

    builder = StateGraph(AgentState)
    # 添加开始的用户输入
    builder.add_node("start_input_node", start_input_node)
    # 添加大纲生成的结构
    builder.add_node("gen_content_node", gen_content_node)
    builder.add_node("user_confirm_node", user_confirm_node)
    builder.add_node("gen_response_node", gen_response_node)


    builder.add_edge("start_input_node", "gen_content_node")
    builder.add_edge("user_confirm_node", "gen_content_node")

    builder.add_conditional_edges(
        "gen_content_node", lambda state: "gen_response_node" if state["gen_end"] else "user_confirm_node", {
        "gen_response_node": "gen_response_node",
        "user_confirm_node": "user_confirm_node"
    })

    builder.set_entry_point("start_input_node")

    app = builder.compile(debug=True)
    result = app.invoke({
        "input": "",
        "description": "",
        "content": [],  # 段落列表,
        "gen_end": False,  # 是否生成结束
        "response": ""
    }, {"recursion_limit": 100000})


    print("result: ", result)

    output_path = f"./paper/{result['input']}.md"
    with open(output_path, "w") as file:
        file.write(result["response"])


if __name__ == "__main__":
    main()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值