目录
1. 读取数据集
pytorch默认的数据存放格式如下图所示,使用torchvision.datasets.ImageFolder函数可以直接生成数据和标签,ImageFolder内部使用PIL进行数据读取,在python中使用没有问题,但部署时一般都是采用opencv读取图片的,容易造成数据的不统一,故在python中可以直接使用opencv读取图片代码如下:
1.1 torch2yaml
根据图片存储生成存放图片路径和标签的yaml文件,详细代码如下:
import os
import yaml
# class与label对应关系
class2label = {"bird": 0, "boat": 1, "cake": 2, "jellyfish": 3, "king_crab": 4}
# 生成图片路径
def gen_imgdir(dir="train"):
txt = []
for classdir in os.listdir("./mini_imagenet/" + dir):
for imgdir in os.listdir("./mini_imagenet/" + dir + "/" + classdir):
txt.append("./mini_imagenet/" + dir + "/" + classdir + "/" + imgdir)
with open(dir + '.yaml', "w", encoding="utf-8") as f1:
yaml.dump(txt, f1)
# 生成图片标签
def gen_label(dir="train"):
txt = []
with open(dir + ".yaml", "r") as f2:
con = f2.read()
C = yaml.load(con, Loader=yaml.FullLoader)
for i in C:
label = i.split("/")[3]
txt.append(class2label[label])
with open(dir + '_label.yaml', "w", encoding="utf-8") as f2:
yaml.dump(txt, f2)
if __name__ == '__main__':
# 生成图片路径和标签
# for i in {"train", "test", "val"}:
# dir = i
# gen_imgdir(dir)
# gen_label(dir)
# 下面代码在生成images文件夹下的yaml文件时使用
# txt = []
# for classdir in os.listdir("./images"):
# txt.append("./images" "/" + classdir)
# with open("./images/predict" + '.yaml', "w", encoding="utf-8") as f1:
# yaml.dump(txt, f1)
# txt = []
# with open("./images/predict" + '.yaml', "r") as f2:
# con = f2.read()
# C = yaml.load(con, Loader=yaml.FullLoader)
# for i in C:
# label = i.split("/")[3]
# txt.append(class2label[label])
# with open("./images/predict" + '_label.yaml', "w", encoding="utf-8") as f2:
# yaml.dump(txt, f2)
pass
1. 2定义一个dataset类
有了上面的yaml文件,就可以随手写一个dataset类读取和加载数据
from torch.utils.data import Dataset
import os
import cv2
import warnings
import numpy as np
import torch
import yaml
import albumentations as A
warnings.filterwarnings("ignore")
# -------------------------------------------------定义数据类---------------------------------------------------
class mydataset(Dataset):
def __init__(self,
trans=True,
mode="train",
img_size=(224, 224)):
# -------------------------------------------初始化数据&标签-----------------------------------------------------
super().__init__()
self.img_size = img_size
self.trans = trans
with open(mode + '.yaml', "r") as f1:
content = f1.read()
self.dir = yaml.load(content, Loader=yaml.FullLoader)
with open(mode + '_label.yaml', "r") as f2:
content = f2.read()
self.label = yaml.load(content, Loader=yaml.FullLoader)
# ----------------------------------------------------获取数据---------------------------------------------------
def __getitem__(self, idx):
imgsdir = self.dir[idx]
img = cv2.imread(imgsdir)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# --------------------------------------------数据增强----------------------------------------------------
if self.trans:
img = self.img_aug(img)
# ---------------------------------------------生成数据-----------------------------------------------------
img = cv2.resize(img, self.img_size)
img = img.transpose(2, 0, 1).astype('float32')
label = np.array(self.label[idx])
return torch.tensor(img), torch.tensor(label.astype('int64'))
# ----------------------------------------------------获取数据量-------------------------------------------------
def __len__(self):
return len(self.dir)
# -----------------------------------------------------数据增强--------------------------------------------------
def img_aug(self, img):
return A.normalize(img=img, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
# -----------------------------------------------------图像显示--------------------------------------------------
def show_normalize(self, idx):
for i in idx:
imgsdir = self.dir[i]
img = cv2.imread(imgsdir)
cv2.imshow('origin_img', img)
cv2.moveWin