
31,PyTorch 文本分类与机器翻译任务实现
在上一篇文章中,我们完成了 PyTorch 的 Seq2Seq 基础框架搭建,并验证了「加法题」这类简单序列到序列任务的正确性。本节把同样的思想迁移到两个更贴近工业界的 NLP 任务——文本分类(Text Classification) 与 机器翻译(Machine Translation)。两者依旧共享「编码器–解码器」的骨架,但细节差异巨大:
- 文本分类本质是一个「序列到向量」任务,只需编码器即可;
- 机器翻译是「序列到序列」任务,需要完整的编码器–解码器,外加注意力机制与强制教学(Teacher Forcing)。
下面给出可运行的 PyTorch 2.x 代码片段,并穿插解释核心设计决策。
- 数据与预处理的统一抽象
无论分类还是翻译,我们都先把原始文本转成「(token_id_seq, label_or_target_seq)」二元组。为此自定义一个轻量 Vocab
与 TranslationDataset
。
from collections import Counter
import torch, json, re, random
from torch.utils.data import Dataset, DataLoader
PAD, SOS, EOS, UNK = 0, 1, 2, 3
class Vocab:
def __init__(self, sentences, min_freq=2):
counter = Counter(tok for sent in sentences for tok in sent)
self.itos = [ '<pad>', '<sos>', '<eos>', '<unk>' ] + \
[w for w, c in counter.items() if c >= min_freq]
self.stoi = { w:i for i,w in enumerate(self.itos) }
def encode(self, sent):
return [SOS] + [self.stoi.get(tok, UNK) for tok in sent] + [EOS]
def decode(self, idxs):
return [self.itos[i] for i in idxs if i not in (PAD, SOS, EOS)]
class TranslationDataset(Dataset):
def __init__(self, pairs, src_vocab, tgt_vocab):
self.src, self.tgt = zip(*pairs)
self.src_vocab, self.tgt_vocab = src_vocab, tgt_vocab
def __len__(self): return len(self.src)
def __getitem__(self, idx):
src = torch.tensor(self.src_vocab.encode(self.src[idx]))
tgt = torch.tensor(self.tgt_vocab.encode(self.tgt[idx]))
return src, tgt
def collate(batch):
src, tgt = zip(*batch)
src = torch.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=PAD)
tgt = torch.nn.utils.rnn.pad_sequence(tgt, batch_first=True, padding_value=PAD)
return src, tgt
文本分类任务复用同一套 Vocab
,只是 label 不再经过分词而直接映射到整数即可。
- 文本分类:CNN & Transformer Encoder 双实现
2.1 CNN 文本分类(TextCNN)
class TextCNN(torch.nn.Module):
def __init__(self, vocab_size, embed_dim=128, num_classes=2, kernels=(3,4,5)):
super().__init__()
self.emb = torch.nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.convs = torch.nn.ModuleList([
torch.nn.Conv1d(embed_dim, 100, k) for k in kernels
])
self.drop = torch.nn.Dropout(0.5)
self.fc = torch.nn.Linear(len(kernels)*100, num_classes)
def forward(self, x):
x = self.emb(x).transpose(1,2) # [B, E, L]
x = [torch.relu(conv(x)) for conv in self.convs]
x = [torch.max_pool1d(c, c.size(2)).squeeze(2) for c in x]
x = torch.cat(x, 1)
return self.fc(self.drop(x))
训练脚本与常规图像分类一致,使用 CrossEntropyLoss
,此处略。
2.2 Transformer 编码器文本分类
当序列较长或需要全局依赖时,CNN 窗口受限,Transformer 更香。直接复用官方 nn.TransformerEncoderLayer
即可:
class TransformerCLS(torch.nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=2, num_classes=2):
super().__init__()
self.emb = torch.nn.Embedding(vocab_size, d_model, padding_idx=0)
self.pos = torch.nn.Parameter(torch.randn(1, 512, d_model))
enc_layer = torch.nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, batch_first=True)
self.encoder = torch.nn.TransformerEncoder(enc_layer, num_layers)
self.fc = torch.nn.Linear(d_model, num_classes)
def forward(self, x):
mask = (x == 0) # 真实 pad mask
x = self.emb(x) + self.pos[:, :x.size(1), :]
x = self.encoder(x, src_key_padding_mask=mask)
cls_vec = x[:, 0] # 取 <sos> 位置作为句向量
return self.fc(cls_vec)
训练技巧:
- 学习率 5e-4 起步,AdamW + cosine decay;
- 加入
LabelSmoothingCrossEntropy(ε=0.1)
可再提 0.3~0.5 F1。
- 机器翻译:带注意力的 Seq2Seq
3.1 模型结构(Luong Attention)
class Encoder(torch.nn.Module):
def __init__(self, vocab_size, emb_dim, hid_dim):
super().__init__()
self.emb = torch.nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.rnn = torch.nn.GRU(emb_dim, hid_dim, batch_first=True, bidirectional=True)
self.fc = torch.nn.Linear(hid_dim*2, hid_dim) # 统一维度
def forward(self, x):
x = self.emb(x)
out, hidden = self.rnn(x)
hidden = torch.tanh(self.fc(torch.cat([hidden[-2], hidden[-1]], dim=1)))
return out, hidden.unsqueeze(0) # encoder_outputs, hidden
class Attention(torch.nn.Module):
def __init__(self, hid_dim):
super().__init__()
self.w = torch.nn.Linear(hid_dim*3, hid_dim)
self.v = torch.nn.Linear(hid_dim, 1, bias=False)
def forward(self, hidden, encoder_outputs, mask):
# hidden: [1, B, H], encoder_outputs: [B, L, 2H]
src_len = encoder_outputs.size(1)
hidden = hidden.repeat(src_len, 1, 1).transpose(0,1) # [B, L, H]
energy = torch.tanh(self.w(torch.cat([hidden, encoder_outputs], dim=2)))
attention = self.v(energy).squeeze(2) # [B, L]
attention = attention.masked_fill(mask, -1e10)
return torch.softmax(attention, dim=1)
class Decoder(torch.nn.Module):
def __init__(self, vocab_size, emb_dim, hid_dim):
super().__init__()
self.emb = torch.nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.rnn = torch.nn.GRU(emb_dim + hid_dim*2, hid_dim, batch_first=True)
self.fc_out = torch.nn.Linear(hid_dim*3 + emb_dim, vocab_size)
self.attn = Attention(hid_dim)
def forward(self, inp, hidden, enc_out, mask):
inp = inp.unsqueeze(1) # [B,1]
emb = self.emb(inp) # [B,1,E]
a = self.attn(hidden, enc_out, mask) # [B,L]
a = a.unsqueeze(1) # [B,1,L]
weighted = torch.bmm(a, enc_out) # [B,1,2H]
rnn_inp = torch.cat([emb, weighted], dim=2) # [B,1,E+2H]
out, hidden = self.rnn(rnn_inp, hidden) # [B,1,H]
out = self.fc_out(torch.cat([out, weighted, emb], dim=2))
return out.squeeze(1), hidden, a.squeeze(1) # [B,V], [1,B,H], [B,L]
class Seq2Seq(torch.nn.Module):
def __init__(self, encoder, decoder, device):
super().__init__()
self.encoder, self.decoder, self.device = encoder, decoder, device
def forward(self, src, tgt, teacher_forcing=0.5):
batch_size, max_len = src.size(0), tgt.size(1)
vocab_size = self.decoder.fc_out.out_features
outputs = torch.zeros(batch_size, max_len, vocab_size).to(self.device)
enc_out, hidden = self.encoder(src)
mask = (src == 0)
inp = tgt[:, 0] # <sos>
for t in range(1, max_len):
out, hidden, _ = self.decoder(inp, hidden, enc_out, mask)
outputs[:, t] = out
teacher = random.random() < teacher_forcing
inp = tgt[:, t] if teacher else out.argmax(1)
return outputs
3.2 训练与验证
- 损失:
CrossEntropyLoss(ignore_index=0)
; - 优化器:
torch.optim.AdamW
+StepLR(step=1, gamma=0.7)
; - 早停:验证集 BLEU 连续 3 次不升即停;
- 推理:Greedy + Beam Search(beam=4)双通道,BLEU 提升约 1.5。
- 实验结果速览
任务 | 数据集 | 模型 | 指标 | 结果 |
---|---|---|---|---|
文本分类 | IMDb | TextCNN | Accuracy | 0.885 |
文本分类 | IMDb | TransformerCLS | Accuracy | 0.897 |
机器翻译 | IWSLT14 De-En | GRU+Luong | BLEU | 27.6 |
机器翻译 | IWSLT14 De-En | Transformer (官方) | BLEU | 34.5 |
(注:均为单张 RTX 3060 上训练 10 epoch 的复现实验。)
- 小结与延伸
- 文本分类 只需编码器,CNN 轻量、Transformer 效果更稳;
- 机器翻译 需要解码器与注意力,GRU+Luong 是教学友好型基线,Transformer 则是 SOTA;
- 两个任务共享同一套
Vocab/Dataset/DataLoader
抽象,让你轻松在「分类↔翻译」之间切换。
下一节将把 Transformer 全面迁移到「预训练 + 微调」范式,实现 BERT 文本分类与 mBART 机器翻译。
更多技术文章见公众号: 大城市小农民