python制作自己的数据集

本文介绍了如何使用Python和TensorFlow创建一个TFRecords数据集,该数据集包含17种鲜花的图像,每种80张。文章详细阐述了从文件组织、使用os和PIL库处理图像到生成TFRecords文件的过程,以及读取和验证TFRecords文件内容的步骤。
部署运行你感兴趣的模型镜像

一、数据集介绍

点击打开链接17_Category_Flower 是一个不同种类鲜花的图像数据,包含 17 不同种类的鲜花,每类 80 张该类鲜花的图片,鲜花种类是英国地区常见鲜花。下载数据后解压文件,然后将不同的花剪切到对应的文件夹,如下图所示:

每个文件夹下面有80个图片文件。

二、使用的工具

首先是在tensorflow框架下,然后介绍一下用到的两个库,一个是os,一个是PIL。PILPython Imaging Library)是 Python 中最常用的图像处理库,而Image类又是 PIL库中一个非常重要的类,通过这个类来创建实例可以有直接载入图像文件,读取处理过的图像和通过抓取的方法得到的图像这三种方法。

三、代码实现

我们是通过TFRecords来创建数据集的,TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件(label)。

1、制作TFRecords文件

import os
import tensorflow as tf
from PIL import Image  # 注意Image,后面会用到
import matplotlib.pyplot as plt
import numpy as np

cwd = 'D:\PyCharm Community Edition 2017.2.3\Work\google_net\jpg\\'
classes = {'daffodil', 'snowdrop', 'lilyvalley', 'bluebell', 'crocus', 'iris', 'tigerlily', 'tulip', 'fritiuary',
           'sunflower', 'daisy', 'coltsfoot', 'dandelion', 'cowslip', 'buttercup', 'windflower', 'pansy'}  # 花为 设定 17 类
writer = tf.python_io.TFRecordWriter("flower_train.tfrecords")  # 要生成的文件

for index, name in enumerate(classes):
    class_path = cwd + name + '\\'
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name  # 每一个图片的地址
        img = Image.open(img_path)
        img = img.resize((224, 224))
        img_raw = img.tobytes()  # 将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))  # example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  # 序列化为字符串
writer.close()


首先将文件移动到对应的路径:

D:\PyCharm Community Edition 2017.2.3\Work\google_net\jpg

然后对每个文件下的图片进行读写和相应的大小惊醒改变,具体过程是使用tf.train.Example来定义我们要填入的数据格式,其中label即为标签,也就是最外层的文件夹名字,img_raw为易经理二进制化的图片。然后使用tf.python_io.TFRecordWriter来写入。基本的,一个Example中包含FeaturesFeatures里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList 或者ByteList,或者Int64List。就这样,我们把相关的信息都存到了一个文件中,所以前面才说不用单独的label文件。而且读取也很方便。

执行完以上代码就会出现如下图所示的TF文件


2、读取TFRECORD文件

制作完文件后,将该文件读入到数据流中,具体代码如下:

def read_and_decode(filename):  # 读入dog_train.tfrecords
    filename_queue = tf.train.string_input_producer([filename])  # 生成一个queue队列
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # 将image数据和label取出来

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [224, 224, 3])  # reshape为128*128的3通道图片
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5  # 在流中抛出img张量
    label = tf.cast(features['label'], tf.int32)  # 在流中抛出label张量
    return img, label

注意,feature的属性“label”和“img_raw”名称要和制作时统一 ,返回的img数据和label数据一一对应。

3、显示tfrecord格式的图片

为了知道TF 文件的具体内容,或者是怕图片对应的label出错,可以将数据流以图片的形式读出来并保存以便查看,具体的代码如下:

filename_queue = tf.train.string_input_producer(["flower_train.tfrecords"])  # 读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw': tf.FixedLenFeature([], tf.string),
                                   })  # 取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [224, 224, 3])
label = tf.cast(features['label'], tf.int32)
label = tf.one_hot(label, 17, 1, 0)
with tf.Session() as sess:  # 开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(100):
        example, l = sess.run([image, label])  # 在会话中取出image和label
        img = Image.fromarray(example, 'RGB')  # 这里Image是之前提到的
        img.save(cwd + str(i) + '_''Label_' + str(l) + '.jpg')  # 存下图片
        print(example, l)
    coord.request_stop()
    coord.join(threads)

执行以上代码后,当前项目对应的文件夹下会生成100张图片,还有对应的label,如下图所示:


在这里我们可以看到,前80个图片文件的label是1,后20个图片的label是2。 由此可见,我们一开始制作tfrecord文件时,图片分类正确。

完整代码如下:

import os
import tensorflow as tf
from PIL import Image  # 注意Image,后面会用到
import matplotlib.pyplot as plt
import numpy as np

cwd = 'D:\PyCharm Community Edition 2017.2.3\Work\google_net\jpg\\'
classes = {'daffodil', 'snowdrop', 'lilyvalley', 'bluebell', 'crocus', 'iris', 'tigerlily', 'tulip', 'fritiuary',
           'sunflower', 'daisy', 'coltsfoot', 'dandelion', 'cowslip', 'buttercup', 'windflower', 'pansy'}  # 花为 设定 17 类
writer = tf.python_io.TFRecordWriter("flower_train.tfrecords")  # 要生成的文件

for index, name in enumerate(classes):
    class_path = cwd + name + '\\'
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name  # 每一个图片的地址
        img = Image.open(img_path)
        img = img.resize((224, 224))
        img_raw = img.tobytes()  # 将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))  # example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  # 序列化为字符串
writer.close()


def read_and_decode(filename):  # 读入dog_train.tfrecords
    filename_queue = tf.train.string_input_producer([filename])  # 生成一个queue队列
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # 将image数据和label取出来

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [224, 224, 3])  # reshape为128*128的3通道图片
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5  # 在流中抛出img张量
    label = tf.cast(features['label'], tf.int32)  # 在流中抛出label张量
    return img, label


filename_queue = tf.train.string_input_producer(["flower_train.tfrecords"])  # 读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw': tf.FixedLenFeature([], tf.string),
                                   })  # 取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [224, 224, 3])
label = tf.cast(features['label'], tf.int32)
label = tf.one_hot(label, 17, 1, 0)
with tf.Session() as sess:  # 开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(100):
        example, l = sess.run([image, label])  # 在会话中取出image和label
        img = Image.fromarray(example, 'RGB')  # 这里Image是之前提到的
        img.save(cwd + str(i) + '_''Label_' + str(l) + '.jpg')  # 存下图片
        print(example, l)
    coord.request_stop()
    coord.join(threads)

本人也是刚刚学习深度学习,能力有限,不足之处请见谅,欢迎大牛一起讨论,共同进步!

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论 31
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值