2017年robomaster比赛中,大神符环节使用的是
# decoding:utf-8
import os
import cv2
import numpy as np
import codecs
from cv2.ml import VAR_ORDERED
import codecs
from cv2.ml import VAR_ORDERED
from canny import *
from find_contours import *
import numpy as np
import cPickle
import gzip
def vectorized_result(j):
e = np.zeros((10, 1))
e[j] = 1.0
return e
def load_data():
mnist = gzip.open(os.path.join('data', 'mnist.pkl.gz'), 'rb')
training_data, classification_data, test_data = cPickle.load(mnist)
mnist.close()
return training_data, classification_data, test_data
def wrap_data():
tr_d, va_d, te_d = load_data()
# print type(tr_d), type(va_d), type(te_d)
training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
training_results = [vectorized_result(y) for y in tr_d[1]]
training_data = zip(training_inputs, training_results)
validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
validation_data = zip(validation_inputs, va_d[1])
test_input = [np.reshape(x, (784, 1)) for x in te_d[0]]
test_data = zip(test_input, te_d[1])
return training_data, validation_data, test_data
def train_svm(train_file='train_data.txt', test_file= 'train_result.txt'):
svm = cv2.ml.SVM_create()
svm.setType(cv2.ml.SVM_C_SVC)
#自己设置一下SVM参数
svm.setKernel(cv2.ml.SVM_POLY)
t_d = np.loadtxt(train_file, np.float32)
m_d = np.loadtxt(test_file, np.int32)
train_data = cv2.ml.TrainData_create(t_d, cv2.ml.ROW_SAMPLE, m_d)
svm.train(train_data)
return svm
def svm_test(svm, test_data):
le = len(test_data)
sum_tem = 0
for i in range(le):
sample = np.array([test_data[i][0].ravel()], dtype=np.float32).reshape(28, 28)
a, b =svm.predict(np.array([test_data[i][0].ravel()], dtype=np.float32))
if b[0][0] == test_data[i][1] or test_data[i][1] == 0:
sum_tem += 1
print '正确率 ', float(sum_tem * 1.0/ le)
def svm_predict(svm, sample):
resized = sample.copy()
rows, cols = resized.shape
if (rows != 28 or cols != 28) and rows * cols > 0:
resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC)
return svm.predict(np.array([resized.ravel()], dtype=np.float32))
if __name__ == '__main__':
tr, val, test = wrap_data()
save_path = os.path.join('data', '自己想个文件名')
if os.path.exists(save_path):
print 'find it'
svm = cv2.ml.SVM_load(save_path)
else:
svm = train_svm()
svm.save(save_path)
svm_test(svm, test)