import torch as t import torch.nn as nn import numpy as np import os.path import struct train_img_dir_ubyte = r'E:\code\LnNet\train-images.idx3-ubyte' train_label_dir_ubyte = r'E:\code\LnNet\train-labels.idx1-ubyte' train_size = os.path.getsize(train_img_dir_ubyte) # 处理MNIST数据 变成矩阵形式 # img_array 每行存储784像素值 # label_array 每个值是对应的图像 with open(train_img_dir_ubyte,'rb') as trf: # 不对 是因为魔数 由4字节 数量是3字节 行数 4字节 不统一 # magic,num,hang,lie = struct.unpack_from('>'+str(12)+'B',trf.read(12)) magic, img_num, hang, lie = struct.unpack('>IIII',trf.read(16)) img_csv_data = struct.unpack_from('>'+str(train_size-16)+'B',trf.read()) img_array = np.array(img_csv_data).astype(np.uint8).reshape(img_num,784) # print(img_array.shape) # plt.imshow(img_array[0].reshape(28,28)) # plt.show() with open(train_label_dir_ubyte,'rb') as tb: magic,label_num = struct.unpack('>II',tb.read(8)) label = struct.unpack_from('>'+str(os.path.getsize(train_label_dir_ubyte)-8)+'B' ,tb.read()) label_array = np.array(label).astype(np.uint8).reshape(-1)
mnist数据处理
最新推荐文章于 2023-09-24 05:00:00 发布