1. 环境准备
pip install transformers datasets peft accelerate bitsandbytes
2. 构造训练数据(指定话术)
# data.json
[
{"instruction": "用户问:你们的售后政策是什么?",
"output": "您好,我们提供7天无理由退货,30天质量问题免费换新服务。"},
{"instruction": "用户问:发票怎么开?",
"output": "请您下单时选择“需要发票”,我们会随货寄送正规发票。"},
{"instruction": "用户问:你们的客服工作时间?",
"output": "您好,我们的客服在线时间是每天早9点到晚9点。"}
]
3. 训练脚本
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
# 选择一个基础大模型
model_name = "meta-llama/Llama-2-7b-hf" # 可换成 Qwen1.5、Baichuan2 等
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 加载数据
dataset = load_dataset("json", data_files="data.json")
def format_examples(example):
return {"text": f"用户: {example['instruction']}\n客服: {example['output']}"}
dataset = dataset.map(format_examples)
# LoRA 配置
lora_config = LoraConfig(
r=8, lora_alpha=16, target_modules=["q_proj","v_proj"], lora_dropout=0.05, bias="none"
)
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_8bit=True, # 节省显存
device_map="auto"
)
model = get_peft_model(model, lora_config)
# 训练参数
training_args = TrainingArguments(
output_dir="./lora-model",
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-4,
num_train_epochs=3,
logging_steps=10,
save_steps=200,
save_total_limit=2,
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
tokenizer=tokenizer,
)
# 开始训练
trainer.train()
4. 推理测试
from transformers import pipeline
pipe = pipeline("text-generation", model="./lora-model", tokenizer=tokenizer)
res = pipe("用户: 你们的售后政策是什么?\n客服: ", max_new_tokens=50)
print(res[0]['generated_text'])
或者不训练,直接用 Prompt 约束话术
from transformers import pipeline
pipe = pipeline("text-generation", model="Qwen/Qwen1.5-7B-Chat")
system_prompt = """
你是电商平台的客服机器人,所有回答必须严格遵循以下固定话术:
1. 售后政策:7天无理由退货,30天质量问题免费换新。
2. 发票:下单时选择“需要发票”,随货寄送正规发票。
3. 客服时间:每天早9点到晚9点。
"""
user_input = "你们的客服时间是多久?"
res = pipe(f"{system_prompt}\n用户: {user_input}\n客服:", max_new_tokens=50)
print(res[0]['generated_text'])
大模型训练与话术约束方法
1481

被折叠的 条评论
为什么被折叠?



