来自邹博机器学习课件,自己根据教学内容又做了小部分改变,但是部分问题还是不太清楚,故发表于次以供以后探讨。
'''
利用SVM进行手写字体识别
'''
import pandas as pd
import numpy as np
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from PIL import Image
import matplotlib.pyplot as plt
if __name__ == '__main__':
data_train = pd.read_csv('14.optdigits.tra', header=None)
data_test = pd.read_csv('14.optdigits.tes', header=None)
x_train = data_train.iloc[:, :-1]
y_train = data_train.iloc[:, -1]
x_test = data_test.iloc[:, :-1]
y_test = data_test.iloc[:, -1]
# 将ndarray转换图片
x_image = x_train.values.reshape(-1, 8, 8).astype(np.uint8)
x_label = y_train.values
Image.fromarray(255-x_image[5]*15).save('./test.png')
'''原始图片*15,不知道干嘛,反正这样做就会让图片变得清晰'''
# 绘制出前16个图像
plt.figure(figsize=(12, 12))
for i in range(16):
plt.subplot(4, 4, i+1)
plt.imshow(x_image[i], cmap=plt.cm.gray_r)
plt.title