搞了好久好久,公式推导+网络设计就推了20多页草稿纸
花了近10天
程序进1k行,各种debug要人命,只能不断的单元测试+梯度检验
因为C++只有加减乘除,所以对这个网络模型不能有一丝丝的模糊,每一步都要理解的很透彻
挺考验能力的,很庆幸我做出来了,这个是第二版,第一版也写了1k行,写完才发现,模型错了,只能全删掉重新写
算是一次修行
网络的设计,编代码时的各种考虑,debug记录,我不想整理了,有问题的直接私信我吧
#include <iostream>
#include <fstream>
#include <cmath>
#include <cstdlib>
#include <random>
#include <ctime>
using namespace std;
//自然数
const double E = 2.718281828459;
//极小值
const double EPS = 3e-3;
//MNIST
const int MNIST_HEIGHT = 28;
const int MNIST_WIDTH = 28;
//INPUT
const int INPUT_HEIGHT = 32;
const int INPUT_WIDTH = 32;
//padding
const int OFFSET = 2;
//标签的字节数
const int LABEL_BYTE = 1;
//输出的大小
const int OUT_SIZE = 10;
const int NUM_TRAIN = 20;
const int NUM_TEST = 20;
/*------------矩阵类-----------------------*/
typedef vector<vector<double>> Matrix;
inline bool reshapeMatrix(Matrix &mat, size_t row, size_t col)
{
if (row <= 0 || col <= 0)
throw "reshapeMatrix: row <= 0 || col <= 0";
mat.resize(row);
for (int i = 0; i < row; i++)
mat[i].resize(col);
return true;
}
inline bool reshapeMatrix(Matrix &mat, size_t row, size_t col, double val)
{
if (row <= 0 || col <= 0)
throw "reshapeMatrix: row <= 0 || col <= 0";
mat.resize(row);
for (int i = 0; i < row; i++)
{
//先清空,再重塑
mat[i].clear();
mat[i].resize(col, val);
}
return true;
}
//矩阵二维卷积
inline void convMatrix(const Matrix &a, const Matrix &b, Matrix &res)
{
if (b.size() > a.size() || b[0].size() > a[0].size())
throw "convMatrix: b is larger than a";
reshapeMatrix(res, a.size() - b.size() + 1, a[0].size() - b[0].size() + 1, 0);
for (int i = 0; i < res.size(); i++)
for (int j = 0; j < res[0].size(); j++)
{
//遍历卷积矩阵
for (int _i = 0; _i < b.size(); _i++)
for (int _j = 0; _j < b[0].size(); _j++)
res[i][j] += a[_i + i][_j + j] * b[_i][_j];
}
return;
}
//矩阵加法
inline void plusMatrix(Matrix &a, const Matrix &b)
{
if (a.size() != b.size() || a[0].size() != b[0].size())
throw "plusMatrix: shape don't match";
for (int i = 0; i < a.size(); i++)
for (int j = 0; j < a[0].size(); j++)
a[i][j] += b[i][j];
return;
}
inline void plusMatrix(Matrix &a, double val)
{
for (int i = 0; i < a.size(); i++)
for (int j = 0; j < a[0].size(); j++)
a[i][j] += val;
return;
}
inline void plusMatrix(const Matrix &a, const Matrix &b, Matrix &res)
{
if (a.size() != b.size() || a[0].size() != b[0].size())
throw "plusMatrix: shape don't match";
reshapeMatrix(res, a.size(), a[0].size());
for (int i = 0; i < res.size(); i++)
for (int j = 0; j < res[0].size(); j++)
res[i][j] = a[i][j] + b[i][j];
return;
}
//矩阵减法
inline void minusMatrix(const Matrix &a, const Matrix &b, Matrix &res)
{
if (a.size() != b.size() || a[0].size() != b[0].size())
throw "plusMatrix: shape don't match";
reshapeMatrix(res, a.size(), a[0].size());
for (int i = 0; i < res.size(); i++)
for (int j = 0; j < res[0].size(); j++)
res[i][j] = a[i][j] - b[i][j];
return;
}
//矩阵乘法
inline void multiplyMatrix(const Matrix &a, const Matrix &b, Matrix &res)
{
if (a[0].size() != b.size())
throw "multiplyMatrix: a.col != b.row";
reshapeMatrix(res, a.size(), b[0].size(), 0);
for (int i = 0; i < res.size(); i++)
for (int j = 0; j < res[0].size(); j++)
for (int k = 0; k < a[0].size(); k++)
res[i][j] += a[i][k] * b[k][j];
return;
}
//矩阵与标量相乘
inline void multiplyMatrix(Matrix &mat, double val)
{
for (int i = 0; i < mat.size(); i++)
for (int j = 0; j < mat[0].size(); j++)
mat[i][j] *= val;
return;
}
inline void multiplyMatrix(double val, const Matrix &mat, Matrix &res)
{
reshapeMatrix(res, mat.size(), mat[0].size());
for (int i = 0; i < res.size(); i++)
for (int j = 0; j < res[0].size(); j++)
res[i][j] = mat[i][j] * val;
return;
}
//矩阵点乘
void matmulMatrix(const Matrix &a, const Matrix &b, Matrix &res)
{
if (a.size() != b.size() || a[0].size() != b[0].size())
throw "matmulMatrix: shape don't match";
reshapeMatrix(res, a.size(), a[0].size());
for (int i = 0; i < res.size(); i++)
for (int j = 0; j < res[0].size(); j++)
res[i][j] = a[i][j] * b[i][j];
return;
}
//矩阵池化,步长=大小
inline void downSampleMatrix(const Matrix &mat, size_t height, size_t width, Matrix &res)
{
if (mat.size() % height != 0 || mat[0].size() % width != 0)
throw "downSampleMatrix: height/width don't match matrix";
reshapeMatrix(res, mat.size() / height, mat[0].size() / width);
for (int i = 0; i < res.size(); i++)
for (int j = 0; j < res[0].size(); j++)
{
//求和
int row_b = i * height;
int row_e = (i + 1) * height;
int col_b