Prompt-Tuning:让预训练模型更高效地适配任务

1. 引言

随着预训练语言模型(如 BERT、GPT、T5)的发展,Fine-Tuning 已成为 NLP 任务中的常见方法。然而,Fine-Tuning 需要调整大量模型参数,导致计算成本高、存储需求大。Prompt-Tuning 作为一种轻量级替代方案,通过添加任务相关的提示(Prompt)来调整模型的行为,极大减少了参数更新的需求,同时提升任务适配性。

本文介绍 Prompt-Tuning 技术,并通过一个意图分类的示例展示其应用。

2. Prompt-Tuning 的原理

Prompt-Tuning 的核心思想是利用 “提示模板” 来引导预训练模型的输出,而不直接修改模型权重。例如,在文本分类任务中,我们可以构造如下 Prompt:

  • 原始文本: “Set a reminder for 3 PM.”
  • Prompt 形式: “The user wants to: [MASK]”

这样,我们可以利用 Masked Language Model(MLM)来预测最合适的填充词,例如 “set a reminder”。

3. Prompt-Tuning 代码示例

以下示例展示了如何使用 DistilBERT 进行意图分类,并采用 Prompt-Tuning 技术来改进模型效果。

3.1 数据集构建

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset

class IntentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = "Intent: " + self.texts[idx]  # 在输入中添加提示词
        label = self.labels[idx]
        encoding = self.tokenizer(
            text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt"
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }

3.2 训练模型

model_name = "distilbert-base-uncased"
num_intents = 3  # 假设有 3 种意图
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_intents)
tokenizer = AutoTokenizer.from_pretrained(model_name)

texts = [
    "What's the weather like today?",
    "Set a reminder for 3 PM.",
    "Tell me a joke.",
    "How's the weather tomorrow?",
    "Remind me to call mom at 6 PM."
]
labels = [0, 1, 2, 0, 1]  # 意图类别:0=询问天气, 1=设置提醒, 2=讲笑话

train_dataset = IntentDataset(texts, labels, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-5)
model.train()

for epoch in range(3):
    total_loss = 0
    for batch in train_dataloader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Average Loss: {total_loss / len(train_dataloader):.4f}")

3.3 进行推理

model.eval()
test_texts = [
    "What's the forecast for this weekend?",
    "Set an alarm for 7 AM."
]
intent_map = {0: "Ask about weather", 1: "Set reminder", 2: "Tell joke"}

with torch.no_grad():
    for test_text in test_texts:
        input_text = "Intent: " + test_text  # 通过 Prompt-Tuning 提示任务
        encoded_input = tokenizer(input_text, return_tensors="pt")
        outputs = model(**encoded_input)
        logits = outputs.logits
        predicted_label = torch.argmax(logits, dim=1).item()
        predicted_intent = intent_map[predicted_label]
        print(f"Text: '{test_text}' -> Predicted Intent: {predicted_intent}")

4. Prompt-Tuning 的应用场景

Prompt-Tuning 适用于多种 NLP 任务:

  • 文本分类:如意图识别、情感分析
  • 文本生成:如摘要生成、对话系统
  • 问答系统:使用 Prompt 提示用户问题类别
  • 代码生成:在代码补全任务中提供任务相关提示

5. Prompt-Tuning 的优劣势

优势

  1. 减少参数调整:相比 Fine-Tuning,Prompt-Tuning 仅需调整部分参数,计算开销更低。
  2. 泛化能力强:可以在不同任务之间快速适配,提高迁移能力。
  3. 适用于小样本场景:无需大量标注数据,即可获得较好的效果。

劣势

  1. 对 Prompt 设计敏感:Prompt 选择会影响模型性能,需要实验不同的提示词。
  2. 无法完全替代 Fine-Tuning:对于需要精细控制的任务(如医学、法律等专业领域),Fine-Tuning 仍然更有效。
  3. 复杂任务表现受限:Prompt 适用于结构化任务,对于需要深度理解的任务,仍需结合 Fine-Tuning。

6. 完整代码实例

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset


# 定义数据集类
class IntentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors="pt"
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),  # 移除多余的 batch 维度
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }


# 加载预训练模型和分词器
model_name = "distilbert-base-uncased"
num_intents = 3  # 假设有 3 种意图
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_intents)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 准备训练数据
texts = [
    "What's the weather like today?",
    "Set a reminder for 3 PM.",
    "Tell me a joke.",
    "How's the weather tomorrow?",
    "Remind me to call mom at 6 PM."
]
labels = [0, 1, 2, 0, 1]  # 意图标签:0=询问天气, 1=设置提醒, 2=讲笑话

# 创建数据集和数据加载器
dataset = IntentDataset(texts, labels, tokenizer)
train_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 定义优化器
optimizer = AdamW(model.parameters(), lr=5e-5)

# 设置模型为训练模式
model.train()

# 训练循环
for epoch in range(3):  # 训练 3 个 epoch
    total_loss = 0
    for batch in train_dataloader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        # 前向传播
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss

        # 反向传播
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")

# 测试(推理)
model.eval()
test_texts = [
    "What's the forecast for this weekend?",
    "Set an alarm for 7 AM."
]
intent_map = {0: "Ask about weather", 1: "Set reminder", 2: "Tell joke"}

with torch.no_grad():
    for test_text in test_texts:
        encoded_input = tokenizer(test_text, return_tensors="pt")
        outputs = model(**encoded_input)
        logits = outputs.logits
        predicted_label = torch.argmax(logits, dim=1).item()
        predicted_intent = intent_map[predicted_label]
        print(f"Text: '{test_text}' -> Predicted Intent: {predicted_intent}")

7. 结论

Prompt-Tuning 是一种高效的 NLP 任务适配方法,能够在不改变大量模型参数的情况下,引导预训练模型完成特定任务。虽然其性能可能不及全面 Fine-Tuning,但在低资源场景、跨任务适配等方面具有巨大优势。希望本文的示例能帮助你理解 Prompt-Tuning,并在实际应用中灵活使用这一技术!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值