pytorch快手上手——从读取数据到ncnn部署

本文介绍了如何使用PyTorch处理数据集,训练模型,然后将模型转换为NCNN进行图片分类预测。详细讲解了从torch2yaml、定义dataset类,到编译配置NCNN、转换模型,最后实现NCNN预测的过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

1. 读取数据集

1.1 torch2yaml

1. 2定义一个dataset类

2. 训练模型

3. 图片分类预测

3.1 python中图片预测

3.2 使用ncnn进行图片分类预测

3.2.1 编译和配置ncnn

3.2.2 pytorch模型转ncnn模型

3.2.2 使用ncnn进行预测

4. 代码链接

5. 参考链接


1. 读取数据集

pytorch默认的数据存放格式如下图所示,使用torchvision.datasets.ImageFolder函数可以直接生成数据和标签,ImageFolder内部使用PIL进行数据读取,在python中使用没有问题,但部署时一般都是采用opencv读取图片的,容易造成数据的不统一,故在python中可以直接使用opencv读取图片代码如下:

1.1 torch2yaml

根据图片存储生成存放图片路径和标签的yaml文件,详细代码如下:

import os
import yaml

# class与label对应关系
class2label = {"bird": 0, "boat": 1, "cake": 2, "jellyfish": 3, "king_crab": 4}

# 生成图片路径
def gen_imgdir(dir="train"):
    txt = []
    for classdir in os.listdir("./mini_imagenet/" + dir):
        for imgdir in os.listdir("./mini_imagenet/" + dir + "/" + classdir):
            txt.append("./mini_imagenet/" + dir + "/" + classdir + "/" + imgdir)
    with open(dir + '.yaml', "w", encoding="utf-8") as f1:
        yaml.dump(txt, f1)

# 生成图片标签
def gen_label(dir="train"):
    txt = []
    with open(dir + ".yaml", "r") as f2:
        con = f2.read()
        C = yaml.load(con, Loader=yaml.FullLoader)
        for i in C:
            label = i.split("/")[3]
            txt.append(class2label[label])
    with open(dir + '_label.yaml', "w", encoding="utf-8") as f2:
        yaml.dump(txt, f2)



if __name__ == '__main__':
    # 生成图片路径和标签
    # for i in {"train", "test", "val"}:
    #     dir = i
    #     gen_imgdir(dir)
    #     gen_label(dir)
    # 下面代码在生成images文件夹下的yaml文件时使用
    # txt = []
    # for classdir in os.listdir("./images"):
    #     txt.append("./images" "/" + classdir)
    # with open("./images/predict" + '.yaml', "w", encoding="utf-8") as f1:
    #     yaml.dump(txt, f1)
    # txt = []
    # with open("./images/predict" + '.yaml', "r") as f2:
    #     con = f2.read()
    #     C = yaml.load(con, Loader=yaml.FullLoader)
    #     for i in C:
    #         label = i.split("/")[3]
    #         txt.append(class2label[label])
    # with open("./images/predict" + '_label.yaml', "w", encoding="utf-8") as f2:
    #     yaml.dump(txt, f2)
    pass

1. 2定义一个dataset类

有了上面的yaml文件,就可以随手写一个dataset类读取和加载数据

from torch.utils.data import Dataset
import os
import cv2
import warnings
import numpy as np
import torch
import yaml
import albumentations as A
warnings.filterwarnings("ignore")


# -------------------------------------------------定义数据类---------------------------------------------------
class mydataset(Dataset):
    def __init__(self,
                 trans=True,
                 mode="train",
                 img_size=(224, 224)):
        # -------------------------------------------初始化数据&标签-----------------------------------------------------
        super().__init__()
        self.img_size = img_size
        self.trans = trans
        with open(mode + '.yaml', "r") as f1:
            content = f1.read()
            self.dir = yaml.load(content, Loader=yaml.FullLoader)
        with open(mode + '_label.yaml', "r") as f2:
            content = f2.read()
            self.label = yaml.load(content, Loader=yaml.FullLoader)

    # ----------------------------------------------------获取数据---------------------------------------------------
    def __getitem__(self, idx):
        imgsdir = self.dir[idx]
        img = cv2.imread(imgsdir)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # --------------------------------------------数据增强----------------------------------------------------
        if self.trans:
            img = self.img_aug(img)
        # ---------------------------------------------生成数据-----------------------------------------------------
        img = cv2.resize(img, self.img_size)
        img = img.transpose(2, 0, 1).astype('float32')
        label = np.array(self.label[idx])
        return torch.tensor(img), torch.tensor(label.astype('int64'))

    # ----------------------------------------------------获取数据量-------------------------------------------------

    def __len__(self):
        return len(self.dir)

    # -----------------------------------------------------数据增强--------------------------------------------------
    def img_aug(self, img):
        return A.normalize(img=img, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

    # -----------------------------------------------------图像显示--------------------------------------------------
    def show_normalize(self, idx):
        for i in idx:
            imgsdir = self.dir[i]
            img = cv2.imread(imgsdir)
            cv2.imshow('origin_img', img)
            cv2.moveWin
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值