Mnist数据集分析
def load_data():
f = gzip.open('/var/tmp/mnist.pkl.gz', 'rb')
training_data, test_data = cPickle.load(f)
f.close()
return (training_data, test_data)
def load_data_wrapper():
tr_d, te_d = load_data()
training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
training_results = [vectorized_result(y) for y in tr_d[1]]
training_data = zip(training_inputs, training_results)
test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
test_data = zip(test_inputs, te_d[1])
return (training_data, test_data)
那么这样处理之后,我们的数据类型为:
[([784,1],[10,1]),…,()],总返回类型为list,然后tuple中又包含tuple,子tuple中第一个为一个二维数组,784*1,第二个为一个二位数组,10*1。
if __name__=="__main__":
training_data,test_data=load_data_wrapper()
print(training_data[0][0].shape)
print(training_data[0][1].shape)
运行结果:
training_data,test_data=load_data_wrapper()
print(type(test_data))
print(test_data[0][0].shape)
print(test_data[0][1].shape)
因此我们可以看到mnist数据集的[([784*1],i),……,()]
搭建网络之后输出的最后一层的维度为: