1. 引言
随着预训练语言模型(如 BERT、GPT、T5)的发展,Fine-Tuning 已成为 NLP 任务中的常见方法。然而,Fine-Tuning 需要调整大量模型参数,导致计算成本高、存储需求大。Prompt-Tuning 作为一种轻量级替代方案,通过添加任务相关的提示(Prompt)来调整模型的行为,极大减少了参数更新的需求,同时提升任务适配性。
本文介绍 Prompt-Tuning 技术,并通过一个意图分类的示例展示其应用。
2. Prompt-Tuning 的原理
Prompt-Tuning 的核心思想是利用 “提示模板” 来引导预训练模型的输出,而不直接修改模型权重。例如,在文本分类任务中,我们可以构造如下 Prompt:
- 原始文本: “Set a reminder for 3 PM.”
- Prompt 形式: “The user wants to: [MASK]”
这样,我们可以利用 Masked Language Model(MLM)来预测最合适的填充词,例如 “set a reminder”。
3. Prompt-Tuning 代码示例
以下示例展示了如何使用 DistilBERT
进行意图分类,并采用 Prompt-Tuning 技术来改进模型效果。
3.1 数据集构建
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
class IntentDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = "Intent: " + self.texts[idx] # 在输入中添加提示词
label = self.labels[idx]
encoding = self.tokenizer(
text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt"
)
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'labels': torch.tensor(label, dtype=torch.long)
}
3.2 训练模型
model_name = "distilbert-base-uncased"
num_intents = 3 # 假设有 3 种意图
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_intents)
tokenizer = AutoTokenizer.from_pretrained(model_name)
texts = [
"What's the weather like today?",
"Set a reminder for 3 PM.",
"Tell me a joke.",
"How's the weather tomorrow?",
"Remind me to call mom at 6 PM."
]
labels = [0, 1, 2, 0, 1] # 意图类别:0=询问天气, 1=设置提醒, 2=讲笑话
train_dataset = IntentDataset(texts, labels, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
optimizer = AdamW(model.parameters(), lr=5e-5)
model.train()
for epoch in range(3):
total_loss = 0
for batch in train_dataloader:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
print(f"Epoch {epoch + 1}, Average Loss: {total_loss / len(train_dataloader):.4f}")
3.3 进行推理
model.eval()
test_texts = [
"What's the forecast for this weekend?",
"Set an alarm for 7 AM."
]
intent_map = {0: "Ask about weather", 1: "Set reminder", 2: "Tell joke"}
with torch.no_grad():
for test_text in test_texts:
input_text = "Intent: " + test_text # 通过 Prompt-Tuning 提示任务
encoded_input = tokenizer(input_text, return_tensors="pt")
outputs = model(**encoded_input)
logits = outputs.logits
predicted_label = torch.argmax(logits, dim=1).item()
predicted_intent = intent_map[predicted_label]
print(f"Text: '{test_text}' -> Predicted Intent: {predicted_intent}")
4. Prompt-Tuning 的应用场景
Prompt-Tuning 适用于多种 NLP 任务:
- 文本分类:如意图识别、情感分析
- 文本生成:如摘要生成、对话系统
- 问答系统:使用 Prompt 提示用户问题类别
- 代码生成:在代码补全任务中提供任务相关提示
5. Prompt-Tuning 的优劣势
优势
- 减少参数调整:相比 Fine-Tuning,Prompt-Tuning 仅需调整部分参数,计算开销更低。
- 泛化能力强:可以在不同任务之间快速适配,提高迁移能力。
- 适用于小样本场景:无需大量标注数据,即可获得较好的效果。
劣势
- 对 Prompt 设计敏感:Prompt 选择会影响模型性能,需要实验不同的提示词。
- 无法完全替代 Fine-Tuning:对于需要精细控制的任务(如医学、法律等专业领域),Fine-Tuning 仍然更有效。
- 复杂任务表现受限:Prompt 适用于结构化任务,对于需要深度理解的任务,仍需结合 Fine-Tuning。
6. 完整代码实例
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
# 定义数据集类
class IntentDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors="pt"
)
return {
'input_ids': encoding['input_ids'].squeeze(0), # 移除多余的 batch 维度
'attention_mask': encoding['attention_mask'].squeeze(0),
'labels': torch.tensor(label, dtype=torch.long)
}
# 加载预训练模型和分词器
model_name = "distilbert-base-uncased"
num_intents = 3 # 假设有 3 种意图
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_intents)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 准备训练数据
texts = [
"What's the weather like today?",
"Set a reminder for 3 PM.",
"Tell me a joke.",
"How's the weather tomorrow?",
"Remind me to call mom at 6 PM."
]
labels = [0, 1, 2, 0, 1] # 意图标签:0=询问天气, 1=设置提醒, 2=讲笑话
# 创建数据集和数据加载器
dataset = IntentDataset(texts, labels, tokenizer)
train_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 定义优化器
optimizer = AdamW(model.parameters(), lr=5e-5)
# 设置模型为训练模式
model.train()
# 训练循环
for epoch in range(3): # 训练 3 个 epoch
total_loss = 0
for batch in train_dataloader:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
# 前向传播
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs.loss
# 反向传播
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
avg_loss = total_loss / len(train_dataloader)
print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")
# 测试(推理)
model.eval()
test_texts = [
"What's the forecast for this weekend?",
"Set an alarm for 7 AM."
]
intent_map = {0: "Ask about weather", 1: "Set reminder", 2: "Tell joke"}
with torch.no_grad():
for test_text in test_texts:
encoded_input = tokenizer(test_text, return_tensors="pt")
outputs = model(**encoded_input)
logits = outputs.logits
predicted_label = torch.argmax(logits, dim=1).item()
predicted_intent = intent_map[predicted_label]
print(f"Text: '{test_text}' -> Predicted Intent: {predicted_intent}")
7. 结论
Prompt-Tuning 是一种高效的 NLP 任务适配方法,能够在不改变大量模型参数的情况下,引导预训练模型完成特定任务。虽然其性能可能不及全面 Fine-Tuning,但在低资源场景、跨任务适配等方面具有巨大优势。希望本文的示例能帮助你理解 Prompt-Tuning,并在实际应用中灵活使用这一技术!