29,PyTorch 文本预处理与词嵌入

代码示例:

29, PyTorch 文本预处理与词嵌入

在把 GAN 迁移到文本、序列或表格数据时,第一步永远是「把符号变成向量」。本节提供一套面向 2024 年生产环境的 PyTorch 文本预处理流水线:从原始 .txt.csv 到可直接喂给 Transformer、RNN 或 1D-GAN 的 FloatTensor。所有代码均可直接复制到 Jupyter Notebook 或上一节的 train.py 中运行。


29.1 数据获取与文件约定

data/
  corpus.txt          # 原始文本,一行一条样本
  labels.csv          # 可选:与 corpus.txt 行对齐的类别标签

29.2 文本清洗 3 行代码

import re, emoji, ftfy

def clean(text:str) -> str:
    text = ftfy.fix_text(text)          # 修复编码
    text = emoji.replace_emoji(text, '')# 移除 emoji
    text = re.sub(r'\s+', ' ', text)    # 合并空格
    return text.strip()

29.3 分词:三种策略一键切换

策略适用场景PyTorch 代码
空格分词英文日志、代码text.split()
SpaCy多语言、需词性import spacy; nlp("text")
HuggingFace Tokenizer直接对接 TransformerAutoTokenizer.from_pretrained("bert-base-uncased")

下面以「HF Tokenizer」为例,因为它同时产出 input_idsattention_mask,后续可直接喂给任何模型。

from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)

29.4 构建 Dataset & DataLoader

from torch.utils.data import Dataset, DataLoader
import pandas as pd, torch

class TextDataset(Dataset):
    def __init__(self, txt_path, tok, max_len=128):
        with open(txt_path, encoding='utf-8') as f:
            self.lines = [clean(l) for l in f]
        self.tok, self.max_len = tok, max_len

    def __len__(self): return len(self.lines)

    def __getitem__(self, idx):
        enc = self.tok(
            self.lines[idx],
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': enc['input_ids'].squeeze(0),       # (L,)
            'attention_mask': enc['attention_mask'].squeeze(0)  # (L,)
        }

ds = TextDataset('data/corpus.txt', tok)
dl = DataLoader(ds, batch_size=256, shuffle=True, num_workers=4)

29.5 词嵌入:四种方案按需选用

方案特点调用示例
nn.Embedding自己训练nn.Embedding(vocab, 256)
Pre-trained Word2Vec轻量、静态gensim.models.KeyedVectors.load('wv.bin')
Pre-trained GloVe同上torchtext.vocab.GloVe(name='6B', dim=300)
Transformer Embedding上下文相关、重AutoModel.from_pretrained('bert-base-uncased')

下面给出「自己训练」与「直接拿 BERT」两套代码,方便快速切换。

29.5.1 自己训练 256 维词向量

V = tok.vocab_size
E = 256
emb = torch.nn.Embedding(V, E, padding_idx=tok.pad_token_id)

29.5.2 直接拿 BERT Embedding(冻结权重)

from transformers import AutoModel
bert = AutoModel.from_pretrained('bert-base-uncased')
for p in bert.parameters():
    p.requires_grad = False

29.6 把文本喂给 1D-GAN

上一节我们实现的是 1D-GAN,只需把卷积核改为 1D,通道视为词向量维度。下面给出从 DataLoaderEmbeddingConv1d 的完整片段。

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, nz=128, embed_dim=256, seq_len=128):
        super().__init__()
        self.seq_len = seq_len
        self.fc = nn.Linear(nz, seq_len * embed_dim)
        self.conv = nn.Sequential(
            nn.Conv1d(embed_dim, 512, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(512, embed_dim, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, z):
        # z: (B, nz)
        x = self.fc(z).view(z.size(0), -1, self.seq_len)  # (B, E, L)
        return self.conv(x)                               # (B, E, L)

# 判别器同理,Conv1d 反向即可

注意:GAN 输出的是连续向量,因此仍需「解码」回文本。常见做法:

  1. 训练一个单独的 解码器(如 nn.Embedding → LSTM → softmax);
  2. 直接用 余弦相似度 把生成向量映射到最近邻 token;
  3. 使用 Gumbel-Softmax 在训练阶段做可导采样。

29.7 一键脚本:文本→张量

pip install transformers datasets torchtext ftfy emoji spacy
python text2tensor.py --corpus data/corpus.txt --max_len 128 --batch 256

text2tensor.py 已整合:

  • 清洗、分词、编码、Dataset、DataLoader;
  • 可选 --bert 直接加载 BERT Embedding;
  • 输出 train.pt/val.pt 供 GAN 训练时 --data_path 直接读取。

29.8 小结

  1. 文本预处理 ≠ 繁琐脚本,HF Tokenizer + Dataset 10 行代码即可上线。
  2. 词嵌入是桥梁,静态(Word2Vec/GloVe)轻量、动态(BERT)上下文丰富,按算力选。
  3. GAN 文本化 的关键是把 (B, L) 的离散 token 变成 (B, E, L) 的连续张量,再交给 1D 卷积。

掌握本节流水线,你就能把上一节的 1D-GAN 无缝迁移到日志、评论、代码、表格等任何符号序列。下一节我们将进入 SeqGAN & LeakGAN 原理与 PyTorch 实现,用强化学习彻底解决「GAN 离散采样不可导」难题。
更多技术文章见公众号: 大城市小农民

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

乔丹搞IT

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值