tensorflow数据加载器

构建图像分类数据集
本文介绍了一种构建图像分类数据集的方法,通过Python脚本自动读取不同类别的图像文件,将其转换为适用于深度学习模型的格式,并实现数据增强与批处理功能。

自己看的

import numpy as np
import os
from PIL import Image
import random
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

train_file = './data/train'
test_file = './data/test'

class Data:
        def __init__(self, batch_size, isTrain = True):
                self.batch_size = batch_size    # batch的大小
                self.classific = 9      # 9种分类
                self.isTrain = isTrain  # 训练
                self.ls_imgs_labels = self.get_imgs_labels()       # 获得全部的图片路径
                self.epoch_imgs, self.epoch_labels = self.get_epoch()   # 获得全部的图片内容

        def get_classific(self):
                return self.classific

        # 读取全部图片数据的路径
        def get_imgs_labels(self):
                ls_imgs = []
                cwd = os.getcwd()
                if self.isTrain:
                        folder = train_file
                        print("训练文件")
                else:
                        folder = test_file
                        print("测试文件")
                folder = os.path.normpath(folder)
                folder = os.path.join(cwd,folder)
                # 每个文件夹的图片里面的标签
                label = 0
                # 遍历每一个文件夹
                for files in os.listdir(folder):
                        if files == "airplane":
                                label = 0
                        elif files == "automobile":
                                label = 1
                        elif files == "bird":
                                label = 2
                        elif files == "cat":
                                label = 3
                        elif files == "deer":
                                label = 4
                        elif files == "dog":
                                label = 5
                        elif files == "horse":
                                label = 6
                        elif files == "ship":
                                label = 7
                        elif files == "truck":
                                label = 8
                        else:
                                print("没有这个种类 :",files)
                                return 
                        curr_folder = os.path.join(folder,files)
                        # print("当前处理图片 :",files,"标签 :",label)
                        # 每一张图片
                        for pic in os.listdir(curr_folder):
                                ls_pic = []
                                img_path = os.path.join(curr_folder,pic)
                                # img = Image.open(img_path)
                                # img_raw = img.tobytes()
                                # labels = [0]*self.classific
                                # labels[label] = 1
                                ls_pic.append(img_path)
                                ls_pic.append(label)
                                ls_imgs.append(ls_pic)
                print("全部图片记录完成    %d" % len(ls_imgs))
                return ls_imgs

        # 获得全部图片
        def get_epoch(self):
                print("begin load data")
                # 得到所有图片路径
                ls_imgs_labels = self.ls_imgs_labels

                imgs = []
                labels = []
                for name in ls_imgs_labels:
                        img = Image.open(name[0])
                        img = np.asarray(img) / 255.0
                        imgs.append(img)
                        label = [0] * self.classific
                        label[name[1]] = 1
                        labels.append(label)
                print("data load successful")
                return imgs, labels

        # 从epoch里面选择batch个图片
        def get_epoch_batch(self):
                length = len(self.epoch_labels)
                curr_imgs = []
                curr_labels = []
                count = 0
                while(count < self.batch_size):
                        begin = int(random.random() * length)
                        curr_imgs.append(self.epoch_imgs[begin])
                        curr_labels.append(self.epoch_labels[begin])
                        count = count + 1
                return curr_imgs, curr_labels

        # 获得batch_size个图
        def get_batch(self):
                # 从一个随机数开始
                length = len(self.ls_imgs_labels)
                # 保存需要得到的图片名
                curr_img_name = []
                count = 0
                while(count < self.batch_size):
                        begin = int(random.random() * length)
                        curr_img_name.append(self.ls_imgs_labels[begin])
                        count = count + 1
                
                imgs = []
                labels = []
                for name in curr_img_name:
                        img = Image.open(name[0])
                        img = np.asarray(img) / 255.0
                        imgs.append(img)
                        label = [0]*self.classific
                        label[name[1]] = 1
                        labels.append(label)
                        # print(name)
                return imgs,labels

 

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值