第八章:PyTorch生态简介 — 深入浅出PyTorch datawhale ai 共学
datasets / transforms / model-zoo / audio & text
8.1 生态概览
| 领域 | 官方包 | 用途 |
|---|---|---|
| CV | torchvision | 图像/视频数据集、数据增强、预训练模型 |
| NLP | torchtext | 文本数据处理、词表、常见数据集 |
| Audio | torchaudio | 音频 I/O、特征、datasets、pipelines |
| Video | pytorchvideo | SOTA 视频模型 & pipeline (Meta) |
| Graphs | torch_geometric | GNN 旗舰库 (生态外) |
8.2 torchvision
8.2.1 datasets
from torchvision import datasets, transforms
train_ds = datasets.CIFAR10(root='./data',
train=True,
download=True,
transform=transforms.ToTensor())
-
-
root:本地缓存目录。 -
download=True:如不存在本地文件则从 mirror 自动拉取 -
transform:对 单张 样本调用,不对 batch 起效
-
| 症状 | 原因 | 修复 |
|---|---|---|
| 下载卡死 0B/s | 默认站点被墙 | TORCHVISION_DOWNLOAD_MIRRORS=http://download.pytorch.org 或换清华镜像 |
| 爆内存 | FakeData 默认 1000×3×224×224 | 指定 size= & image_size= |
8.2.2 transforms
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
img = Image.open("./figures/lenna.jpg") # ① 读入 PIL.Image
trans = transforms.Compose([ # ② pipeline
transforms.Resize(256), # 等比 ⇒ 短边 256
transforms.CenterCrop(224), # 中心裁 224²
transforms.ColorJitter(0.4,0.4,0.4,0.1), # 随机亮度/对比度/饱和/色调
transforms.ToTensor(), # [0,255] PIL ⇒ [0,1] Tensor
transforms.Normalize( # 标准化(Imagenet 统计值)
mean=[0.485,0.456,0.406],
std =[0.229,0.224,0.225])
])
t_img = trans(img) # ③ 调用
| 报错 | 含义 | 解决 |
|---|---|---|
TypeError: pic should be PIL Image | 在 ToTensor 前用了 plt.imread | 保持 PIL 或改用 transforms.ToPILImage() |
| 图像全黑 | ColorJitter 参数太大 | 控制在 0–0.5 范围 |
拓展
-
RandAug / TrivialAug 已内置于
torchvision.transforms.v2(≥0.15) -
批量增强 用
torchvision.transforms.functional+vmap/for,Compose 仅对 sample
8.2.3 models
import torchvision.models as models
net = models.resnet18(weights='IMAGENET1K_V1') # ① 加载权重
net.fc = torch.nn.Linear(net.fc.in_features, 4) # ② 微调到 4 类
-
weights新接口 (0.13+);旧pretrained=True已 deprecated -
微调时记得
for p in net.parameters(): p.requires_grad = False再解冻fc
8.3 PyTorchVideo
1. Hub 调用
import torch
model = torch.hub.load('facebookresearch/pytorchvideo',
model='slowfast_r50',
pretrained=True)
model.eval()
2. 输入尺寸
(B, C=3, T, H, W),默认 32 x 3 clip
clip = torch.randn(1, 3, 32, 224, 224)
with torch.no_grad(): logits = model(clip)
常见错误
RuntimeError: Given groups=1, weight ... → 维度顺序,把 (B,C,H,W,T) 写反
8.4 torchtext
8.4.1 快速流水线
from torchtext.data import Field, BucketIterator
from torchtext.datasets import IMDB
token = lambda x: x.split()
TEXT = Field(tokenize=token, lower=True, batch_first=True)
LABEL = Field(sequential=False, unk_token=None)
train_ds, test_ds = IMDB.splits(TEXT, LABEL)
TEXT.build_vocab(train_ds, max_size=20000, vectors='glove.6B.100d')
LABEL.build_vocab(train_ds)
train_iter, test_iter = BucketIterator.splits(
(train_ds, test_ds), batch_size=32, sort_key=lambda x: len(x.text))
-
踩坑:新版 torchtext (0.12+) 移除了上面老接口,需要用
torchtext.legacy或新 API (torchtext.data.functional,torchtext.datasets.IMDB(split='train'))
8.5 torchaudio
8.5.1 基础 I/O
import torchaudio
wave, sr = torchaudio.load('speech.wav') # (channel, time)
resample = torchaudio.transforms.Resample(sr, 16_000)
mel_spect = torchaudio.transforms.MelSpectrogram(16_000)
feat = mel_spect(resample(wave)) # (C, n_mels, frames)
8.5.2 预训练 ASR
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
asr = bundle.get_model().eval()
tokens = asr(wave.unsqueeze(0)) # logits
| 需求 | 方案 |
|---|---|
| 移动端 on-device 推理 | torchaudio.models.emformer_rnnt + TorchScript |
| 多 说话人 | Asteroid, SpeechBrain 社区项目 |
常见跨包报错速查表
| Error | 触发场景 | 修正 |
|---|---|---|
libsox not found | torchaudio I/O | conda install -c conda-forge sox |
No module named 'pathlib' | 老 Python | >=3.7 |
TypeError: expected scalar type Double but found Float | torchaudio + GPU | .to(dtype=torch.float32) |
Checklist
-
版本对齐:
torch == torchvision == torchaudiomajor/minor 相同 -
数据 Aug:影像
transforms.v2,音频torchaudio.sox_effects.apply_effects_tensor,文本nlpaug。 -
模型 Zoo:Torch-Hub(CV/Video)+
torchaudio.pipelines(Audio)+transformers(NLP) -
可视化:
torchinfo(结构)+ TensorBoard / wandb(指标) -
部署:移动端 → PyTorchVideo
accelerator; 服务器 →torch.compile≥2.0。
8944

被折叠的 条评论
为什么被折叠?



