Kaggle入门:Digit Recognizer
数据是经典的lecun大神的mnist手写数字数据集。
一、下载数据
数据分为三部分:训练数据、测试数据、提交数据格式示例
数据格式和lecun网站的数据格式有所不同,这里所给的数据为csv格式每一张数字图片为28*28像素的灰度图,文件每一行代表756个像素点。
二、数据读取
这里使用pandas库读取数据, 由于灰度图像像素值在0-255之间,所以将其转化为uint8类型。
mnist.py
import pandas as pd
import numpy as np
train_data_file = 'train.csv'
test_data_file = 'test.csv'
train_data = pd.read_csv(train_data_file).as_matrix().astype(np.uint8)
test_data = pd.read_csv(test_data_file).as_matrix().astype(np.uint8)
def extract_images_and_labels(dataset, validation = False):
#需要将数据转化为[image_num, x, y, depth]格式
images = dataset[:, 1:].reshape(-1, 28, 28, 1)
#由于label为0~9,将其转化为一个向量.如将0 转换为 [1,0,0,0,0,0,0,0,0,0]
labels_dense = dataset[:, 0]
num_labels = labels_dense.shape[0]
index_offset = np.arange(num_labels) * 10
labels_one_hot = np.zeros((num_labels, 10))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
if validation:
num_images = images.shape[0]
divider = num_images - 200
return images[:divider], labels_one_hot[:divider], images[divider+1:], labels_one_hot[divider+1:]
else:
return images, labels_one_hot