复现resnet18花分类,并计算torch/onnx/mnn fp32和int8精度

该文介绍了如何利用PyTorch的resnet18模型对花卉数据集进行训练,定义了数据预处理,加载训练和验证数据,调整模型以适应5类分类任务。训练完成后,模型被导出为ONNX格式,然后转换为MNN模型并进行量化,以供部署在移动设备上使用。
部署运行你感兴趣的模型镜像

【pytorch花分类】使用torchvision的resnet18

提前下载好数据集并且分割好训练和验证集

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms

# 定义数据的路径和预处理方法
data_dir = '/home/ruoji/MNN/data/flowers'
train_dir = data_dir + '/train'
val_dir = data_dir + '/val'

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# 加载数据
train_data = datasets.ImageFolder(train_dir, data_transforms['train'])
val_data = datasets.ImageFolder(val_dir, data_transforms['val'])
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=4)

# 加载模型
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 5)  # 将全连接层改为5分类
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 10
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

best_acc = 0.0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    scheduler.step()

    model.eval()
    num_correct = 0
    num_total = 0
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        num_correct += (predicted == labels).sum().item()
        num_total += labels.size(0)

    epoch_loss = running_loss / len(train_data)
    epoch_acc = num_correct / num_total
    print('Epoch {} - Loss: {:.4f} Acc: {:.4f}'.format(epoch+1, epoch_loss, epoch_acc))
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        PATH = '/home/ruoji/MNN/data/resnet18_flowers_best.pth'
        torch.save(model.state_dict(), PATH)

print('Finished Training')

导出onnx模型

import torch
import torchvision

# 加载训练好的模型
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 5)
model.load_state_dict(torch.load('/home/ruoji/MNN/data/resnet18_flowers.pth'))

# 设置输入张量的形状
batch_size = 1
input = torch.randn(batch_size, 3, 224, 224)

input_name = 'input'
output_name = 'output'
# 将模型转换为ONNX格式
torch.onnx.export(model, input, 
                  'resnet18_flower_onnx.onnx', 
                  input_names = [input_name],
                  output_names = [output_name],
                  verbose=True,
                  opset_version=11,
                  dynamic_axes={input_name: {0: 'batch_size'},
                                output_name: {0: 'batch_size'}})

onnx转mnn int量化

mnnconvert -f ONNX --bizCode MNN --modelFile resnet18_flower_onnx.onnx --MNNModel resnet18_flower_fp32.mnn --keepInputFormat


mnnquant resnet18_flower_fp32.mnn resnet18_flower_int8.mnn quant_flower.json 
{
    "format":"RGB",
    "mean":[
        103.94,
        116.78,
        123.68
    ],
    "normal":[
        0.017,
        0.017,
        0.017
    ],
    "width":224,
    "height":224,
    "path":"/home/ruoji/MNN/data/flowers/val/daisy",
    "used_image_num":50,
    "feature_quantize_method":"KL",
    "weight_quantize_method":"MAX_ABS"
}

测试精度

pth: 0.9478
在这里插入图片描述

onnx : 0.9341
在这里插入图片描述

mnn_fp32: 0.9341
在这里插入图片描述
mnn_int8: 0.9341 离线量化
在这里插入图片描述

您可能感兴趣的与本文相关的镜像

PyTorch 2.9

PyTorch 2.9

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

在将 PyTorch 模型转换为 ONNX 模型时,FP32FP16、INT8 精度的转换方法各有不同。 ### FP32 精度转换 FP32 是默认的精度类型,使用 PyTorch 自带的 `torch.onnx.export` 方法即可完成转换。此方法能方便地将仅含通用算子的网络的 PyTorch 模型转为 ONNX 格式。示例代码如下: ```python import torch import torchvision.models as models # 加载预训练模型 model = models.resnet18(pretrained=True) model.eval() # 定义输入张量 dummy_input = torch.randn(1, 3, 224, 224) # 导出为 ONNX 模型 torch.onnx.export(model, dummy_input, "resnet18_fp32.onnx", verbose=True) ``` ### FP16 精度转换 要进行 FP16 精度转换,需先将模型输入数据转换为半精度FP16),再使用 `torch.onnx.export` 方法导出。示例代码如下: ```python import torch import torchvision.models as models # 加载预训练模型 model = models.resnet18(pretrained=True) model = model.half() # 将模型转换为 FP16 model.eval() # 定义输入张量转换为 FP16 dummy_input = torch.randn(1, 3, 224, 224).half() # 导出为 ONNX 模型 torch.onnx.export(model, dummy_input, "resnet18_fp16.onnx", verbose=True) ``` ### INT8 精度转换 INT8 量化通常借助 ONNX Runtime 的量化工具。步骤如下: 1. 安装依赖库:`pip install onnx onnxruntime onnxruntime-tools` 2. 转换模型: ```python import torch import torchvision.models as models import onnx from onnxruntime.quantization import quantize_dynamic, QuantType # 加载预训练模型 model = models.resnet18(pretrained=True) model.eval() # 定义输入张量 dummy_input = torch.randn(1, 3, 224, 224) # 导出为 ONNX 模型 torch.onnx.export(model, dummy_input, "resnet18.onnx", verbose=True) # 加载 ONNX 模型 onnx_model = onnx.load("resnet18.onnx") # 动态量化为 INT8 quantized_model = quantize_dynamic( onnx_model, "resnet18_int8.onnx", weight_type=QuantType.QInt8 ) ```
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值