MNIST手写数字识别一

本文介绍了如何使用MNIST数据集进行手写数字分类,涵盖了数据准备、标签编码、数据划分、批量读取、模型构建(包括逻辑回归与Softmax)、训练过程、交叉熵损失函数应用及最终的模型评估和预测。通过一步步详解,读者将理解如何构建和优化深度学习模型进行手写识别任务。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

手写分类识别一

MNIST手写数字识别:分类问题
一、数据准备
  1. MNIST数据集来自美国国家标准与技术研究所,National Institute of Standards and Technology(NIST)

  2. 数据集由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局(the Census Bureau)的工作人员

  3. MNIST数据集可在http://yann.lecun.com/exdb/mnist/获取

  4. Tensorflow提供了数据集读取方法

    import tensorflow as tf
    import tensorflow.examples.tutorials.mnist.input_data as input_data
    mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)#读取数据(数据集缩放的文件目录,标签的数据格式)
    

    如果在读取时指定目录不存在,则会自动去下载,需要等待一定时间,如果存在了,则直接读取。

  5. 了解MNNIST手写数字识别数据集

    print('训练集train数量:',mnist.train.num_examples,",验证集validation数量:",mnist.validatio.num_examples,"测试集test数量:",mnist.test.num_examples)
    
    print("train images shape:",mnist.train.images.shape,"labels shaple:",mnist.train.labels.shape)
    
  6. 可视化image

    import matplotlib.pyplot as plt
    def plot_image(image):
    	plt.imshow(image.reshape(28,28),cmap="binary")
    	plt.show()
    #打印下标为20000的图片
    plot_image(mnist.train.images[20000])
    
  7. 思考:以下代码会输出什么图像?

    plt.imshow(mnist.train.images[20000].reshape(14,56),cmap="binary")
    plt.show()
    
二、标签数据和独热i编码
  1. 认识标签labeel

    mnist.train.labels[1]
    
  2. 独热编码(one hot encoding)

    一种稀疏向量,其中:一个元素设为1,所有其他元素均为0

    独热编码常用于表示拥有有限个可能值的字符串或标示符。

    例如:假设某个植物学数据集记录了15000个不同的物种,其中每个物种都用独一无二的字符串标识符来表示。在特征工程过程中,可能需要将这些字符串标识符编码为独热向量,向量的大小为15000

  3. 为什么要采用one hot编码?

    • 将离散特征的取值扩展到了欧式空间,离散特征的某个取值就对应欧式空间的某个点。
    • 机器学习算法中,特征之间距离的计算或相似度的常用计算方法都是基于欧式空间的。
    • 将离散型特征使用one hot 编码,会让特征之间的距离计算更加合理。
三、数据集的划分
  1. 构建和训练机器学习模型是希望对新的数据做出良好预测

  2. 如何去保证训练的实效,可以应对前所未见的数据呢?

    一种方法是将数据集分为两个子集:训练集(用于训练模型的子集),测试集(用于测试模型的子集)。

  3. 在测试集上表现是否良好是衡量能否在新数据上表现良好的有用指标,前提是:一、测试集足够大。二、不会反复使用相同的测试集来作假。

  4. 拆分数据:将当个数据集拆分为一个训练集和一个测试集。

    确保测试集满足以下两个条件:一、规模足够大,可产生具有统计意义的结果。二、能代表整个数据集,测试集的特征应该与训练集的特征相同。

  5. 新的划分:将数据集划分为三个子集,可以大幅度降低拟合的发生几率:使用验证集评估训练集的效果。在模型”通过“验证集之后,使用测试集再次检查评估结果。

  6. 读取验证数据

    print(" validation images:",mnist.validation.images.shape,"labels:",mnist.validation.labels.shape)
    
  7. 读取测试数据

    print("test images:",mnist.test.images.shape,"labels:",mnist.test.labels.shape)
    
四、数据的批量读取
  1. 一次批量读取多条数据

    print(mnist.train.labels[0:10])
    
  2. next_batch()实现内部会对数据集先做shuffle(洗牌)/(打乱)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值