31,PyTorch 文本分类与机器翻译任务实现

在这里插入图片描述
31,PyTorch 文本分类与机器翻译任务实现

在上一篇文章中,我们完成了 PyTorch 的 Seq2Seq 基础框架搭建,并验证了「加法题」这类简单序列到序列任务的正确性。本节把同样的思想迁移到两个更贴近工业界的 NLP 任务——文本分类(Text Classification)机器翻译(Machine Translation)。两者依旧共享「编码器–解码器」的骨架,但细节差异巨大:

  • 文本分类本质是一个「序列到向量」任务,只需编码器即可;
  • 机器翻译是「序列到序列」任务,需要完整的编码器–解码器,外加注意力机制与强制教学(Teacher Forcing)。

下面给出可运行的 PyTorch 2.x 代码片段,并穿插解释核心设计决策。


  1. 数据与预处理的统一抽象

无论分类还是翻译,我们都先把原始文本转成「(token_id_seq, label_or_target_seq)」二元组。为此自定义一个轻量 VocabTranslationDataset

from collections import Counter
import torch, json, re, random
from torch.utils.data import Dataset, DataLoader

PAD, SOS, EOS, UNK = 0, 1, 2, 3

class Vocab:
    def __init__(self, sentences, min_freq=2):
        counter = Counter(tok for sent in sentences for tok in sent)
        self.itos = [ '<pad>', '<sos>', '<eos>', '<unk>' ] + \
                     [w for w, c in counter.items() if c >= min_freq]
        self.stoi = { w:i for i,w in enumerate(self.itos) }
    def encode(self, sent): 
        return [SOS] + [self.stoi.get(tok, UNK) for tok in sent] + [EOS]
    def decode(self, idxs):
        return [self.itos[i] for i in idxs if i not in (PAD, SOS, EOS)]

class TranslationDataset(Dataset):
    def __init__(self, pairs, src_vocab, tgt_vocab):
        self.src, self.tgt = zip(*pairs)
        self.src_vocab, self.tgt_vocab = src_vocab, tgt_vocab
    def __len__(self): return len(self.src)
    def __getitem__(self, idx):
        src = torch.tensor(self.src_vocab.encode(self.src[idx]))
        tgt = torch.tensor(self.tgt_vocab.encode(self.tgt[idx]))
        return src, tgt

def collate(batch):
    src, tgt = zip(*batch)
    src = torch.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=PAD)
    tgt = torch.nn.utils.rnn.pad_sequence(tgt, batch_first=True, padding_value=PAD)
    return src, tgt

文本分类任务复用同一套 Vocab,只是 label 不再经过分词而直接映射到整数即可。


  1. 文本分类:CNN & Transformer Encoder 双实现

2.1 CNN 文本分类(TextCNN)

class TextCNN(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim=128, num_classes=2, kernels=(3,4,5)):
        super().__init__()
        self.emb = torch.nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.convs = torch.nn.ModuleList([
            torch.nn.Conv1d(embed_dim, 100, k) for k in kernels
        ])
        self.drop = torch.nn.Dropout(0.5)
        self.fc = torch.nn.Linear(len(kernels)*100, num_classes)
    def forward(self, x):
        x = self.emb(x).transpose(1,2)                # [B, E, L]
        x = [torch.relu(conv(x)) for conv in self.convs]
        x = [torch.max_pool1d(c, c.size(2)).squeeze(2) for c in x]
        x = torch.cat(x, 1)
        return self.fc(self.drop(x))

训练脚本与常规图像分类一致,使用 CrossEntropyLoss,此处略。

2.2 Transformer 编码器文本分类

当序列较长或需要全局依赖时,CNN 窗口受限,Transformer 更香。直接复用官方 nn.TransformerEncoderLayer 即可:

class TransformerCLS(torch.nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=2, num_classes=2):
        super().__init__()
        self.emb = torch.nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos = torch.nn.Parameter(torch.randn(1, 512, d_model))
        enc_layer = torch.nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, batch_first=True)
        self.encoder = torch.nn.TransformerEncoder(enc_layer, num_layers)
        self.fc = torch.nn.Linear(d_model, num_classes)
    def forward(self, x):
        mask = (x == 0)                         # 真实 pad mask
        x = self.emb(x) + self.pos[:, :x.size(1), :]
        x = self.encoder(x, src_key_padding_mask=mask)
        cls_vec = x[:, 0]                       # 取 <sos> 位置作为句向量
        return self.fc(cls_vec)

训练技巧:

  • 学习率 5e-4 起步,AdamW + cosine decay;
  • 加入 LabelSmoothingCrossEntropy(ε=0.1) 可再提 0.3~0.5 F1。

  1. 机器翻译:带注意力的 Seq2Seq

3.1 模型结构(Luong Attention)

class Encoder(torch.nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim):
        super().__init__()
        self.emb = torch.nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.rnn = torch.nn.GRU(emb_dim, hid_dim, batch_first=True, bidirectional=True)
        self.fc = torch.nn.Linear(hid_dim*2, hid_dim)   # 统一维度
    def forward(self, x):
        x = self.emb(x)
        out, hidden = self.rnn(x)
        hidden = torch.tanh(self.fc(torch.cat([hidden[-2], hidden[-1]], dim=1)))
        return out, hidden.unsqueeze(0)    # encoder_outputs, hidden

class Attention(torch.nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.w = torch.nn.Linear(hid_dim*3, hid_dim)
        self.v = torch.nn.Linear(hid_dim, 1, bias=False)
    def forward(self, hidden, encoder_outputs, mask):
        # hidden: [1, B, H], encoder_outputs: [B, L, 2H]
        src_len = encoder_outputs.size(1)
        hidden = hidden.repeat(src_len, 1, 1).transpose(0,1)   # [B, L, H]
        energy = torch.tanh(self.w(torch.cat([hidden, encoder_outputs], dim=2)))
        attention = self.v(energy).squeeze(2)                  # [B, L]
        attention = attention.masked_fill(mask, -1e10)
        return torch.softmax(attention, dim=1)

class Decoder(torch.nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim):
        super().__init__()
        self.emb = torch.nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.rnn = torch.nn.GRU(emb_dim + hid_dim*2, hid_dim, batch_first=True)
        self.fc_out = torch.nn.Linear(hid_dim*3 + emb_dim, vocab_size)
        self.attn = Attention(hid_dim)
    def forward(self, inp, hidden, enc_out, mask):
        inp = inp.unsqueeze(1)                       # [B,1]
        emb = self.emb(inp)                          # [B,1,E]
        a = self.attn(hidden, enc_out, mask)         # [B,L]
        a = a.unsqueeze(1)                           # [B,1,L]
        weighted = torch.bmm(a, enc_out)             # [B,1,2H]
        rnn_inp = torch.cat([emb, weighted], dim=2)  # [B,1,E+2H]
        out, hidden = self.rnn(rnn_inp, hidden)      # [B,1,H]
        out = self.fc_out(torch.cat([out, weighted, emb], dim=2))
        return out.squeeze(1), hidden, a.squeeze(1)  # [B,V], [1,B,H], [B,L]

class Seq2Seq(torch.nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder, self.decoder, self.device = encoder, decoder, device
    def forward(self, src, tgt, teacher_forcing=0.5):
        batch_size, max_len = src.size(0), tgt.size(1)
        vocab_size = self.decoder.fc_out.out_features
        outputs = torch.zeros(batch_size, max_len, vocab_size).to(self.device)
        enc_out, hidden = self.encoder(src)
        mask = (src == 0)
        inp = tgt[:, 0]  # <sos>
        for t in range(1, max_len):
            out, hidden, _ = self.decoder(inp, hidden, enc_out, mask)
            outputs[:, t] = out
            teacher = random.random() < teacher_forcing
            inp = tgt[:, t] if teacher else out.argmax(1)
        return outputs

3.2 训练与验证

  • 损失:CrossEntropyLoss(ignore_index=0)
  • 优化器:torch.optim.AdamW + StepLR(step=1, gamma=0.7)
  • 早停:验证集 BLEU 连续 3 次不升即停;
  • 推理:Greedy + Beam Search(beam=4)双通道,BLEU 提升约 1.5。

  1. 实验结果速览

任务数据集模型指标结果
文本分类IMDbTextCNNAccuracy0.885
文本分类IMDbTransformerCLSAccuracy0.897
机器翻译IWSLT14 De-EnGRU+LuongBLEU27.6
机器翻译IWSLT14 De-EnTransformer (官方)BLEU34.5

(注:均为单张 RTX 3060 上训练 10 epoch 的复现实验。)


  1. 小结与延伸

  • 文本分类 只需编码器,CNN 轻量、Transformer 效果更稳;
  • 机器翻译 需要解码器与注意力,GRU+Luong 是教学友好型基线,Transformer 则是 SOTA;
  • 两个任务共享同一套 Vocab/Dataset/DataLoader 抽象,让你轻松在「分类↔翻译」之间切换。

下一节将把 Transformer 全面迁移到「预训练 + 微调」范式,实现 BERT 文本分类与 mBART 机器翻译。
更多技术文章见公众号: 大城市小农民

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

乔丹搞IT

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值