0.引言
介绍了如何生成手写体数字的数据,提取特征,借助 sklearn 机器学习模型建模,进行识别手写体数字 1-9 模型的建立和测试。
用到的几种模型:
1. LR,Logistic Regression, (线性模型)中的逻辑斯特回归
2. Linear SVC,Support Vector Classification, (支持向量机)中的线性支持向量分类
3. MLPC,Multi-Layer Perceptron Classification, (神经网络)多层感知机分类
4. SGDC,Stochastic Gradient Descent Classification, (线性模型)随机梯度法求解
手写体的识别是一个 分类问题,提取图像特征作为模型输入,输出到标记数字 1-9;
主要内容:
1. 生成手写体数字数据集;
2. 提取图像特征存入 CSV;
3. 利用机器学习建立和测试手写体数字识别模型;
得到不同样本量训练下,几种机器学习模型精度随样本的变化关系曲线:
图 0 不同样本数目下的四种模型的测试精度( 数据集大小从 100 到 5800,间隔 100 )
1.开发环境
python: 3.6.3
import PIL, cv2, pandas, numpy, os, csv, random
需要调用的 sklearn 库:
1 from sklearn.linear_model import LogisticRegression #线性模型中的 Logistic 回归模型
2 from sklearn.linear_model import SGDClassifier #线性模型中的随机梯度下降模型
3 from sklearn.svm import LinearSVC #SVM 模型中的线性 SVC 模型
4 from sklearn.neural_network import MLPClassifier #神经网络模型中的多层网络模型
2.整体设计思路
图 1 整体的框架设计
工程的目的,是想利用机器学习模型去训练识别生成的随机验证码图像(单个数字 1-9 ),通过以下三个步骤实现:
1. 生成手写体数据集
2. 提取特征向量写入 CSV
3. sklearn 模型训练和测试
图 2 整体的设计流程
3. 编程过程
3.1 生成多张单个验证码图像 ( generate_folders.py, generate_handwritten_numbers.py )
图 3 生成的多张单个验证码图像
思路就是 random 随机生成数字 1-9,然后利用PIL的画笔工具进行画图,对图像进行扭曲,然后根据随机数的真实标记 1-9,保存到对应文件夹内,用标记+序号命名。
1 draw = ImageDraw.Draw(im) #画笔工具
3.2提取特征向量写入 CSV ( get_features.py )
这一步是提取图像中的特征。生成的单个图像是 30*30 即 900 个像素点的;
为了降低维度,没有选择 900 个像素点每点的灰度作为输入,而是选取了 30 行每行的黑点数,和 30 列每列的黑点数作为输入,这样降到了 60 维。
(a) 提取 900 维特征
(b) 提取 60 维特征
图 4 提取图像特征
特征的提取也比较简单,逐行逐列计算然后计数求和:
1 defget_feature(img):2 #提取特征
3 #30*30的图像,
4
5 width, height =img.size6
7 globalpixel_cnt_list8 pixel_cnt_list=[]9
10 height = 30
11 for y inrange(height):12 pixel_cnt_x =013 for x inrange(width):14 #print(img.getpixel((x,y)))
15 if img.getpixel((x, y)) == 0: #黑点
16 pixel_cnt_x += 1
17
18 pixel_cnt_list.append(pixel_cnt_x)19
20 for x inrange(width):21 pixel_cnt_y =022 for y inrange(height):23 if img.getpixel((x, y)) == 0: #黑点
24 pixel_cnt_y += 1
25
26 pixel_cnt_list.append(pixel_cnt_y)27
28 return pixel_cnt_list
所以我们接下来需要做的工作是,遍历访问文件夹 num_1-9 中的所有图像文件,进行特征提取,然后写入 CSV 文件中:
1 with open(path_csv+"tmp.csv", "w", newline="") as csvfile: