机器学习_sklearn-SVM实现手写数字识别

本文通过使用sklearn库中的MNIST数据集,构建并优化了一个手写数字识别模型。首先,进行了数据预处理,包括数据加载、打乱和划分。接着,采用SVM模型进行初步训练,并分析了模型的准确性和混淆矩阵。发现模型性能有待提升后,通过调整模型参数,特别是C值,以提高模型精度。最终,引入高斯核SVM,显著提升了模型的整体准确率至98%以上,但对特定数字如4、9的识别仍有改进空间。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、 数据准备

from sklearn.model_selection import StratifiedShuffleSplit
import pandas as pd
import numpy  as np
from sklearn.datasets import fetch_mldata

class Data_need():
	def __init__(self, percent, data_name):
		self.percent = percent
		self.data_name = data_name

	def get_data(self):
		data_home = r'D:\Python_data\python Data\sklearn'
		mnist = fetch_mldata(self.data_name, data_home=data_home)
		return mnist['data'], mnist['target']

	## 打乱数据集
	def random_data(self, x, y):
		mnist_train, mnist_test = 0, 0
		## 创建DataFrame
		data_y = pd.DataFrame(y, columns=['y'])
		n = len(x[0])
		data_x = pd.DataFrame(x, columns=list(range(n)))
		mnist_data = pd.merge(data_x, data_y, right_index=True, left_index=True)
		## 分层取样
		split = StratifiedShuffleSplit(n_splits=1, test_size = self.percent, random_state=42)
		for train_index, test_index in split.split(mnist_data, mnist_data['y']):
			mnist_train = mnist_data.loc[train_index,:]
			mnist_test = mnist_data.loc[test_index,:]
		return mnist_train, mnist_test

	def train_test_data(self, train, test):
		# 将像素数据变为二值变量
		return (np.array(train.iloc[:,:-1]) != 0)*1, np.array(train['y']), (np.array(test.iloc[:,:-1])!= 0)*1, np.array(test['y'])


if __name__ == '__main__':
	data_need = Data_need(0.3, 'MNIST original')
	x, y = data_need.get_data()
	train, test = data_need.random_data(x, y)
	x_train_in, y_train_in, x_test_in, y_test_in = data_need.train_test_data(train, test)

	

2、查看数据及模型训练

模型采用ovr (ova)SMV模型

from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
plt.style.use('ggplot')

def to_plot(num, n):
	"""
	num: 想要绘制的数值
	n :第几个样本
	"""
	plt_x_array = x_train_in[y_train_in == num]
	some_digit = plt_x_array[n]
	some_digit_image = some_digit.reshape(28, 28)
	plt.imshow(some_digit_image, cmap=plt.cm.binary, interpolation='nearest')
	plt.axis('off')
	plt.show()


if __name__ == '__main__' :
	to_plot(8, 10)
	ova_svm_clf = LinearSVC(loss='hinge', C=5, multi_class='ovr')
	ova_svm_clf.fit(x_train_in, y_train_in)
	## 交叉验证出预测
	y_prd = cross_val_predict(ova_svm_clf, x_train_in, y_train_in, cv=3)
	## 评估 混淆矩阵
	conf_m = confusion_matrix(y_train_in, y_prd)


在这里插入图片描述

3、模型评估


### 整体的准确率
def clf_correct(y_train, y_prd):
	return sum((y_train - y_prd) == 0) / len(y_train)


class plot_conf_m():
	def __init__(self, conf_m):
		self.conf_m = conf_m

	def plt_conf_m(self):
		## 用matshow()函数绘制出混淆矩阵
		plt.matshow(self.conf_m, cmap=plt.cm.gray)

	def plt_error_conf_m(self):
		## 关注误差数据的图像呈现
		row_sums = self.conf_m.sum(axis=1, keepdims=True)
		norm_conf_m = self.conf_m / row_sums
		## 用0 将正确分类覆盖 查看那个类分类特别不准
		np.fill_diagonal(norm_conf_m, 0)
		plt.matshow(norm_conf_m, cmap=plt.cm.gray)


if __name__ == '__main__':
	print("整体准确性:{}".format(clf_correct(y_train_in, y_prd)))
	plt_confm = plot_conf_m(conf_m)
	plt_confm.plt_conf_m(), plt.title("Focus on the correct prediction")
	plt_confm.plt_error_conf_m(), plt.title("Focus on the error prediction")
	plt.show()

