1. 引言
在自然语言处理(NLP)任务中,文本润色(text polishing)是一项重要的应用,例如帮助用户提升写作质量、改善语法或增强表达清晰度。T5(Text-to-Text Transfer Transformer)模型以其强大的泛化能力和灵活的文本转换特性,成为 Fine-Tuning 任务的理想选择。本文将通过一个简单的示例,介绍如何对 T5 模型进行 Fine-Tuning,使其具备自动文本润色能力。
2. 数据集准备
Fine-Tuning T5 需要构建输入与目标输出的文本对。例如,我们希望 T5 模型能够将“错误或表达不清的句子”转换为“流畅且正确的句子”。我们定义如下数据集:
input_texts = [
"I write bad, need help.",
"This is a good products.",
"Meeting is tomorrow, I not ready.",
"He talk too much and not clear."
]
target_texts = [
"I write poorly and need assistance.",
"This is a great product.",
"The meeting is tomorrow, and I'm not prepared.",
"He talks too much and isn't clear."
]
在 T5 中,所有任务都需要一个前缀,例如“polish:”,用于指示任务类型。这样可以让模型在生成时理解具体的任务目标。
3. 数据集加载
在 PyTorch 中,Dataset
是数据加载的基础类,我们可以创建自定义 PolishingDataset
类:
from torch.utils.data import Dataset
from transformers import T5Tokenizer
class PolishingDataset(Dataset):
def __init__(self, input_texts, target_texts, tokenizer, max_length=128):
self.input_texts = input_texts
self.target_texts = target_texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.input_texts)
def __getitem__(self, idx):
input_text = "polish: " + self.input_texts[idx]
target_text = self.target_texts[idx]
input_encoding = self.tokenizer(
input_text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt"
)
target_encoding = self.tokenizer(
target_text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt"
)
return {
'input_ids': input_encoding['input_ids'].squeeze(0),
'attention_mask': input_encoding['attention_mask'].squeeze(0),
'labels': target_encoding['input_ids'].squeeze(0)
}
4. 加载 T5 预训练模型
from transformers import T5ForConditionalGeneration
model_name = "t5-small"
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
5. 训练模型
训练循环包括前向传播、损失计算、反向传播以及优化步骤。这里使用 AdamW
作为优化器。
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
train_dataset = PolishingDataset(input_texts, target_texts, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
optimizer = AdamW(model.parameters(), lr=3e-4)
model.train()
for epoch in range(3):
total_loss = 0
for batch in train_dataloader:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
print(f"Epoch {epoch + 1}, Average Loss: {total_loss / len(train_dataloader):.4f}")
6. 进行推理(测试模型)
训练完成后,我们可以使用 model.generate()
进行推理。
model.eval()
test_texts = [
"I not good at writing.",
"This phone is nice but heavy."
]
with torch.no_grad():
for test_text in test_texts:
input_text = "polish: " + test_text
encoded_input = tokenizer(input_text, return_tensors="pt")
output_ids = model.generate(
input_ids=encoded_input['input_ids'],
attention_mask=encoded_input['attention_mask'],
max_length=50,
num_beams=4,
early_stopping=True
)
polished_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"Original: '{test_text}' -> Polished: '{polished_text}'")
7. 完整代码实例
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer
# 定义数据集类
class PolishingDataset(Dataset):
def __init__(self, input_texts, target_texts, tokenizer, max_length=128):
self.input_texts = input_texts
self.target_texts = target_texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.input_texts)
def __getitem__(self, idx):
input_text = "polish: " + self.input_texts[idx] # 添加任务前缀
target_text = self.target_texts[idx]
# 编码输入
input_encoding = self.tokenizer(
input_text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors="pt"
)
# 编码目标输出
target_encoding = self.tokenizer(
target_text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors="pt"
)
return {
'input_ids': input_encoding['input_ids'].squeeze(0),
'attention_mask': input_encoding['attention_mask'].squeeze(0),
'labels': target_encoding['input_ids'].squeeze(0) # 目标文本的 input_ids 作为 labels
}
# 加载预训练 T5 模型和分词器
model_name = "google-t5/t5-small" # 确保使用正确的模型名称
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
# 准备训练数据
input_texts = [
"I write bad, need help.",
"This is a good products.",
"Meeting is tomorrow, I not ready.",
"He talk too much and not clear."
]
target_texts = [
"I write poorly and need assistance.",
"This is a great product.",
"The meeting is tomorrow, and I'm not prepared.",
"He talks too much and isn't clear."
]
# 创建数据集和数据加载器
dataset = PolishingDataset(input_texts, target_texts, tokenizer)
train_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 定义优化器(优化所有模型参数)
optimizer = AdamW(model.parameters(), lr=3e-4)
# 设置模型为训练模式
model.train()
# 训练循环
for epoch in range(30): # 训练 3 个 epoch
total_loss = 0
for batch in train_dataloader:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
# 前向传播
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs.loss
# 反向传播
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
avg_loss = total_loss / len(train_dataloader)
print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")
# 测试(推理)
model.eval()
test_texts = [
"I not good at writing.",
"This phone is nice but heavy."
]
with torch.no_grad():
for test_text in test_texts:
input_text = "polish: " + test_text
encoded_input = tokenizer(input_text, return_tensors="pt")
output_ids = model.generate(
input_ids=encoded_input['input_ids'],
attention_mask=encoded_input['attention_mask'],
max_length=50,
num_beams=4, # 使用 beam search 提高生成质量
early_stopping=True
)
polished_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"Original: '{test_text}' -> Polished: '{polished_text}'")
8. 结论
通过本文的示例,我们学习了如何 Fine-Tuning T5 模型来执行文本润色任务。完整流程包括数据准备、数据加载、模型训练、优化器设置及推理测试。Fine-Tuning 的思想可以扩展到其他 NLP 任务,例如机器翻译、摘要生成等。希望本教程能帮助你更好地理解 T5 Fine-Tuning 的过程,并应用于自己的任务中!