一.数据集中常用的高层API函数:
- dataset.map(func)
- dataset.shuffle(buffer_size)#随机打乱数据的顺序,buffer_size参数等价于tf.train.shuffle_batch中的min_after_dequeue参数。作用是保证每次打乱数据时的最低数量,要是打乱的数量过少,则打乱作用就不大了
- dataset.batch(batch_size)#将数据组合成batch
- dataset.repeat(N)#将数据集重复N份,repeat只代表重复相同的处理操作,并不会记录前一个epoch的处理结果。
二.具体使用如下:
import tensorflow as tf
train_files = tf.train.match_filenames_once('tfrecords/train_file-*')
test_file = tf.train.match_filenames_once('tfrecords/test_file-*')
def preprocess_for_train(image,height,width,bbox):
pass
def build_net(input):pass
def calc_loss(logit,label):pass
def parse(record):
features = tf.parse_single_example(
record,
features={
'image':tf.FixedLenFeature([],tf.string),
'label':tf.FixedLenFeature([],tf.int64),
'height':tf.FixedLenFeature([],tf.int64),
'width':tf.FixedLenFeature([],tf.int64),
'channels':tf.FixedLenFeature([],tf.int64)
}
)
image_data = tf.decode_raw(features['image'],tf.uint8)
image_data.set_shape([features['height'],features['width'],features['channels']])
label = features['label']
return image_data,label
if __name__ == '__main__':
image_size = 299
batch_size = 100
shuffle_buffer = 10000
#读取数据集
dataset = tf.data.TFRecordDataset(train_files)
#将tfrecord转化成image,label的格式
dataset = dataset.map(parse)
#对image进行预处理
dataset = dataset.map(
lambda image,label:(preprocess_for_train(image,image_size,image_size,None),label)
)
#将数据集的顺序打乱,shuffle_buffer指定了队列中最少的元素个数
dataset = dataset.shuffle(shuffle_buffer)
#指定每次从迭代器中读出的数据个数,默认为1
dataset = dataset.batch(batch_size)
num_epoches = 10
#将数据集中的数据重复num_epoches次,由于之前使用了shuffle,因此每个副本的顺序都不一定相同
dataset = dataset.repeat(num_epoches)
iterator = dataset.make_initializable_iterator()
image_batch,label_batch = iterator.get_next()
learning_rate = 0.01
#构建网络,得到结果
logit = build_net(image_batch)
#结算损失
loss = calc_loss(logit,label_batch)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
test_dataset = tf.data.TFRecordDataset(test_file)
test_dataset = test_dataset.map(parse).map(
lambda image,label:(tf.image.resize_images(image,[image_size,image_size]),label)
)
test_dataset = test_dataset.batch(batch_size)
test_iterator = test_dataset.make_initializable_iterator()
test_image_batch,test_label_batch = test_dataset.get_next()
test_logit = build_net(test_image_batch)
prediction = tf.argmax(test_logit,1)
with tf.Session() as sess:
#使用filename_match_once函数要初始化local_variables
#使用迭代器要初始化iterator.initializer
#global_variables一般都会初始化
sess.run([tf.global_variables_initializer(),tf.local_variables_initializer(),iterator.initializer])
while True:
try:
sess.run(train_step)
except:
break
sess.run(test_iterator.initializer)
test_results = []
test_labels = []
while True:
try:
pred,label = sess.run([prediction,test_label_batch])
test_results.extend(pred)
test_labels.extend(label)
except:
break
correct = [float(y==y_) for (y,y_) in zip(test_results,test_labels)]
acc = sum(correct)/len(correct)
print(acc)
!!!上面例子中的循环体并不是指定循环运行次数,而是使用while(True) try-except的形式来将所有数据遍历一遍(即一个epoch)。这是因为在动态指定输入数据时,不同来源的数据数量大小难以预知,而这个方法我们不必提前知道数据量的大小。