MNIST进阶之next_batch()

本文介绍了一个自定义的next_batch()函数,用于从数据集中提取批次数据,并提供了完整的代码实现及示例。该函数适用于自行创建的数据集,可以有效辅助机器学习项目的训练过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

对于训练数据分batch十分有必要,但是MNIST官网给出的程序采用了封装函数,本文写了一个next_batch()适用于自己制作的数据集。

import numpy as np


class DataSet(object):

    def __init__(self, images, labels, num_examples):
        self._images = images
        self._labels = labels
        self._epochs_completed = 0  # 完成遍历轮数
        self._index_in_epochs = 0   # 调用next_batch()函数后记住上一次位置
        self._num_examples = num_examples  # 训练样本数

    def next_batch(self, batch_size, fake_data=False, shuffle=True):
        start = self._index_in_epochs

        if self._epochs_completed == 0 and start == 0 and shuffle:
            index0 = np.arange(self._num_examples)
            print(index0)
            np.random.shuffle(index0)
            print(index0)
            self._images = np.array(self._images)[index0]
            self._labels = np.array(self._labels)[index0]
            print(self._images)
            print(self._labels)
            print("-----------------")

        if start + batch_size > self._num_examples:
            self._epochs_completed += 1
            rest_num_examples = self._num_examples - start
            images_rest_part = self._images[start:self._num_examples]
            labels_rest_part = self._labels[start:self._num_examples]
            if shuffle:
                index = np.arange(self._num_examples)
                np.random.shuffle(index)
                self._images = self._images[index]
                self._labels = self._labels[index]
            start = 0
            self._index_in_epochs = batch_size - rest_num_examples
            end = self._index_in_epochs
            images_new_part = self._images[start:end]
            labels_new_part = self._labels[start:end]
            return np.concatenate((images_rest_part, images_new_part), axis=0), np.concatenate(
                (labels_rest_part, labels_new_part), axis=0)

        else:
            self._index_in_epochs += batch_size
            end = self._index_in_epochs
            return self._images[start:end], self._labels[start:end]


if __name__ == '__main__':
    input = ['a', 'b', '1', '2', '*', '3', 'c', '&', '#']
    output = ["Letter", "Letter", "Number", "Number", "Symbol", "Number", "Letter", "Symbol", "Symbol"]
    ds = DataSet(input, output, 9)
    for i in range(3):
        image_batch, label_batch = ds.next_batch(4)
        print(image_batch)
        print(label_batch)

现在来看结果:

 

这行代码来自经典的 MNIST 数据集加载和训练过程,主要用于从 MNIST 训练集中获取一批数据样本(batch),以便用于模型训练。下面详细解释一下它的作用及使用方法。 ### 解释 - **`mnist.train.next_batch(batchSize)`**:该函数是从MNIST数据集中读取下一批指定大小的数据样本。其中 `batchSize` 是一个整数,表示每次取出多少条记录作为批次数据来进行一次迭代训练。 - **返回值**: 返回两个NumPy数组 `(x, y)` - `x`: 形状为 `[batch_size, 784]` 的二维张量,包含每个图像像素点的灰度值(0到1之间)。每幅图像是28×28=784维特征向量; - `y`: 形状为 `[batch_size, num_classes(10)]` 的标签矩阵,通常是一个one-hot编码形式的目标类别。 例如,如果你设置 `batchSize = 64` ,那么将一次性取得64个手写数字图片以及对应的正确答案标签。 ### 示例代码片段 ```python import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 加载MNIST数据集 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 定义批量大小 batchSize = 100 for i in range(5): # 这里仅做示范循环几次 # 获取下一个批处理数据 batch_x, batch_y = mnist.train.next_batch(batchSize) print(f"Batch {i+1}:") print("X shape:", batch_x.shape) # 输出像形如 (100, 784) print("Y shape:", batch_y.shape) # 输出像形如 (100, 10) ``` 上述代码演示了如何多次调用 `next_batch()` 来逐次获得不同批次的数据。实际训练过程中会结合优化算法、前馈神经网络等组件完成对整个训练集的学习。 需要注意的是,随着 TensorFlow 版本不断更新演进,最新版本中已经不再推荐直接使用这种方式访问MNIST数据集,而是转向更为通用的数据管道构建方式,比如使用 `tf.data.Dataset` API 。不过理解这一经典操作仍然是很有帮助的。 ---
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值