程序解码MNIST数据集并返回两个tuple

该博客介绍了如何使用Python解析MNIST数据集的idx3-ubyte和idx1-ubyte文件,包括训练集和测试集的图像及标签。通过gzip打开文件,利用struct模块进行解包,最终返回训练和测试数据的两个tuple。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

#encoding:utf-8
import numpy as np
import struct
import gzip

a = 'D:/src/tensorflow/data_sets/MNIST_data/train-images-idx3-ubyte.gz'
b = 'D:/src/tensorflow/data_sets/MNIST_data/train-labels-idx1-ubyte.gz'
c = 'D:/src/tensorflow/data_sets/MNIST_data/t10k-images-idx3-ubyte.gz'
d = 'D:/src/tensorflow/data_sets/MNIST_data/t10k-labels-idx1-ubyte.gz'

def decode_idx3_ubyte(idx3_ubyte_file):
    fp = gzip.open(idx3_ubyte_file,'rb')
    bin_data = fp.read()
    
    offset = 0
    fmt_header = '>iiii'
    magic_number,num_images,num_rows,num_cols = struct.unpack_from(fmt_header,bin_data,offset)
    print ('魔数:%d, 图片数量:%d张,图片大小:%d%d' % (magic_number,num_images,num_rows,num_cols))
    
    image_size = num_rows*num_cols
    offset += struct.calcsize(fmt_header)
    fmt_image = '>'+str(image_size)+'B'
    images = np.empty((num_images,num_rows,num_cols))
    for i in range(num_images):
        if (i+1)%10000 == 0:
            print ('已解析 %d' % (i+1) + '张')
        images[i] = np.array(struct.unpack_from(fmt_image,bin_data,offset)).reshape((num_rows,num_cols))
        offset += struct.calcsize(fmt_image)
    return images

def decode_idx1_ubyte(idx1_ubyte_file):
    fp = gzip.open(idx1_ubyte_file,'rb')
    bin_data = fp.read()
    
    offset = 0
    fmt_header = '>ii'
    magic_number,num_images = struct.unpack_from(fmt_header,bin_data,offset)
    print ('魔数:%d, 图片数量:%d张' % (magic_number,num_images))

    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = np.empty(num_images)
    for i in range(num_images):
        if (i+1)%10000 == 0:
            print ('已解析 %d' % (i+1) + '张')
        labels[i] = struct.unpack_from(fmt_image,bin_data,offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels

def load_mnist_data(train_images=a,train_labels=b,test_images=c,test_labels=d):
    """decode the mnist local train and test data."""
    t_images = decode_idx3_ubyte(train_images)
    t_labels = decode_idx1_ubyte(train_labels)
    v_images = decode_idx3_ubyte(test_images)
    v_labels = decode_idx1_ubyte(test_labels)
    return (t_images,t_labels),(v_images,v_labels)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值