该代码Github地址:
https://github.com/peter-u-diehl/stdp-mnist/blob/master/Diehl%26Cook_spiking_MNIST.py
import numpy as np
import matplotlib.cm as cmap
import time
import os.path
import scipy
import cPickle as pickle
import brian_no_units #import it to deactivate unit checking --> This should NOT be done for testing/debugging
import brian as b
from struct import unpack
from brian import *
# specify the location of the MNIST data
MNIST_data_path = ''
首先导入模块,定义数据集的路径变量。
def get_labeled_data(picklename, bTrain = True):
"""Read input-vector (image) and target class (label, 0-9) and return
it as list of tuples.
"""
if os.path.isfile('%s.pickle' % picklename):
data = pickle.load(open('%s.pickle' % picklename))
else:
# Open the images with gzip in read binary mode
if bTrain:
images = open(MNIST_data_path + 'train-images.idx3-ubyte','rb')
labels = open(MNIST_data_path + 'train-labels.idx1-ubyte','rb')
else:
images = open(MNIST_data_path + 't10k-images.idx3-ubyte','rb')
labels = open(MNIST_data_path + 't10k-labels.idx1-ubyte','rb')
# Get metadata for images
images.read(4) # skip the magic_number
number_of_images = unpack('>I', images.read(4))[0]
rows = unpack('>I', images.read(4))[0]
cols = unpack('>I', images.read(4))[0]
# Get metadata for labels
labels.read(4) # skip the magic_number
N = unpack('>I', labels.read(4))[0]
if number_of_images != N:
raise Exception('number of labels did not match the number of images')
# Get the data
x = np.zeros((N, rows, cols), dtype=np.uint8) # Initialize numpy array
y = np.zeros((N, 1), dtype=np.uint8) # Initialize numpy array
for i in xrange(N):
if i % 1000 == 0:
print("i: %i" % i)
x[i] = [[unpack('>B', images.read(1))[0] for unused_col in xrange(cols)] for unused_row in xrange(rows) ]
y[i] = unpack('>B', labels.read(1))[0]
data = {'x': x, 'y': y, 'rows': rows, 'cols': cols}
pickle.dump(data, open("%s.pickle" % picklename, "wb"))
return data
获取含标签数据的函数。
函数功能:获得带标签的数据
输入:图片矩阵及其目标类(0-9的标签)、是否为训练数据
输出:元组列表
注:官方MNIS数据集有60000个训练集和10000个测试集,为IDX格式,IDX格式形式如下:
magic number
size in dimension 0
size in dimension 1
size in dimension 2
.....
size in dimension N
data
魔法数字是个整数,前两个字节总是0,第三个字节表示数据的类型:
0x08: unsigned byte
0x09: signed byte
0x0B: short (2 bytes)
0x0C: int (4 bytes)
0x0D: float (4 bytes)
0x0E: double (8 bytes)
第四个字节表示矩阵的维度。
接着便是每个维度的尺寸,用四字节的整数表示。
images.read(4)跳过了MNIST数据集的魔法数字的四个字节。
unpack是struct模块中的函数,用法是unpack(fmt, string),代码中的'>'说明了改变对齐方式的方法,为大端对齐;'I'表示将C类型的unsigned int 转换为Python类型的integer。
获取到MNIST数据集中的数据后转换为numpy类型数组。
pickle提供了一个简单的持久化功能,可以将对象以文件的形式存放在磁盘上,dump方法:
pickle.dump(obj, file[, protocol])
def get_matrix_from_file(fileName):
offset = len(ending) + 4
if fileName[-4-offset] == 'X':
n_src = n_input
else:
if fileName[-3-offset]=='e':
n_src = n_e
else:
n_src = n_i
if fileName[-1-offset]=='e':
n_tgt = n_e
else:
n_tgt = n_i
readout = np.load(fileName)
print readout.shape, fileName
value_arr = np.zeros((n_src, n_tgt))
if not readout.shape == (0,):
value_arr[np.int32(readout[:,0]), np.int32(readout[:,1])] = readout[:,2]
return value_arr