PyTorch Vision模型库全面解析:从分类到光流估计
概述
PyTorch Vision(torchvision)模型库提供了一系列预训练模型,覆盖了计算机视觉领域的多个重要任务。本文将全面介绍这些模型的使用方法、最佳实践和注意事项,帮助开发者快速上手并应用于实际项目中。
预训练模型基础
权重加载机制
PyTorch Vision的所有预训练模型权重都通过PyTorch的hub模块管理。当实例化一个预训练模型时,其权重会自动下载到缓存目录中。可以通过设置TORCH_HOME
环境变量来指定缓存位置。
重要提示:
- 不同预训练模型可能有不同的使用许可,取决于其训练数据集的授权条款
- 模型权重与PyTorch版本的兼容性:加载旧版PyTorch保存的state_dict可以保证兼容性,但加载整个模型或ScriptModule可能无法保证历史行为
新版权重API(v0.13+)
从v0.13版本开始,PyTorch Vision引入了多权重支持API,允许为同一模型加载不同的权重版本:
from torchvision.models import resnet50, ResNet50_Weights
# 旧权重(准确率76.130%)
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
# 新权重(准确率80.858%)
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# 最佳可用权重(当前是IMAGENET1K_V2的别名)
resnet50(weights=ResNet50_Weights.DEFAULT)
# 字符串形式也支持
resnet50(weights="IMAGENET1K_V2")
# 不使用预训练权重(随机初始化)
resnet50(weights=None)
迁移指南: 旧的pretrained
参数已被弃用,将在v0.15版本中移除。以下是新旧API的等价调用方式:
# 使用预训练权重
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) # 新API
resnet50(pretrained=True) # 旧API(已弃用)
# 不使用预训练权重
resnet50(weights=None) # 新API
resnet50(pretrained=False) # 旧API(已弃用)
模型使用最佳实践
预处理流程
使用预训练模型时,正确的图像预处理至关重要。不同模型家族、变体甚至权重版本可能需要不同的预处理方式。PyTorch Vision通过weights.transforms()
方法简化了这一过程:
from torchvision.models import resnet50, ResNet50_Weights
# 1. 初始化权重转换
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# 2. 应用预处理
img_transformed = preprocess(img)
模型模式切换
某些模型组件(如批归一化层)在训练和评估时有不同行为。使用前应正确设置模型模式:
model = resnet50(weights=ResNet50_Weights.DEFAULT)
model.eval() # 设置为评估模式
# 或
model.train() # 设置为训练模式
模型检索与列表(v0.14+)
PyTorch Vision提供了便捷的模型检索功能:
from torchvision.models import list_models, get_model, get_model_weights
# 列出所有可用模型
all_models = list_models()
# 获取特定模型
model = get_model("mobilenet_v3_large", weights=None)
# 获取模型对应的权重枚举
weights_enum = get_model_weights("mobilenet_v3_large")
主要模型类别
图像分类模型
PyTorch Vision提供了丰富的分类模型,包括:
- 经典CNN架构:AlexNet、VGG、ResNet等
- 轻量级模型:MobileNet、ShuffleNet等
- 新型架构:Vision Transformer、Swin Transformer等
使用示例:
from torchvision.models import resnet50, ResNet50_Weights
# 初始化模型和预处理
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights).eval()
preprocess = weights.transforms()
# 推理
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
category = weights.meta["categories"][class_id]
量化模型
PyTorch Vision支持多种INT8量化模型,包括:
- Quantized ResNet
- Quantized MobileNet
- Quantized ShuffleNet等
使用示例:
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights
model = resnet50(weights=ResNet50_QuantizedWeights.DEFAULT, quantize=True)
语义分割模型
支持的模型包括:
- FCN
- DeepLabV3
- LR-ASPP
输出格式: 语义分割模型的输出是一个字典,包含键"out",对应形状为(N, C, H, W)
的张量,其中C是类别数。
目标检测与实例分割
包括以下模型:
- Faster R-CNN
- RetinaNet
- SSD
- Mask R-CNN
注意事项: 检测模型需要输入一个图像列表(List[Tensor[C, H, W]]
),输出是包含预测框、分数和标签的字典列表。
视频分类模型
支持以下架构:
- 3D ResNet
- S3D
- MViT
- Swin Transformer
光流估计
当前支持RAFT模型,用于计算两帧图像之间的光流。
性能指标
PyTorch Vision为每个预训练模型提供了详细的性能指标,包括:
- 分类模型:ImageNet-1K top-1/top-5准确率
- 检测/分割模型:COCO数据集上的mAP
- 视频模型:Kinetics-400上的准确率
开发者可以根据这些指标选择最适合自己需求的模型。
总结
PyTorch Vision模型库提供了从传统到前沿的各种计算机视觉模型,通过统一的API设计简化了模型加载和使用流程。掌握这些模型的特性和使用方法,可以大幅提升开发效率,快速构建高质量的计算机视觉应用。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考