ANN处理mnist字体库

本文介绍了一个使用OpenCV和Python实现的手写数字识别系统。该系统通过加载MNIST数据集进行训练,并采用ANN(人工神经网络)进行模式识别。文中详细展示了从数据准备到模型训练及测试的过程,并提供了完整的代码实现。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

作者很懒 还是先知贴个代码,open3 + python 自己体会
# decoding:utf-8
import os
import cv2
from cv2.cv2 import *
import codecs
from cv2.ml import VAR_ORDERED
from canny import *
from find_contours import *
import numpy as np
import cPickle
import gzip
# decoding:utf-8

def revel(a):
    list = []
    for i in a:
        list.append(i[0])
    return list

def load_data():
    mnist = gzip.open(os.path.join('data', 'mnist.pkl.gz'), 'rb')
    training_data, classification_data, test_data = cPickle.load(mnist)
    mnist.close()
    return training_data, classification_data, test_data

def wrap_data():
    tr_d, va_d, te_d = load_data()
    # print type(tr_d), type(va_d), type(te_d)
    training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
    training_results = [vectorized_result(y) for y in tr_d[1]]
    training_data = zip(training_inputs, training_results)
    validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
    validation_data = zip(validation_inputs, va_d[1])
    test_input = [np.reshape(x, (784, 1)) for x in te_d[0]]
    test_data = zip(test_input, te_d[1])
    return training_data, validation_data, test_data

def vectorized_result(j):
    e = np.zeros((10, 1))
    e[j] = 1.0
    return e

def create_ANN(hidden=20):
    ann = cv2.ml.ANN_MLP_create()
    ann.setLayerSizes(np.array([64, hidden, 10]))
    ann.setTrainMethod(cv2.ml.ANN_MLP_RPROP)
    ann.setActivationFunction(cv2.ml.ANN_MLP_IDENTITY)
    ann.setTermCriteria((cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 20, 1))
    return ann

def train(ann, samples=10000, epochs=1):
    tr, val, test = wrap_data()
    t_d = np.loadtxt('train_data.txt', np.float32)
    m_d = np.loadtxt('train_result.txt', np.int32)
    for x in xrange(epochs):
        ann.train(t_d, cv2.ml.ROW_SAMPLE, m_d)
    # for x in xrange(epochs):
    #     counter = 0
    #     for img in tr:
    #         if counter > samples:
    #             break
    #         if counter % 1000 == 0:
    #             print "Epoch %d : Trained %d/%d" % (x, counter, samples)
    #         counter += 1
    #         data, digit = img
    #         t_d = np.loadtxt(train_file, np.float32)
    #         m_d = np.loadtxt(test_file, np.int32)
    #         # print 'data', np.array([data], dtype=np.float32).reshape(28, 28),\
    #         #                digit
    #         ann.train(np.array([data.ravel()], dtype=np.float32),\
    #         cv2.ml.ROW_SAMPLE, np.array([digit.ravel()], dtype=np.float32))
    #         # print '看一下训练数据的young', np.array([data.ravel()], dtype=np.float32)
    #         # cv2.imshow('img', np.array([data.ravel()], dtype=np.float32).reshape(28, 28))
    #         # while cv2.waitKey() is not 27:
    #         #     pass
    #         # cv2.destroyWindow('img')
    #     print 'Epoch %d complete' % x
    return ann, test

def test(ann, test_data):
    # for i in range(10):
    #     name = ['sample']
    #     name.append(str(i))
    #     sample = np.array([test_data[i][0].ravel()], dtype=np.float32).reshape(28, 28)
    #     cv2.imshow(str(''.join(name)), sample)
    #     while (cv2.waitKey()!=27):
    #         pass
    #     print ann.predict(np.array([test_data[i][0].ravel()], dtype=np.float32))
    sample = np.array([test_data[4][0].ravel()], dtype=np.float32).reshape(28, 28)
    print 'sample', sample
    cv2.imshow('sample', sample)
    while (cv2.waitKey()!=27):
        pass
    sample_tem = cv2.resize(sample, (8, 8), interpolation=cv2.INTER_CUBIC)
    print ann.predict(np.array([sample_tem.ravel()], dtype=np.float32))

def predict(ann, sample):
    resized = sample.copy()
    rows, cols = resized.shape
    if (rows != 28 or cols != 28) and rows * cols > 0:
        resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC)
    return ann.predict(np.array([resized.ravel()], dtype=np.float32))

def main():
    save_path = os.path.join('data', 'best_ann_test_rp+ident+ddddd')
    if os.path.exists(save_path):
        print 'find it'
        ann = cv2.ml.ANN_MLP_load(save_path)
        a, b, test_data = wrap_data()
    else:
        ann, test_data = train(create_ANN(58), 50000, 10)
        ann.save(save_path)
    test(ann, test_data)

if __name__ == '__main__':
    main()
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值