cifar10代码

本文档详细介绍了如何在机器学习项目中处理CIFAR-10数据集,包括数据输入的步骤。源码经过修正,分为数据输入、模型定义、模型训练和性能测试四个独立的.py文件,确保程序结构清晰。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在源码的基础上进行了部分改正。

一个标准的机器学习程序,应该包括数据输入、定义模型本身、模型训练和模型性能测试四大部分,可以分成四个.py文件。

(一)数据输入部分(input_dataset.py

#coding:utf-8
import os
import tensorflow as tf
# 原图像的尺度为32*32,但根据常识,信息部分通常位于图像的中央,这里定义了以中心裁剪后图像的尺寸
fixed_height = 24
fixed_width = 24
# cifar10数据集的格式,训练样例集和测试样例集分别为50k和10k
train_samples_per_epoch = 50000
test_samples_per_epoch = 10000
data_dir='./cifar-10-batches-bin' # 定义数据集所在文件夹路径
batch_size=128 #定义每次参数更新时,所使用的batch的大小

def read_cifar10(filename_queue):
    # 定义一个空的类对象,类似于c语言里面的结构体定义
    class Image(object):
        pass
    image = Image()
    image.height=32
    image.width=32
    image.depth=3
    label_bytes = 1
    image_bytes = image.height*image.width*image.depth
    Bytes_to_read = label_bytes+image_bytes
    # 定义一个Reader,它每次能从文件中读取固定字节数
    reader = tf.FixedLengthRecordReader(record_bytes=Bytes_to_read)
    # 返回从filename_queue中读取的(key, value)对,key和value都是字符串类型的tensor,并且当队列中的某一个文件读完成时,该文件名会dequeue
    image.key, value_str = reader.read(filename_queue)
    # 解码操作可以看作读二进制文件,把字符串中的字节转换为数值向量,每一个数值占用一个字节,在[0, 255]区间内,因此out_type要取uint8类型
    value = tf.decode_raw(bytes=value_str, out_type=tf.uint8)
    # 从一维tensor对象中截取一个slice,类似于从一维向量中筛选子向量,因为value中包含了label和feature,故要对向量类型tensor进行'parse'操作
    image.label = tf.slice(input_=value, begin=[0], size=[label_bytes])# begin和size分别表示待截取片段的起点和长度
    data_mat = tf.slice(input_=value, begin=[label_bytes], size=[image_bytes])
    data_mat = tf.reshape(data_mat, (image.depth, image.height, image.width)) #这里的维度顺序,是依据cifar二进制文件的格式而定的
    transposed_value = tf.transpose(data_mat, perm=[1, 2, 0]) #对data_mat的维度进行重新排列,返回值的第i个维度对应着data_mat的第perm[i]维
    image.mat = transposed_value
    return image

def get_batch_samples(img_obj, min_samples_in_queue, batch_size, shuffle_flag):
# tf.train.shuffle_batch()函数用于随机地shuffling 队列中的tensors来创建batches(也即每次可以读取多个data文件中的样例构成一个batch)。这个函数向当前Graph中添加了下列对象:
# *创建了一个shuffling queue,用于把‘tensors’中的tensors压入该队列;
# *一个dequeue_many操作,用于根据队列中的数据创建一个batch;
# *创建了一个QueueRunner对象,用于启动一个进程压数据到队列
# capacity参数用于控制shuffling queue的最大长度;min_after_dequeue参数表示进行一次dequeue操作后队列中元素的最小数量,可以用于确保batch中
# 元素的随机性;num_threads参数用于指定多少个threads负责压tensors到队列;enqueue_many参数用于表征是否tensors中的每一个tensor都代表一个样例
# tf.train.batch()与之类似,只不过顺序地出队列(也即每次只能从一个data文件中读取batch),少了随机性。

    if shuffle_flag == False:
        image_batch, label_batch = tf.train.batch(tensors=img_obj,
                                                  batch_size=batch_size,
                                                  num_threads=4,
                                                  capacity=min_samples_in_queue+3*batch_size)
    else:
        image_batch, label_batch = tf.train.shuffle_batch(tensors=img_obj,
                                                          batch_size=batch_size,
                                                          num_threads=4,
                                                          min_after_dequeue=min_samples_in_queue,
                                                          capacity=min_samples_in_queue+3*batch_size)
    tf.summary.image('input_image', image_batch) #输出预处理后图像的summary缓存对象,用于在session中写入到事件文件中//tf.summary.image('input_image', image_batch, max_images=6)
    return image_batch, tf.reshape(label_batch, shape=[batch_size])

def preprocess_input_data():
#这部分程序用于对训练数据集进行‘数据增强’操作,通过增加训练集的大小来防止过拟合
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值