PyTorch TorchTune 项目中的提示模板使用指南
什么是提示模板
提示模板是结构化文本模板,用于格式化用户提示以优化模型在特定任务上的性能。在PyTorch TorchTune项目中,提示模板扮演着至关重要的角色,它能够帮助开发者更好地引导模型行为。
提示模板主要有三种用途:
- 模型专用模板:某些模型在预训练时就使用了特定标记(如Llama2和Mistral模型中的[INST]标签),在推理时使用这些模板能确保最佳性能
- 任务专用模板:针对特定任务设计的模板,如语法纠错、摘要生成、问答等
- 社区标准模板:如ChatML等广泛采用的标准化模板
为什么需要提示模板
提示模板与模型分词器添加的特殊标记不同。特殊标记是模型词汇表的一部分,用于标记文本的结构(如句子开始、结束等),而提示模板则是更高层次的文本结构,用于指导模型如何理解和响应输入。
如何使用提示模板
在TorchTune中,可以通过两种方式使用提示模板:
1. 通过点路径字符串指定
from torchtune.models.mistral import mistral_tokenizer
m_tokenizer = mistral_tokenizer(
path="/tmp/Mistral-7B-v0.1/tokenizer.model",
prompt_template="torchtune.models.mistral.MistralChatTemplate"
)
或者在配置文件中:
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
path: /tmp/Mistral-7B-v0.1/tokenizer.model
prompt_template: torchtune.models.mistral.MistralChatTemplate
2. 通过字典定义
字典中的每个角色对应一个元组,元组包含要在内容前后添加的文本。例如,要实现以下模板:
System: {content}\n
User: {content}\n
Assistant: {content}\n
Tool: {content}\n
可以这样定义:
template = {
"system": ("System: ", "\n"),
"user": ("User: ", "\n"),
"assistant": ("Assistant: ", "\n"),
"ipython": ("Tool: ", "\n"),
}
然后传递给分词器:
m_tokenizer = mistral_tokenizer(
path="/tmp/Mistral-7B-v0.1/tokenizer.model",
prompt_template=template,
)
使用PromptTemplate类
字典模板也可以传递给PromptTemplate类,作为独立的自定义提示模板使用:
from torchtune.data import PromptTemplate, Message
def my_custom_template() -> PromptTemplate:
return PromptTemplate(
template={
"user": ("User: ", "\n"),
"assistant": ("Assistant: ", "\n"),
},
)
template = my_custom_template()
msgs = [
Message(role="user", content="Hello world!"),
Message(role="assistant", content="Is AI overhyped?"),
]
templated_msgs = template(msgs)
创建自定义提示模板
对于更复杂的场景,可以创建继承自PromptTemplateInterface的新类:
from torchtune.data import Message, PromptTemplateInterface
class EurekaTemplate(PromptTemplateInterface):
def __call__(self, messages: list[Message], inference: bool = False) -> list[Message]:
formatted_dialogue = []
for message in messages:
if message.role == "assistant":
content = "Eureka!"
else:
content = message.content
formatted_dialogue.append(
Message(
role=message.role,
content=content,
masked=message.masked,
ipython=message.ipython,
eot=message.eot,
),
)
return formatted_dialogue
然后通过点路径字符串使用:
m_tokenizer = mistral_tokenizer(
path="/tmp/Mistral-7B-v0.1/tokenizer.model",
prompt_template="path.to.template.EurekaTemplate",
)
内置提示模板
TorchTune提供了多种内置模板:
- GrammarErrorCorrectionTemplate:语法纠错
- SummarizeTemplate:摘要生成
- QuestionAnswerTemplate:问答任务
- ChatMLTemplate:ChatML标准格式
最佳实践
- 一致性原则:训练和推理时应使用相同的提示模板
- 任务适配:选择与目标任务最匹配的模板
- 性能测试:不同模板可能影响模型表现,建议进行对比测试
- 可读性:模板设计应保持清晰易读,便于维护
通过合理使用提示模板,开发者可以显著提升模型在特定任务上的表现,同时保持代码的整洁和可维护性。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考