首先MNIST数据集主要分为4个部分,分别是:
t10k-images-idx3-ubyte test image (10,000)
t10k-labels-idx1-ubyte test label (10,000)
train-images-idx3-ubyte train image (60,000)
train-labels-idx1-ubyte train label (60,000)
官网链接:
http://yann.lecun.com/exdb/mnist
首先下载数据,用urllib.request.urlretrieve下载数据,然后读取数据,并将其转换成 npy file:
import urllib
MNIST_URL = 'http://yann.lecun.com/exdb/mnist'
MNIST_FLOAT_TRAIN = 'train-images-idx3-ubyte'
DATA_DIR = 'dataset'
urllib.request.urlretrieve(url=f'{MNIST_URL}/{MNIST_FLOAT_TRAIN}.gz', filename=local_filename + '.gz')
with gzip.open(local_filename + '.gz', 'rb') as f:
file_content = f.read()
with open(local_filename, 'wb') as f: # byte data
f.write(file_content)
读取出来的数据的第一个四字节(int) 是magic number (例如Ox0000 0801
),第二,三,四字节分别是image numbers, image height, image weight,知道结构后便可以开始解析MNIST数据集了。
with open(local_filename, 'rb') as f:
f.seek(4) # skip magic number
nimages, rows, cols = struct.unpack('>iii', f.read(12)) # 大端读取3个4字节,i表示int,4字节
dim = rows * cols
images = np.fromfile(f, dtype=np.dtype(np.ubyte))
images = (images / 255.0).astype('float32').reshape((nimages, dim))
这样就可以得到 nimages 个 image了。其他文件的解析大同小异,可自行思考。