本文所用数据集与8.训练自己的数据集(1):TFRecords编写与读取一样,此处不再赘诉。
1.将图片文件制作成Dataset数据集
1.1获取文件名与标签
import tensorflow as tf
import os
from PIL import Image
import numpy as np
#生成图片与对应标签的列表
def load_sample(sample_dir):
#图片名列表
lfilenames = []
#标签名列表
labelnames = []
#遍历文件夹
for (dirpath,dirnames,filenames) in os.walk(sample_dir):
#遍历图片
for filename in filenames:
#每张图片的路径名
filename_path = os.sep.join([dirpath,filename])
#添加文件名
lfilenames.append(filename_path)
#添加文件名对应的标签
labelnames.append(dirpath.split('/')[-1])
#生成标签名列表
lab = list(sorted(set(labelnames)))
#生成标签字典
labdict = dict(zip(lab,list(range(len(lab)))))
#生成与图片对应的标签列表
labels = [labdict[i] for i in labelnames]
#图片与标签字典
image_label_dict = dict(zip(lfilenames,labels))
#将文件名与标