用C语言实现MNIST识别的FCNN网络 A MNIST recognition FCNN in C language
一、准备mnist训练和测试数据集【1】
从 https://www.kaggle.com/datasets 下载mnist手写数字数据集。
- 数据集包含60,000个训练数据和10,000个测试数据。
- 数据集包含了0-9共10类手写数字图片,每张图片都是28x28大小的灰度图。
- MNIST数据集包含四个部分:
训练集图像:train-images-idx3-ubyte.gz(9.9MB,包含60000个样本)
训练集标签:train-labels-idx1-ubyte.gz(29KB,包含60000个标签)
测试集图像:t10k-images-idx3-ubyte.gz(1.6MB,包含10000个样本)
测试集标签:t10k-labels-idx1-ubyte.gz(5KB,包含10000个标签)
【2】
二、设计神经网络
采用三层全连接神经网络FCNN:
- 输入层: 一共28x28=784个神经元;
- 中间层: 一共512个神经元,激活函数用sigmoid;
- 输出层: 一共10个神经元,分别对应0-9数字的概率,激活函数用sigmoid。
三、训练和测试神经网络
- 用60,000个数据训练:训练过程中先用大的learning_rate,再用小的,提高训练速度;
- 用10,000个数据测试。
四、C语言代码实现(C code)
#include "stdio.h"
#include "stdlib.h"
#include "math.h"
#define ROW_PIX_NUM 28
#define COL_PIX_NUM 28
#define TOTAL_PIX_NUM (ROW_PIX_NUM * COL_PIX_NUM) // 784
#define TOTAL_TRAIN_NUM 60000
#define TOTAL_TEST_NUM 10000
#define INPUT_NUM TOTAL_PIX_NUM
#define HIDDEN_NUM 512
#define OUTPUT_NUM 10
#define LEARNING_RATE_BIG 0.1
#define LEARNING_RATE_SMALL 0.01
#define EPOCHS 100000
#define MIN_EPOCH 90000
static unsigned int epoch;
static unsigned int tmp;
static unsigned int train_magic_number1, train_number_of_images, train_number_of_rows, train_number_of_columns;
static unsigned int train_images[TOTAL_TRAIN_NUM][TOTAL_PIX_NUM];
static unsigned int train_magic_number2, train_number_of_items;
static unsigned int train_labels[TOTAL_TRAIN_NUM];
static unsigned int test_magic_number1, test_number_of_images, test_number_of_rows, test_number_of_columns;
static unsigned int test_images[TOTAL_TEST_NUM][TOTAL_PIX_NUM];
static unsigned int test_magic_number2, test_number_of_items;
static unsigned int test_labels[TOTAL_TEST_NUM];
static float weights0[INPUT_NUM][HIDDEN_NUM]; //input --> hidden
static float weights1[HIDDEN_NUM][OUTPUT_NUM]; //hidden --> output
static float hiddens[HIDDEN_NUM];
static float outputs[OUTPUT_NUM];
static float errs_output[OUTPUT_NUM], errs_hidden[HIDDEN_NUM];
static float learning_rate = LEARNING_RATE_SMALL;
//对应int32大小的成员 的转换 范例
unsigned int swapInt32(unsigned int value)
{
return