【pytorch花分类】使用torchvision的resnet18
提前下载好数据集并且分割好训练和验证集
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
# 定义数据的路径和预处理方法
data_dir = '/home/ruoji/MNN/data/flowers'
train_dir = data_dir + '/train'
val_dir = data_dir + '/val'
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
# 加载数据
train_data = datasets.ImageFolder(train_dir, data_transforms['train'])
val_data = datasets.ImageFolder(val_dir, data_transforms['val'])
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=4)
# 加载模型
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 5) # 将全连接层改为5分类
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 10
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
best_acc = 0.0
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
scheduler.step()
model.eval()
num_correct = 0
num_total = 0
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
num_correct += (predicted == labels).sum().item()
num_total += labels.size(0)
epoch_loss = running_loss / len(train_data)
epoch_acc = num_correct / num_total
print('Epoch {} - Loss: {:.4f} Acc: {:.4f}'.format(epoch+1, epoch_loss, epoch_acc))
if epoch_acc > best_acc:
best_acc = epoch_acc
PATH = '/home/ruoji/MNN/data/resnet18_flowers_best.pth'
torch.save(model.state_dict(), PATH)
print('Finished Training')
导出onnx模型
import torch
import torchvision
# 加载训练好的模型
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 5)
model.load_state_dict(torch.load('/home/ruoji/MNN/data/resnet18_flowers.pth'))
# 设置输入张量的形状
batch_size = 1
input = torch.randn(batch_size, 3, 224, 224)
input_name = 'input'
output_name = 'output'
# 将模型转换为ONNX格式
torch.onnx.export(model, input,
'resnet18_flower_onnx.onnx',
input_names = [input_name],
output_names = [output_name],
verbose=True,
opset_version=11,
dynamic_axes={input_name: {0: 'batch_size'},
output_name: {0: 'batch_size'}})
onnx转mnn int量化
mnnconvert -f ONNX --bizCode MNN --modelFile resnet18_flower_onnx.onnx --MNNModel resnet18_flower_fp32.mnn --keepInputFormat
mnnquant resnet18_flower_fp32.mnn resnet18_flower_int8.mnn quant_flower.json
{
"format":"RGB",
"mean":[
103.94,
116.78,
123.68
],
"normal":[
0.017,
0.017,
0.017
],
"width":224,
"height":224,
"path":"/home/ruoji/MNN/data/flowers/val/daisy",
"used_image_num":50,
"feature_quantize_method":"KL",
"weight_quantize_method":"MAX_ABS"
}
测试精度
pth: 0.9478

onnx : 0.9341

mnn_fp32: 0.9341

mnn_int8: 0.9341 离线量化

该文介绍了如何利用PyTorch的resnet18模型对花卉数据集进行训练,定义了数据预处理,加载训练和验证数据,调整模型以适应5类分类任务。训练完成后,模型被导出为ONNX格式,然后转换为MNN模型并进行量化,以供部署在移动设备上使用。
3060

被折叠的 条评论
为什么被折叠?



