使用 pytorch训练自己的图片分类模型

文章介绍了如何利用预训练的VGG-16模型进行迁移学习,将模型扩展以识别蚂蚁和蜜蜂。通过修改模型结构,添加新类别,并使用数据增强提升模型性能,最终在两个阶段实现了72%的训练集和更高验证集的分类精度。

如何自己训练一个图片分类模型,如果一切从头开始,对于一般公司或个人基本是难以实现的。其实,我们可以利用一个现有的图片分类模型,加上新的分类,这种方式叫做迁移学习,就是把现有的模式知识,转移到新的模型。Pytorch 官网提供已经训练好的模型,可以在此基础上训练自己的模型。我们用的模型是 VGG 分类模型,首先,先运行一个已经训练好的模型可做 1000 个分类。

安装依赖

# 去官网根据系统进行下载
pip3 install torch torchvision torchaudio
pip3 install tqdm

现有模型进行图片识别

可以去百度上下载一个狗或者鸟的图片,运行下面的程序进行识别。

# 导入软件包
import numpy as np
import json
from PIL import Image

import torch
import torchvision
from torchvision import models, transforms

#生成VGG-16模型的实例
use_pretrained = True  # 使用已经训练好的参数
net = models.vgg16(pretrained=use_pretrained)
net.eval()  # 设置为推测模式

# 对输入图片进行预处理的类
class BaseTransform():
    """
    调整图片的尺寸,并对颜色进行规范化。

    Attributes
    ----------
    resize : int
       指定调整尺寸后图片的大小
    mean : (R, G, B)
       各个颜色通道的平均值
    std : (R, G, B)
       各个颜色通道的标准偏差
    """

    def __init__(self, resize, mean, std):
        self.base_transform = transforms.Compose([
            transforms.Resize(resize),  #将较短边的长度作为resize的大小
            transforms.CenterCrop(resize),  #从图片中央截取resize × resize大小的区域
            transforms.ToTensor(),  #转换为Torch张量
            transforms.Normalize(mean, std)  #颜色信息的正规化
        ])

    def __call__(self, img):
        return self.base_transform(img)

# 根据输出结果对标签进行预测的后处理类
class ILSVRCPredictor():
    """
    根据ILSVRC数据,从模型的输出结果计算出分类标签

    Attributes
    ----------
    class_index : dictionary
           将类的index与标签名关联起来的字典型变量
    """

    def __init__(self, class_index):
        self.class_index = class_index

    def predict_max(self, out):
        """
        获得概率最大的ILSVRC分类标签名

        Parameters
        ----------
        out : torch.Size([1, 1000])
            从Net中输出结果

        Returns
        -------
        predicted_label_name : str
            预测概率最高的分类标签的名称
        """
        maxid = np.argmax(out.detach().numpy())
        predicted_label_name = self.class_index[str(maxid)][1]

        ret
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值