python机器学习手写字体识别_Python 3 利用机器学习模型 进行手写体数字检测

该博客介绍了一种使用Python和机器学习模型识别手写数字的方法,涉及Logistic Regression、Linear SVC、MLPClassifier和SGDClassifier。首先,生成手写体数字数据集并提取特征,然后通过sklearn库训练和测试模型,展示不同模型在不同样本量下的精度变化。最终,通过保存模型进行单张图像的预测,并绘制了样本数与精度的关系图。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

0.引言

介绍了如何生成手写体数字的数据,提取特征,借助 sklearn 机器学习模型建模,进行识别手写体数字 1-9 模型的建立和测试。

用到的几种模型:

1. LR,Logistic Regression,                (线性模型)中的逻辑斯特回归

2. Linear SVC,Support Vector Classification,      (支持向量机)中的线性支持向量分类

3. MLPC,Multi-Layer Perceptron Classification,       (神经网络)多层感知机分类

4. SGDC,Stochastic Gradient Descent Classification,   (线性模型)随机梯度法求解

手写体的识别是一个 分类问题,提取图像特征作为模型输入,输出到标记数字 1-9;

主要内容:

1. 生成手写体数字数据集;

2. 提取图像特征存入 CSV;

3. 利用机器学习建立和测试手写体数字识别模型;

得到不同样本量训练下,几种机器学习模型精度随样本的变化关系曲线:

图 0  不同样本数目下的四种模型的测试精度( 数据集大小从 100 到 5800,间隔 100 )

1.开发环境

python:  3.6.3

import PIL, cv2, pandas, numpy, os, csv, random

需要调用的 sklearn 库:

1 from sklearn.linear_model import LogisticRegression #线性模型中的 Logistic 回归模型

2 from sklearn.linear_model import SGDClassifier #线性模型中的随机梯度下降模型

3 from sklearn.svm import LinearSVC #SVM 模型中的线性 SVC 模型

4 from sklearn.neural_network import MLPClassifier #神经网络模型中的多层网络模型

2.整体设计思路

图 1 整体的框架设计

工程的目的,是想利用机器学习模型去训练识别生成的随机验证码图像(单个数字 1-9 ),通过以下三个步骤实现:

1. 生成手写体数据集

2. 提取特征向量写入 CSV

3. sklearn 模型训练和测试

图 2 整体的设计流程

3. 编程过程

3.1 生成多张单个验证码图像 ( generate_folders.py, generate_handwritten_numbers.py )

        

图 3 生成的多张单个验证码图像

思路就是 random 随机生成数字 1-9,然后利用PIL的画笔工具进行画图,对图像进行扭曲,然后根据随机数的真实标记 1-9,保存到对应文件夹内,用标记+序号命名。

1 draw = ImageDraw.Draw(im) #画笔工具

3.2提取特征向量写入 CSV ( get_features.py )

这一步是提取图像中的特征。生成的单个图像是 30*30 即 900 个像素点的;

为了降低维度,没有选择 900 个像素点每点的灰度作为输入,而是选取了 30 行每行的黑点数,和 30 列每列的黑点数作为输入,这样降到了 60 维。

(a) 提取 900 维特征

(b) 提取 60 维特征

图 4 提取图像特征

特征的提取也比较简单,逐行逐列计算然后计数求和:

1 defget_feature(img):2 #提取特征

3 #30*30的图像,

4

5 width, height =img.size6

7 globalpixel_cnt_list8 pixel_cnt_list=[]9

10 height = 30

11 for y inrange(height):12 pixel_cnt_x =013 for x inrange(width):14 #print(img.getpixel((x,y)))

15 if img.getpixel((x, y)) == 0: #黑点

16 pixel_cnt_x += 1

17

18 pixel_cnt_list.append(pixel_cnt_x)19

20 for x inrange(width):21 pixel_cnt_y =022 for y inrange(height):23 if img.getpixel((x, y)) == 0: #黑点

24 pixel_cnt_y += 1

25

26 pixel_cnt_list.append(pixel_cnt_y)27

28 return pixel_cnt_list

所以我们接下来需要做的工作是,遍历访问文件夹 num_1-9 中的所有图像文件,进行特征提取,然后写入 CSV 文件中:

1 with open(path_csv+"tmp.csv", "w", newline="") as csvfile:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值