代码示例:
29, PyTorch 文本预处理与词嵌入
在把 GAN 迁移到文本、序列或表格数据时,第一步永远是「把符号变成向量」。本节提供一套面向 2024 年生产环境的 PyTorch 文本预处理流水线:从原始 .txt
或 .csv
到可直接喂给 Transformer、RNN 或 1D-GAN 的 FloatTensor
。所有代码均可直接复制到 Jupyter Notebook 或上一节的 train.py
中运行。
29.1 数据获取与文件约定
data/
corpus.txt # 原始文本,一行一条样本
labels.csv # 可选:与 corpus.txt 行对齐的类别标签
29.2 文本清洗 3 行代码
import re, emoji, ftfy
def clean(text:str) -> str:
text = ftfy.fix_text(text) # 修复编码
text = emoji.replace_emoji(text, '')# 移除 emoji
text = re.sub(r'\s+', ' ', text) # 合并空格
return text.strip()
29.3 分词:三种策略一键切换
策略 | 适用场景 | PyTorch 代码 |
---|---|---|
空格分词 | 英文日志、代码 | text.split() |
SpaCy | 多语言、需词性 | import spacy; nlp("text") |
HuggingFace Tokenizer | 直接对接 Transformer | AutoTokenizer.from_pretrained("bert-base-uncased") |
下面以「HF Tokenizer」为例,因为它同时产出 input_ids
与 attention_mask
,后续可直接喂给任何模型。
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)
29.4 构建 Dataset & DataLoader
from torch.utils.data import Dataset, DataLoader
import pandas as pd, torch
class TextDataset(Dataset):
def __init__(self, txt_path, tok, max_len=128):
with open(txt_path, encoding='utf-8') as f:
self.lines = [clean(l) for l in f]
self.tok, self.max_len = tok, max_len
def __len__(self): return len(self.lines)
def __getitem__(self, idx):
enc = self.tok(
self.lines[idx],
max_length=self.max_len,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': enc['input_ids'].squeeze(0), # (L,)
'attention_mask': enc['attention_mask'].squeeze(0) # (L,)
}
ds = TextDataset('data/corpus.txt', tok)
dl = DataLoader(ds, batch_size=256, shuffle=True, num_workers=4)
29.5 词嵌入:四种方案按需选用
方案 | 特点 | 调用示例 |
---|---|---|
nn.Embedding | 自己训练 | nn.Embedding(vocab, 256) |
Pre-trained Word2Vec | 轻量、静态 | gensim.models.KeyedVectors.load('wv.bin') |
Pre-trained GloVe | 同上 | torchtext.vocab.GloVe(name='6B', dim=300) |
Transformer Embedding | 上下文相关、重 | AutoModel.from_pretrained('bert-base-uncased') |
下面给出「自己训练」与「直接拿 BERT」两套代码,方便快速切换。
29.5.1 自己训练 256 维词向量
V = tok.vocab_size
E = 256
emb = torch.nn.Embedding(V, E, padding_idx=tok.pad_token_id)
29.5.2 直接拿 BERT Embedding(冻结权重)
from transformers import AutoModel
bert = AutoModel.from_pretrained('bert-base-uncased')
for p in bert.parameters():
p.requires_grad = False
29.6 把文本喂给 1D-GAN
上一节我们实现的是 1D-GAN,只需把卷积核改为 1D,通道视为词向量维度。下面给出从 DataLoader
→ Embedding
→ Conv1d
的完整片段。
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, nz=128, embed_dim=256, seq_len=128):
super().__init__()
self.seq_len = seq_len
self.fc = nn.Linear(nz, seq_len * embed_dim)
self.conv = nn.Sequential(
nn.Conv1d(embed_dim, 512, 3, 1, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(512, embed_dim, 3, 1, 1),
nn.Tanh()
)
def forward(self, z):
# z: (B, nz)
x = self.fc(z).view(z.size(0), -1, self.seq_len) # (B, E, L)
return self.conv(x) # (B, E, L)
# 判别器同理,Conv1d 反向即可
注意:GAN 输出的是连续向量,因此仍需「解码」回文本。常见做法:
- 训练一个单独的 解码器(如
nn.Embedding
→ LSTM → softmax); - 直接用 余弦相似度 把生成向量映射到最近邻 token;
- 使用 Gumbel-Softmax 在训练阶段做可导采样。
29.7 一键脚本:文本→张量
pip install transformers datasets torchtext ftfy emoji spacy
python text2tensor.py --corpus data/corpus.txt --max_len 128 --batch 256
text2tensor.py
已整合:
- 清洗、分词、编码、Dataset、DataLoader;
- 可选
--bert
直接加载 BERT Embedding; - 输出
train.pt/val.pt
供 GAN 训练时--data_path
直接读取。
29.8 小结
- 文本预处理 ≠ 繁琐脚本,HF Tokenizer + Dataset 10 行代码即可上线。
- 词嵌入是桥梁,静态(Word2Vec/GloVe)轻量、动态(BERT)上下文丰富,按算力选。
- GAN 文本化 的关键是把
(B, L)
的离散 token 变成(B, E, L)
的连续张量,再交给 1D 卷积。
掌握本节流水线,你就能把上一节的 1D-GAN 无缝迁移到日志、评论、代码、表格等任何符号序列。下一节我们将进入 SeqGAN & LeakGAN 原理与 PyTorch 实现,用强化学习彻底解决「GAN 离散采样不可导」难题。
更多技术文章见公众号: 大城市小农民