##  整体准确性:0.902795918367347

从下面两个混淆矩阵中可以看出 错误分类分布比较平均,还待提高,所以增大C 进行重新拟合
在这里插入图片描述

4、模型修正及预测

1. 模型修正

if __name__ == '__main__' :
	ova_svm_clf_fix = LinearSVC(loss='hinge', C=10, multi_class='ovr')
	ova_svm_clf_fix.fit(x_train_in, y_train_in)
	## 交叉验证出预测
	y_prd_fix = cross_val_predict(ova_svm_clf_fix, x_train_in, y_train_in, cv=3)
	## 评估 混淆矩阵
	conf_m_fix = confusion_matrix(y_train_in, y_prd_fix)

	print("整体准确性:{}".format(clf_correct(y_train_in, y_prd_fix)))
	plt_confm_fix = plot_conf_m(conf_m_fix)
	plt_confm_fix.plt_conf_m(), plt.title("Focus on the correct prediction")
	plt_confm_fix.plt_error_conf_m(), plt.title("Focus on the error prediction")
	plt.show()

## 整体准确率0.91204

增大C 虽然提高了整体的准确率,对准确率并没有明显好转,可见线性核对该数据分类效果不明显。所以改用高斯核进行拟合。
在这里插入图片描述

from sklearn.svm import SVC
from sklearn.metrics import classification_report

if __name__ == '__main__': # ova
	ova_svm_clf_rbf = SVC(kernel='rbf',gamma = 'auto', C = 15, cache_size= 8000, decision_function_shape = 'ovr')
	ova_svm_clf_rbf.fit(x_train_in, y_train_in)
	y_prd_rbf = ova_svm_clf_rbf.predict(x_train_in)
	print('整体准确率{}'.format(clf_correct(y_train_in, y_prd_rbf))) # 0.90
	conf_m_rbf = confusion_matrix(y_train_in, y_prd_rbf)
	plt_confm_rbf = plot_conf_m(conf_m_rbf)
	plt_confm_rbf.plt_conf_m(), plt.title("Focus on the correct prediction")
	plt_confm_rbf.plt_error_conf_m(), plt.title("Focus on the error prediction")
	plt.show()
	# 输出详细报告
	print(classification_report(y_train_in, y_prd_rbf))

"""
# 整体准确率:0.9831632653061224
             precision    recall  f1-score   support
        0.0       0.99      0.99      0.99      4832
        1.0       0.99      0.99      0.99      5514
        2.0       0.98      0.99      0.99      4893
        3.0       0.98      0.97      0.97      4999
        4.0       0.98      0.98      0.98      4777
        5.0       0.98      0.98      0.98      4419
        6.0       0.99      0.99      0.99      4813
        7.0       0.98      0.98      0.98      5105
        8.0       0.98      0.98      0.98      4777
        9.0       0.98      0.97      0.97      4871
avg / total       0.98      0.98      0.98     49000

"""

高斯核的准确率明显提升了,但对9和4 与 3和5 的识别还是不是十分精确
在这里插入图片描述

2. 模型预测

if __name__ == '__main__' :
	y_test_prd = ova_svm_clf_fix.predict(x_test)
	print("整体准确性:{}".format(clf_correct(y_train, y_test_prd)))
	plt_confm_test = plot_conf_m(conf_m)
	plt_confm_test.plt_conf_m(), plt.title("Focus on the correct prediction")
	plt_confm_test.plt_error_conf_m(), plt.title("Focus on the error prediction")
	plt.show()
	# 输出详细报告
	print(classification_report(y_test_in, y_test_prd))

"""
# 整体准确性:0.9615238095238096
            precision    recall  f1-score   support
        0.0       0.97      0.99      0.98      2071
        1.0       0.97      0.98      0.98      2363
        2.0       0.96      0.97      0.96      2097
        3.0       0.95      0.95      0.95      2142
        4.0       0.96      0.96      0.96      2047
        5.0       0.96      0.94      0.95      1894
        6.0       0.97      0.98      0.97      2063
        7.0       0.97      0.96      0.97      2188
        8.0       0.95      0.95      0.95      2048
        9.0       0.94      0.94      0.94      2087
avg / total       0.96      0.96      0.96     21000

"""

在这里插入图片描述

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Scc_hy

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值