SVM处理mnist字体库

该博客介绍了一个在2017年robomaster比赛中应用SVM处理MNIST字体库的示例。通过加载数据、数据预处理、训练SVM模型、测试模型准确率,并保存模型供后续使用,展示了SVM在图像识别中的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

  2017年robomaster比赛中,大神符环节使用的是
# decoding:utf-8
import os
import cv2
import numpy as np
import codecs
from cv2.ml import VAR_ORDERED
import codecs
from cv2.ml import VAR_ORDERED
from canny import *
from find_contours import *
import numpy as np
import cPickle
import gzip


def vectorized_result(j):
   e = np.zeros((10, 1))
   e[j] = 1.0
   return e


def load_data():
   mnist = gzip.open(os.path.join('data', 'mnist.pkl.gz'), 'rb')
   training_data, classification_data, test_data = cPickle.load(mnist)
   mnist.close()
   return training_data, classification_data, test_data


def wrap_data():
   tr_d, va_d, te_d = load_data()
   # print type(tr_d), type(va_d), type(te_d)
   training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
   training_results = [vectorized_result(y) for y in tr_d[1]]
   training_data = zip(training_inputs, training_results)
   validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
   validation_data = zip(validation_inputs, va_d[1])
   test_input = [np.reshape(x, (784, 1)) for x in te_d[0]]
   test_data = zip(test_input, te_d[1])
   return training_data, validation_data, test_data


def train_svm(train_file='train_data.txt', test_file= 'train_result.txt'):
   svm = cv2.ml.SVM_create()
   svm.setType(cv2.ml.SVM_C_SVC)
   #自己设置一下SVM参数
   svm.setKernel(cv2.ml.SVM_POLY)
   t_d = np.loadtxt(train_file, np.float32)
   m_d = np.loadtxt(test_file, np.int32)
   train_data = cv2.ml.TrainData_create(t_d, cv2.ml.ROW_SAMPLE, m_d)
   svm.train(train_data)
   return svm


def svm_test(svm, test_data):
   le = len(test_data)
   sum_tem = 0
   for i in range(le):
      sample = np.array([test_data[i][0].ravel()], dtype=np.float32).reshape(28, 28)
      a, b =svm.predict(np.array([test_data[i][0].ravel()], dtype=np.float32))
      if b[0][0] == test_data[i][1] or test_data[i][1] == 0:
         sum_tem += 1
   print '正确率 ', float(sum_tem * 1.0/ le)


def svm_predict(svm, sample):
   resized = sample.copy()
   rows, cols = resized.shape
   if (rows != 28 or cols != 28) and rows * cols > 0:
      resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC)
   return svm.predict(np.array([resized.ravel()], dtype=np.float32))


if __name__ == '__main__':
   tr, val, test = wrap_data()
   save_path = os.path.join('data', '自己想个文件名')
   if os.path.exists(save_path):
      print 'find it'
      svm = cv2.ml.SVM_load(save_path)
   else:
      svm = train_svm()
      svm.save(save_path)
   svm_test(svm, test)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值