2.simple_models

本文深入探讨TensorFlow应用,涵盖线性回归、逻辑回归、损失函数、Dropout、Tensorboard、CNN、RNN等核心概念及实践案例,详解模型保存与加载、Inception网络图像分类及验证码生成与识别。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.线性回归

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
x_data=np.linspace(-0.5,0.5,200)[:,np.newaxis]#生成200行1列的二维数据
noise=np.random.normal(0,0.02,x_data.shape)#200行1列的噪声数据
y_data=np.square(x_data)+noise
#1.定义placeholder数据x,y
x=tf.placeholder(tf.float32,[None,1])#(batch,数据维度)
y=tf.placeholder(tf.float32,[None,1])
#2.定义网络结构(输入1,中间层10,输出1)个神经元
weights_L1=tf.Variable(tf.random_normal([1,10]))
biases_L1=tf.Variable(tf.zeros([1,10]))
wx_plus_b_L1=tf.matmul(x,weights_L1)+biases_L1
L1=tf.nn.tanh(wx_plus_b_L1)
weights_L2=tf.Variable(tf.random_normal([10,1]))
biases_L2=tf.Variable(tf.zeros([1,1]))
wx_plus_b_L2=tf.matmul(L1,weights_L2)+biases_L2
prediction=tf.nn.tanh(wx_plus_b_L2)
#3.优化损失,更新参数
loss=tf.reduce_mean(tf.square(y-prediction))
train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for _ in range(2000):
        sess.run(train_step,feed_dict={x:x_data,y:y_data})
    prediction_value=sess.run(prediction,feed_dict={x:x_data})
#画图
plt.figure()
plt.scatter(x_data,y_data)
plt.plot(x_data,prediction_value,'r-',lw=5)
plt.show()

在这里插入图片描述

2.mnist逻辑回归

batch_size=100
n_batch=mnist.train.num_examples//batch_size
#1.定义Placeholder
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
#2.网络结构(输入784,输出10)
w=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
prediction=tf.nn.softmax(tf.matmul(x,w)+b)
#3.代价函数优化模型,更新参数
loss=tf.reduce_mean(tf.square(y-prediction))
train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#4.定义准确率,用测试集检验模型
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(20):
        for batch in range(n_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print('iter'+str(epoch)+' test acc: '+str(acc))

3.损失函数loss:

1.损失函数---------经典损失函数--------交叉熵:交叉熵刻画了两个概率分布之间的距离,它是分类问题中使用比较广的一种损失函数。通过q来表示p的交叉熵为:

在这里插入图片描述

Softmax将神经网络前向传播得到的结果变成概率分布,原始神经网络的输出被用作置信度来生成新的输出,而新的输出满足概率分布的所有要求。

交叉熵函数不是对称的,H(p,q)!=H(q,p),他刻画的是通过概率分布q来表达概率分布p的困难程度。因为正确答案是希望得到的结果,所以当交叉熵作为神经网络的损失函数是,p代表的是正确答案,q代表的是预测值。交叉熵刻画的是两个概率分布的距离,也就是说交叉熵值越小,两个概率分布越接近。tensorflow实现交叉熵代码:
在这里插入图片描述

其中y_代表正确结果,y代表预测结果。tf.clip_by_value()函数的意思是,小于1e-10的数全换成1e-10,大于1的数全换成1。tensorflow中*的意思是对应相同位置的数项乘,不是矩阵的乘法。

因为交叉熵一般会与softmax回归一起使用,所以tensorflow对这两个功能进行了统一封装:

在这里插入图片描述

通过这个命令就可以得到使用了Softmax回归之后的交叉熵。

在只有一个正确答案的分类问题中,tensorflow提供了tf.nn.sparse_softmax_cross_entropy_with_logits函数来进一步加速计算过程。
在这里插入图片描述

2.损失函数---------经典损失函数--------均方误差(MSE,mean squared error):

在这里插入图片描述

其中yi为一个batch中第i个数据的正确答案,而yi‘为神经网络给出的预测值。tensorflow实现代码:
在这里插入图片描述

3.损失函数---------自定义函数-----

tf.greater(A,B) 返回A>B的结果,布尔值

tf.select(C,A,B) C为真时(True),返回A值,为假(False)时返回B值。

这两个函数都是在元素级别进行

4.Dropout防止过拟合

1.定义Keep——prob,2.对神经层dropout

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
batch_size=100
n_batch=mnist.train.num_examples//batch_size
keep_prob=tf.placeholder(tf.float32)#(1.定义Keep——prob)
#1.定义Placeholder
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
#2.网络结构(输入784,中间2000,1000,输出10)
w1=tf.Variable(tf.truncated_normal([784,2000],stddev=0.1))
b1=tf.Variable(tf.zeros([2000])+0.1)
l1=tf.nn.tanh(tf.matmul(x,w1)+b1)
l1_drop=tf.nn.dropout(l1,keep_prob)#(2.对神经层dropout)

w2=tf.Variable(tf.truncated_normal([2000,1000],stddev=0.1))
b2=tf.Variable(tf.zeros([1000])+0.1)
l2=tf.nn.tanh(tf.matmul(l1_drop,w2)+b2)
l2_drop=tf.nn.dropout(l2,keep_prob)

w3=tf.Variable(tf.truncated_normal([1000,10],stddev=0.1))
b3=tf.Variable(tf.zeros([10])+0.1)
prediction=tf.nn.softmax(tf.matmul(l2_drop,w3)+b3)

#3.代价函数优化模型,更新参数
loss=tf.reduce_mean(tf.square(y-prediction))
train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#4.定义准确率,用测试集检验模型
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(20):
        for batch in range(n_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7})
        acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
        print('iter'+str(epoch)+' test acc: '+str(acc))

5.Tensorboard:

1.网络模型结构信息:1.定义命名空间,2.定义filewriter
2.参数变量值信息:1参数概要,2传入变量,3合并所有的变量summary,4run merged,5将变量summary写入到filewriter

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
batch_size=100
n_batch=mnist.train.num_examples//batch_size
#2.1参数概要
def variable_summaries(var):
    with tf.name_scope('summaries'):
        with tf.name_scope('mean'):
            mean=tf.reduce_mean(var)
            tf.summary.scalar('mean',tf.reduce_mean(var))
        with tf.name_scope('stddev'):
            tf.summary.scalar('stddev',tf.sqrt(tf.reduce_mean(tf.square(var-mean))))
        with tf.name_scope('max'):
            tf.summary.scalar('max',tf.reduce_max(var))
        with tf.name_scope('min'):
            tf.summary.scalar('min',tf.reduce_min(var))
        with tf.name_scope('histogram'):
            tf.summary.histogram('histogram',var)
with tf.name_scope('input'):#(1.1定义命名空间)
    x=tf.placeholder(tf.float32,[None,784],name='x_input')
    y=tf.placeholder(tf.float32,[None,10],name='y_input')
with tf.name_scope('layer'):
    w=tf.Variable(tf.zeros([784,10]),name='w')
    variable_summaries(w)#2.2传入变量
    b=tf.Variable(tf.zeros([10]),name='b')
    variable_summaries(b)
    prediction=tf.nn.softmax(tf.matmul(x,w)+b,name='pred')
    
with tf.name_scope('loss'):
    loss=tf.reduce_mean(tf.square(y-prediction),name='loss')
    variable_summaries(loss)
with tf.name_scope('train'):
    train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)
with tf.name_scope('accuracy'):
    with tf.name_scope('correct_pred'):
        correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
    with tf.name_scope('accuracy'):
        accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
        variable_summaries(accuracy)
merged=tf.summary.merge_all()#2.3合并所有的变量summary
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    writer=tf.summary.FileWriter('logs/',sess.graph)#(1.2定义filewriter)
    for epoch in range(20):
        for batch in range(n_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            summary,_=sess.run([merged,train_step],feed_dict={x:batch_xs,y:batch_ys})#2.4run  merged
        writer.add_summary(summary,epoch)#2.5将变量summary写入到filewriter
        acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print('iter'+str(epoch)+' test acc: '+str(acc))

6.CNN

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
batch_size=100
n_batch=mnist.train.num_examples//batch_size
keep_prob=tf.placeholder(tf.float32)
def weight_variable(shape):#权值初始化
    initial=tf.truncated_normal(shape,stddev=0.1)
    return tf.Variable(initial)
def bias_variable(shape):#偏置值初始化
    initial=tf.constant(0.1,shape=shape)
    return tf.Variable(initial)
def conv2d(x,w):#卷积层
    return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')
def max_pool_2x2(x):#池化层
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
#定义输入数据placeholder
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
x_image=tf.reshape(x,[-1,28,28,1])
#第一层,输入28,28,1,输出14,14,32
w_conv1=weight_variable([5,5,1,32])#5x5的卷积核,出入通道1,输出通道32
b_conv1=bias_variable([32])#初始化第一个卷积层的权值和偏置值
h_conv1=tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1)
#第二层,输入14,14,32,输出7,7,64
w_conv2=weight_variable([5,5,32,64])
b_conv2=bias_variable([64])
h_conv2=tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2)
h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])#flatten
#全连接层
w_fc1=weight_variable([7*7*64,1024])
b_fc1=bias_variable([1024])
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
w_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
prediction=tf.matmul(h_fc1_drop,w_fc2)+b_fc2
#损失,优化,准确率
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
accuracy=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction,1),tf.argmax(y,1)),tf.float32))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(20):
        for batch in range(n_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7})
        acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
        print('iter '+str(epoch)+'testing accuracy: '+str(acc))

7.RNN

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)
n_inputs=28
max_time=28
lstm_size=100
n_classes=10
batch_size=50
n_batch=mnist.train.num_examples//batch_size
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
weights=tf.Variable(tf.truncated_normal([lstm_size,n_classes],stddev=0.1))
biases=tf.Variable(tf.constant(0.1,shape=[n_classes]))
def RNN(x,weights,biases):
    inputs=tf.reshape(x,[-1,max_time,n_inputs])
    lstm_cell= tf.nn.rnn_cell.BasicLSTMCell(lstm_size)
    outputs,final_state=tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
    results=tf.matmul(final_state[1],weights)+biases
    return results
prediction=RNN(x,weights,biases)
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
accuracy=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)),tf.float32))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(5):
        for batch in range(n_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print('iter '+str(epoch)+'testing accuracy: '+str(acc))

8.保存模型

1.创建saver对象,2.调用save方法

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
batch_size=100
n_batch=mnist.train.num_examples//batch_size
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
w=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
prediction=tf.nn.softmax(tf.matmul(x,w)+b)
loss=tf.reduce_mean(tf.square(y-prediction))
train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
saver=tf.train.Saver()#1.(创建saver对象)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(20):
        for batch in range(n_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print('iter'+str(epoch)+' test acc: '+str(acc))
    saver.save(sess,'net/my_net.ckpt')#2.(调用save方法)

9.导入模型

1.创建saver对象;2.调用restoe方法

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
batch_size=100
n_batch=mnist.train.num_examples//batch_size
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
w=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
prediction=tf.nn.softmax(tf.matmul(x,w)+b)
loss=tf.reduce_mean(tf.square(y-prediction))
train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
saver=tf.train.Saver()#1.(创建saver对象)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
    saver.restore(sess,'net/my_net.ckpt')#2.(调用restoe方法)
    print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))

10.下载exception-v3模型

import tensorflow as tf
import os
import tarfile
import requests
#1.模型的下载地址
inception_pretrain_model_url='http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
#2.1.模型存放地址
inception_pretrain_model_dir='inception_model'
if not os.path.exists(inception_pretrain_model_dir):
    os.makedirs(inception_pretrain_model_dir)
#2.2模型结构图存放地址
log_dir='inception_log'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
#3.1获取模型文件名以及文件路径
filename=inception_pretrain_model_url.split('/')[-1]
filepath=os.path.join(inception_pretrain_model_dir,filename)
#3.2获取模型结构文件名以及路径,classify_image_graph_def.pb是google训练好的模型,将要写入到结构图文件夹下
inception_graph_def_file=os.path.join(inception_pretrain_model_dir,'classify_image_graph_def.pb')
#4.1下载模型,解压文件
if not os.path.exists(filepath):
    print('download:',filename)
    r=requests.get(inception_pretrain_model_url,stream=True)
    with open(filepath,'wb') as f:
        for chunk in r.iter_content(chunk_size=1024):
            if chunk:
                f.write(chunk)
print('finish:',filename)
tarfile.open(filepath,'r:gz').extractall(inception_pretrain_model_dir)
#4.2下载模型并保存结构图
with tf.Session() as sess:
    #创建一个图来存放google训练好的模型
    with tf.gfile.FastGFile(inception_graph_def_file,'rb') as f:
        graph_def=tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def,name='')
    #保存图结构
    writer=tf.summary.FileWriter(log_dir,sess.graph)
    writer.close()

11.利用inception网络对图像进行分类

import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt
class NodeLookup(object):#创建数字类别与文字类别的字典
    def __init__(self):
        label_lookup_path='inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'
        uid_lookup_path='inception_model/imagenet_synset_to_human_label_map.txt'
        self.node_lookup=self.load(label_lookup_path,uid_lookup_path)#下载两个文件夹内容
    def load(self,label_lookup_path,uid_lookup_path):
        uid_to_human={}#编号:文字标签
        proto_as_ascii_lines=tf.gfile.GFile(uid_lookup_path).readlines()
        for line in proto_as_ascii_lines:#对第一个文本内容进行解析,
            uid=line.strip('\n').split('\t')[0]#编号
            human_string=line.strip('\n').split('\t')[1]#文字标签
            uid_to_human[uid]=human_string#编号:文字标签
        node_id_to_uid={}#类别序号:编号
        proto_as_ascii=tf.gfile.GFile(label_lookup_path).readlines()#对第二个文本内容进行解析,
        for line in proto_as_ascii:
            if line.startswith('  target_class:'):
                target_class=int(line.split(':')[1])#类别序号
            if line.startswith('  target_class_string:'):
                target_class_string=line.split(':')[1]#编号
                node_id_to_uid[target_class]=target_class_string[2:-2]#类别序号:编号
        node_id_to_name={}#类别序号:文字标签
        for key,val in node_id_to_uid.items():
            name=uid_to_human[val]
            node_id_to_name[key]=name#类别序号:文字标签
        return node_id_to_name
    def id_to_string(self,node_id):
        if node_id not in self.node_lookup:
            return ''
        return self.node_lookup[node_id]
#创建一个图用来存放google训练好的模型
with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb','rb') as f:
    graph_def=tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def,name='')
with tf.Session() as sess:
    softmax_tensor=sess.graph.get_tensor_by_name('softmax:0')#模型的输出层起个名字
    for root,dirs,files in os.walk('images/'):#遍历此文件夹下的所有文件
        for file in files:
            image_data=tf.gfile.FastGFile(os.path.join(root,file),'rb').read()
            predictions=sess.run(softmax_tensor,{'DecodeJpeg/contents:0':image_data})#预测
            predictions=np.squeeze(predictions)#把结果转化成一维的
            image_path=os.path.join(root,file)
            print('image_path')
            img=Image.open(image_path)
            plt.imshow(img)
            plt.axis('off')
            plt.show()
            top_k=predictions.argsort()[-5:][::-1]
            node_lookup=NodeLookup()
            for node_id in top_k:
                human_string=node_lookup.id_to_string(node_id)
                score=predictions[node_id]
                print('%s (score=%.5f)\n'%(human_string,score))

12.多任务,文件转换成tfrecord格式

#1,生成验证码
from captcha.image import ImageCaptcha
import numpy as np
from PIL import Image
import random
import sys
number=['0','1','2','3','4','5','6','7','8','9']
def random_captcha_text(char_set=number,captcha_size=4):
    captcha_text=[]
    for i in range(captcha_size):
        c=random.choice(char_set)
        captcha_text.append(c)
    return captcha_text
def gen_captch_text_and_image():
    captcha_text=random_captcha_text()
    captcha_text=''.join(captcha_text)#列表字符拼接成字符串
    image=ImageCaptcha()
    captcha=image.generate(captcha_text)
    image.write(captcha_text,'captcha/images/'+captcha_text+'.jpg')
num=1000
if __name__=='__main__':
    for i in range(num):
        gen_captch_text_and_image()
        sys.stdout.write('\r>> Creating image %d/%d'%(i+1,num))
        sys.stdout.flush()
    sys.stdout.write('\n')
    sys.stdout.flush()
    print('生成完毕')
import os
import tensorflow as tf
#2.将图片文件转换成tfrecord文件格式
_NUM_TEST=50
_RANDOM_SEED=0
DATASET_DIR='D:/DeepLearning/Github小项目/0.studay/5.tensorflow/captcha/images/'
TFRECORD_DIR='D:/DeepLearning/Github小项目/0.studay/5.tensorflow/captcha/'
def _dataset_exists(dataset_dir):
    for split_name in ['train','test']:
        output_filename=os.path.join(dataset_dir,split_name+'.tfrecords')
        if not tf.gfile.Exists(output_filename):
            return False
    return True
def _get_filenames_and_classes(dataset_dir):
    photo_filenames=[]
    for filename in os.listdir(dataset_dir):
        path=os.path.join(dataset_dir,filename)
        photo_filenames.append(path)
    return photo_filenames
def int64_feature(values):
    if not isinstance(values,(tuple,list)):
        values=[values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def image_to_tfexample(image_data,label0,label1,label2,label3):
    return tf.train.Example(features=tf.train.Features(feature={'image':bytes_feature(image_data),
                                                               'label0':int64_feature(label0),
                                                               'label1':int64_feature(label1),
                                                               'label2':int64_feature(label2),
                                                               'label3':int64_feature(label3)}))
def _convert_dataset(split_name,filenames,dataset_dir):
    assert split_name in ['train','test']
    with tf.Session() as sess:
        output_filename=os.path.join(TFRECORD_DIR,split_name+'.tfrecords')
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
            for i,filename in enumerate(filenames):
                try:
                    sys.stdout.write('\r>>Converting iamge %d/%d'%(i+1,len(filenames)))
                    sys.stdout.flush()
                    image_data=Image.open(filename)
                    image_data=image_data.resize((224,224))
                    image_data=np.array(image_data.convert('L'))#转换成黑白图像
                    image_data=image_data.tobytes()
                    labels=filename.split('/')[-1][0:4]
                    num_labels=[]
                    for j in range(4):
                        num_labels.append(int(labels[j]))
                    example=image_to_tfexample(image_data,num_labels[0],num_labels[1],num_labels[2],num_labels[3])
                    tfrecord_writer.write(example.SerializeToString())
                except IOError as e:
                    print('could not read:',filename)
                    print('error:',e)
                    print('skip it \n')
    sys.stdout.write('\n')
    sys.stdout.flush()

if _dataset_exists(TFRECORD_DIR):
    print('tfrecode文件已存在')
else:
    photo_filenames=_get_filenames_and_classes(DATASET_DIR)
    random.seed(_RANDOM_SEED)
    random.shuffle(photo_filenames)
    training_filenames=photo_filenames[_NUM_TEST:]
    testing_filenames=photo_filenames[:_NUM_TEST]
    _convert_dataset('train',training_filenames,DATASET_DIR)
    _convert_dataset('test',testing_filenames,DATASET_DIR)
    print('生成tfrecode文件')
import numpy as np import matplotlib.pyplot as plt import pandas as pd import tkinter as tk from tkinter import ttk, filedialog, messagebox from PIL import Image, ImageDraw import cv2 import os import csv from sklearn.datasets import load_digits from sklearn.model_selection import train_test_split from sklearn.svm import SVC from sklearn.tree import DecisionTreeClassifier from sklearn.ensemble import RandomForestClassifier from sklearn.neural_network import MLPClassifier from sklearn.neighbors import KNeighborsClassifier from sklearn.naive_bayes import GaussianNB from sklearn.metrics import accuracy_score from sklearn.preprocessing import StandardScaler # 设置中文字体和负号显示 plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"] plt.rcParams["axes.unicode_minus"] = False # 尝试导入XGBoost和LightGBM(使用延迟导入减少启动时间) XGB_INSTALLED = False LGB_INSTALLED = False try: import xgboost as xgb XGB_INSTALLED = True except ImportError: print("警告: 未安装XGBoost库,无法使用XGBoost模型") try: import lightgbm as lgb LGB_INSTALLED = True except ImportError: print("警告: 未安装LightGBM库,无法使用LightGBM模型") # 定义模型元数据常量(优化参数) MODEL_METADATA = { 'svm': ('支持向量机(SVM)', SVC, StandardScaler, {'probability': True, 'random_state': 42}), 'dt': ('决策树(DT)', DecisionTreeClassifier, None, {'random_state': 42}), 'rf': ('随机森林(RF)', RandomForestClassifier, None, {'n_estimators': 100, 'random_state': 42}), 'mlp': ('多层感知机(MLP)', MLPClassifier, StandardScaler, {'hidden_layer_sizes': (100, 50), 'max_iter': 500, 'random_state': 42}), 'knn': ('K最近邻(KNN)', KNeighborsClassifier, StandardScaler, {'n_neighbors': 5, 'weights': 'distance'}), 'nb': ('高斯朴素贝叶斯(NB)', GaussianNB, None, {}), } # 添加可选模型 if XGB_INSTALLED: MODEL_METADATA['xgb'] = ('XGBoost(XGB)', xgb.XGBClassifier, None, {'objective': 'multi:softmax', 'random_state': 42}) if LGB_INSTALLED: MODEL_METADATA['lgb'] = ('LightGBM(LGB)', lgb.LGBMClassifier, None, { 'objective': 'multiclass', 'random_state': 42, 'num_class': 10, 'max_depth': 5, 'min_child_samples': 10, 'learning_rate': 0.1, 'force_col_wise': True }) class ModelFactory: @staticmethod def get_split_data(digits_dataset): """数据集划分""" X, y = digits_dataset.data, digits_dataset.target return train_test_split(X, y, test_size=0.3, random_state=42) @classmethod def create_model(cls, model_type): """创建模型和数据标准化器""" if model_type not in MODEL_METADATA: raise ValueError(f"未知模型类型: {model_type}") name, model_cls, scaler_cls, params = MODEL_METADATA[model_type] if not model_cls: raise ImportError(f"{name}模型依赖库未安装") model = model_cls(**params) scaler = scaler_cls() if scaler_cls else None return model, scaler @staticmethod def train_model(model, X_train, y_train, scaler=None, model_type=None): """训练模型""" if scaler: X_train = scaler.fit_transform(X_train) if model_type == 'lgb' and isinstance(X_train, np.ndarray): X_train = pd.DataFrame(X_train) model.fit(X_train, y_train) return model @staticmethod def evaluate_model(model, X_test, y_test, scaler=None, model_type=None): """评估模型""" if scaler: X_test = scaler.transform(X_test) if model_type == 'lgb' and isinstance(X_test, np.ndarray) and hasattr(model, 'feature_name_'): X_test = pd.DataFrame(X_test, columns=model.feature_name_) y_pred = model.predict(X_test) return accuracy_score(y_test, y_pred) @classmethod def train_and_evaluate(cls, model_type, X_train, y_train, X_test, y_test): """训练并评估模型""" try: model, scaler = cls.create_model(model_type) model = cls.train_model(model, X_train, y_train, scaler, model_type) accuracy = cls.evaluate_model(model, X_test, y_test, scaler, model_type) return model, scaler, accuracy except Exception as e: print(f"模型 {model_type} 训练/评估错误: {str(e)}") raise @classmethod def evaluate_all_models(cls, digits_dataset): """评估所有可用模型""" print("\n=== 模型评估 ===") X_train, X_test, y_train, y_test = cls.get_split_data(digits_dataset) results = [] for model_type in MODEL_METADATA: name = MODEL_METADATA[model_type][0] print(f"评估模型: {name} ({model_type})") if not MODEL_METADATA[model_type][1]: results.append({"模型名称": name, "准确率": "N/A"}) continue try: _, _, accuracy = cls.train_and_evaluate( model_type, X_train, y_train, X_test, y_test ) results.append({"模型名称": name, "准确率": f"{accuracy:.4f}"}) except Exception as e: results.append({"模型名称": name, "准确率": f"错误: {str(e)}"}) # 按准确率排序 results.sort( key=lambda x: float(x["准确率"]) if isinstance(x["准确率"], str) and x["准确率"].replace('.', '', 1).isdigit() else -1, reverse=True ) print(pd.DataFrame(results)) return results class HandwritingBoard: CANVAS_SIZE = 300 # 固定画布尺寸 BRUSH_SIZE = 12 # 画笔大小 def __init__(self, root, model_factory, digits): self.root = root self.root.title("手写数字识别系统") self.root.geometry("900x650") self.model_factory = model_factory self.digits = digits self.model_cache = {} self.current_model = None self.scaler = None self.current_model_type = None self.has_drawn = False self.custom_data = [] self.drawing = False self.last_x = self.last_y = 0 # 自定义数据目录 self.data_dir = "custom_digits_data" os.makedirs(self.data_dir, exist_ok=True) # 初始化画布 self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255) self.draw_obj = ImageDraw.Draw(self.image) self.create_widgets() self.init_default_model() def create_widgets(self): """创建界面组件""" main_frame = tk.Frame(self.root) main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) # 左侧绘图区域 left_frame = tk.Frame(main_frame) left_frame.pack(side=tk.LEFT, fill=tk.BOTH, padx=(0, 10)) canvas_frame = tk.LabelFrame(left_frame, text="绘制区域", font=("Arial", 10)) canvas_frame.pack(fill=tk.BOTH, expand=True) self.canvas = tk.Canvas(canvas_frame, bg="white", width=self.CANVAS_SIZE, height=self.CANVAS_SIZE) self.canvas.pack(padx=5, pady=5) self.canvas.bind("<Button-1>", self.start_draw) self.canvas.bind("<B1-Motion>", self.draw) self.canvas.bind("<ButtonRelease-1>", self.stop_draw) # 添加绘制提示 self.canvas.create_text( self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2, text="绘制数字", fill="gray", font=("Arial", 16) ) # 绘图控制按钮 btn_frame = tk.Frame(left_frame) btn_frame.pack(fill=tk.X, pady=(5, 0)) tk.Button(btn_frame, text="识别", command=self.recognize, width=8).pack(side=tk.LEFT, padx=2) tk.Button(btn_frame, text="清除", command=self.clear_canvas, width=8).pack(side=tk.LEFT, padx=2) tk.Button(btn_frame, text="样本", command=self.show_samples, width=8).pack(side=tk.LEFT, padx=2) # 右侧控制面板 right_frame = tk.Frame(main_frame) right_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True) # 模型选择 model_frame = tk.LabelFrame(right_frame, text="模型选择", font=("Arial", 10)) model_frame.pack(fill=tk.X, pady=(0, 10)) tk.Label(model_frame, text="选择模型:").pack(side=tk.LEFT, padx=5) self.available_models = [] for model_type, (name, _, _, _) in MODEL_METADATA.items(): if MODEL_METADATA[model_type][1]: self.available_models.append((model_type, name)) self.model_var = tk.StringVar() self.model_combobox = ttk.Combobox( model_frame, textvariable=self.model_var, values=[name for _, name in self.available_models], state="readonly", width=20 ) self.model_combobox.current(0) self.model_combobox.pack(side=tk.LEFT, padx=5) self.model_combobox.bind("<<ComboboxSelected>>", self.on_model_select) # 当前模型信息 self.model_label = tk.Label( model_frame, text="", font=("Arial", 10), relief=tk.SUNKEN, padx=5, pady=2, width=30 ) self.model_label.pack(side=tk.RIGHT, padx=5) # 识别结果 result_frame = tk.LabelFrame(right_frame, text="识别结果", font=("Arial", 10)) result_frame.pack(fill=tk.X, pady=(0, 10)) self.result_label = tk.Label( result_frame, text="请绘制数字", font=("Arial", 24), pady=10 ) self.result_label.pack() self.prob_label = tk.Label( result_frame, text="", font=("Arial", 10) ) self.prob_label.pack() # 置信度可视化 confidence_frame = tk.LabelFrame(right_frame, text="识别置信度", font=("Arial", 10)) confidence_frame.pack(fill=tk.X, pady=(0, 10)) self.confidence_canvas = tk.Canvas( confidence_frame, bg="white", height=40 ) self.confidence_canvas.pack(fill=tk.X, padx=5, pady=5) # 候选数字 candidates_frame = tk.LabelFrame(right_frame, text="可能的数字", font=("Arial", 10)) candidates_frame.pack(fill=tk.X, pady=(0, 10)) columns = ("数字", "概率") self.candidates_tree = ttk.Treeview( candidates_frame, columns=columns, show="headings", height=4 ) for col in columns: self.candidates_tree.heading(col, text=col) self.candidates_tree.column(col, width=70, anchor=tk.CENTER) scrollbar = ttk.Scrollbar( candidates_frame, orient=tk.VERTICAL, command=self.candidates_tree.yview ) self.candidates_tree.configure(yscroll=scrollbar.set) self.candidates_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) scrollbar.pack(side=tk.RIGHT, fill=tk.Y) # 模型性能 performance_frame = tk.LabelFrame(right_frame, text="模型性能对比", font=("Arial", 10)) performance_frame.pack(fill=tk.BOTH, expand=True) columns = ("模型名称", "准确率") self.performance_tree = ttk.Treeview( performance_frame, columns=columns, show="headings", height=8 ) for col in columns: self.performance_tree.heading(col, text=col) self.performance_tree.column(col, width=100, anchor=tk.CENTER) scrollbar = ttk.Scrollbar( performance_frame, orient=tk.VERTICAL, command=self.performance_tree.yview ) self.performance_tree.configure(yscroll=scrollbar.set) self.performance_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) scrollbar.pack(side=tk.RIGHT, fill=tk.Y) # 训练集管理 train_frame = tk.Frame(right_frame) train_frame.pack(fill=tk.X, pady=(10, 0)) tk.Button( train_frame, text="保存为训练样本", command=self.save_as_training_sample, width=15 ).pack(side=tk.LEFT, padx=2) tk.Button( train_frame, text="保存全部训练集", command=self.save_all_training_data, width=15 ).pack(side=tk.LEFT, padx=2) tk.Button( train_frame, text="加载训练集", command=self.load_training_data, width=15 ).pack(side=tk.LEFT, padx=2) tk.Button( train_frame, text="性能图表", command=self.show_performance_chart, width=15 ).pack(side=tk.LEFT, padx=2) # 状态信息 self.status_var = tk.StringVar(value="就绪") status_bar = tk.Label( self.root, textvariable=self.status_var, bd=1, relief=tk.SUNKEN, anchor=tk.W, font=("Arial", 9) ) status_bar.pack(side=tk.BOTTOM, fill=tk.X) def start_draw(self, event): """开始绘制""" self.drawing = True self.last_x, self.last_y = event.x, event.y def draw(self, event): """绘制""" if not self.drawing: return x, y = event.x, event.y # 在画布上绘制 self.canvas.create_line( self.last_x, self.last_y, x, y, fill="black", width=self.BRUSH_SIZE, capstyle=tk.ROUND, smooth=True ) # 在图像上绘制 self.draw_obj.line( [self.last_x, self.last_y, x, y], fill=0, width=self.BRUSH_SIZE ) self.last_x, self.last_y = x, y def stop_draw(self, event): """停止绘制""" self.drawing = False self.has_drawn = True self.status_var.set("已绘制数字,点击'识别'进行识别") def clear_canvas(self): """清除画布""" self.canvas.delete("all") self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255) self.draw_obj = ImageDraw.Draw(self.image) # 添加绘制提示 self.canvas.create_text( self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2, text="绘制数字", fill="gray", font=("Arial", 16) ) self.result_label.config(text="请绘制数字") self.prob_label.config(text="") self.clear_confidence_display() self.has_drawn = False self.status_var.set("画布已清除") def clear_confidence_display(self): """清除置信度显示""" self.confidence_canvas.delete("all") self.confidence_canvas.create_text( 150, 20, text="识别后显示置信度", fill="gray", font=("Arial", 10) ) for item in self.candidates_tree.get_children(): self.candidates_tree.delete(item) def preprocess_image(self): """预处理手写数字图像""" img_array = np.array(self.image) # 高斯模糊降噪 img_array = cv2.GaussianBlur(img_array, (5, 5), 0) # 二值化 _, img_array = cv2.threshold(img_array, 127, 255, cv2.THRESH_BINARY_INV) # 轮廓检测 contours, _ = cv2.findContours(img_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: self.status_var.set("未检测到有效数字,请重新绘制") return None # 找到最大轮廓 c = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(c) # 提取数字区域 digit = img_array[y:y+h, x:x+w] # 填充为正方形 size = max(w, h) padded = np.ones((size, size), dtype=np.uint8) * 255 offset_x = (size - w) // 2 offset_y = (size - h) // 2 padded[offset_y:offset_y+h, offset_x:offset_x+w] = digit # 缩放为8x8 resized = cv2.resize(padded, (8, 8), interpolation=cv2.INTER_AREA) # 归一化 normalized = 16 - (resized / 255 * 16).astype(np.uint8) return normalized.flatten() def recognize(self): """识别手写数字""" if not self.has_drawn: self.status_var.set("请先绘制数字再识别") return if self.current_model is None: self.status_var.set("模型未加载,请选择模型") return # 预处理图像 img_array = self.preprocess_image() if img_array is None: return img_input = img_array.reshape(1, -1) try: # 标准化 if self.scaler: img_input = self.scaler.transform(img_input) # LightGBM特殊处理 if self.current_model_type == 'lgb' and hasattr(self.current_model, 'feature_name_'): img_input = pd.DataFrame(img_input, columns=self.current_model.feature_name_) # 预测 pred = self.current_model.predict(img_input)[0] self.result_label.config(text=f"识别结果: {pred}") # 概率预测 if hasattr(self.current_model, 'predict_proba'): probs = self.current_model.predict_proba(img_input)[0] confidence = probs[pred] # 更新UI self.prob_label.config(text=f"置信度: {confidence:.2%}") self.update_confidence_display(confidence) # 显示候选数字 top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3] self.update_candidates_display(top3) else: self.prob_label.config(text="该模型不支持概率输出") self.clear_confidence_display() self.status_var.set(f"识别完成: 数字 {pred}") except Exception as e: self.status_var.set(f"识别错误: {str(e)}") self.clear_confidence_display() def update_confidence_display(self, confidence): """更新置信度可视化""" self.confidence_canvas.delete("all") # 画布尺寸 canvas_width = self.confidence_canvas.winfo_width() or 300 # 绘制背景 self.confidence_canvas.create_rectangle( 10, 10, canvas_width - 10, 30, fill="#f0f0f0", outline="#cccccc" ) # 绘制置信度条 bar_width = int((canvas_width - 20) * confidence) color = self.get_confidence_color(confidence) self.confidence_canvas.create_rectangle( 10, 10, 10 + bar_width, 30, fill=color, outline="" ) # 绘制文本 self.confidence_canvas.create_text( canvas_width / 2, 20, text=f"{confidence:.1%}", font=("Arial", 10, "bold") ) # 绘制刻度 for i in range(0, 11): x_pos = 10 + i * (canvas_width - 20) / 10 self.confidence_canvas.create_line(x_pos, 30, x_pos, 35, width=1) if i % 2 == 0: self.confidence_canvas.create_text(x_pos, 45, text=f"{i*10}%", font=("Arial", 8)) def get_confidence_color(self, confidence): """根据置信度获取颜色""" if confidence >= 0.9: return "#4CAF50" # 绿色 elif confidence >= 0.7: return "#FFC107" # 黄色 else: return "#F44336" # 红色 def update_candidates_display(self, candidates): """更新候选数字显示""" # 清空现有项 for item in self.candidates_tree.get_children(): self.candidates_tree.delete(item) # 添加新项 for digit, prob in candidates: self.candidates_tree.insert( "", tk.END, values=(digit, f"{prob:.2%}") ) def show_samples(self): """显示样本图像""" plt.figure(figsize=(10, 4)) for i in range(10): plt.subplot(2, 5, i+1) sample_idx = np.where(self.digits.target == i)[0][0] plt.imshow(self.digits.images[sample_idx], cmap="gray") plt.title(f"数字 {i}", fontsize=9) plt.axis("off") plt.tight_layout() plt.show() def on_model_select(self, event): """模型选择事件处理""" selected_name = self.model_var.get() model_type = next( (k for k, v in self.available_models if v == selected_name), None ) if model_type: self.change_model(model_type) def change_model(self, model_type): """切换模型""" model_name = MODEL_METADATA[model_type][0] # 从缓存加载 if model_type in self.model_cache: self.current_model, self.scaler, accuracy, self.current_model_type = self.model_cache[model_type] self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})") self.status_var.set(f"已加载模型: {model_name}") return self.status_var.set(f"正在加载模型: {model_name}...") self.root.update() # 更新UI显示状态 try: X_train, X_test, y_train, y_test = self.model_factory.get_split_data(self.digits) self.current_model, self.scaler, accuracy = self.model_factory.train_and_evaluate( model_type, X_train, y_train, X_test, y_test ) self.current_model_type = model_type self.model_cache[model_type] = (self.current_model, self.scaler, accuracy, self.current_model_type) self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})") self.status_var.set(f"模型加载完成: {model_name}, 准确率: {accuracy:.4f}") self.clear_canvas() # 更新性能表格 self.load_performance_data() except Exception as e: self.status_var.set(f"模型加载失败: {str(e)}") self.model_label.config(text="模型加载失败") def init_default_model(self): """初始化默认模型""" self.model_var.set(self.available_models[0][1]) self.change_model(self.available_models[0][0]) def load_performance_data(self): """加载性能数据""" results = self.model_factory.evaluate_all_models(self.digits) # 清空表格 for item in self.performance_tree.get_children(): self.performance_tree.delete(item) # 添加数据 for i, result in enumerate(results): tag = "highlight" if i == 0 else "" self.performance_tree.insert( "", tk.END, values=(result["模型名称"], result["准确率"]), tags=(tag,) ) self.performance_tree.tag_configure("highlight", background="#e6f7ff") def show_performance_chart(self): """显示性能图表""" results = self.model_factory.evaluate_all_models(self.digits) # 提取有效结果 valid_results = [] for result in results: try: accuracy = float(result["准确率"]) valid_results.append((result["模型名称"], accuracy)) except ValueError: continue if not valid_results: messagebox.showinfo("提示", "没有可用的性能数据") return # 排序 valid_results.sort(key=lambda x: x[1], reverse=True) models, accuracies = zip(*valid_results) # 创建图表 plt.figure(figsize=(10, 5)) bars = plt.barh(models, accuracies, color='#2196F3') plt.xlabel('准确率', fontsize=10) plt.ylabel('模型', fontsize=10) plt.title('模型性能对比', fontsize=12) plt.xlim(0, 1.05) # 添加数值标签 for bar in bars: width = bar.get_width() plt.text( width + 0.01, bar.get_y() + bar.get_height()/2, f'{width:.4f}', ha='left', va='center', fontsize=8 ) plt.tight_layout() plt.show() def save_as_training_sample(self): """保存为训练样本""" if not self.has_drawn: self.status_var.set("请先绘制数字再保存") return img_array = self.preprocess_image() if img_array is None: return # 弹出标签输入窗口 label_window = tk.Toplevel(self.root) label_window.title("输入标签") label_window.geometry("300x120") label_window.transient(self.root) label_window.grab_set() tk.Label( label_window, text="请输入数字标签 (0-9):", font=("Arial", 10) ).pack(pady=10) entry = tk.Entry(label_window, font=("Arial", 12), width=5) entry.pack(pady=5) entry.focus_set() def save_with_label(): try: label = int(entry.get()) if label < 0 or label > 9: raise ValueError("标签必须是0-9的数字") self.custom_data.append((img_array.tolist(), label)) self.status_var.set(f"已保存数字 {label} (共 {len(self.custom_data)} 个样本)") label_window.destroy() except ValueError as e: self.status_var.set(f"保存错误: {str(e)}") tk.Button( label_window, text="保存", command=save_with_label, width=10 ).pack(pady=5) def save_all_training_data(self): """保存全部训练数据""" if not self.custom_data: self.status_var.set("没有训练数据可保存") return file_path = filedialog.asksaveasfilename( defaultextension=".csv", filetypes=[("CSV文件", "*.csv")], initialfile="custom_digits.csv", title="保存训练集" ) if not file_path: return try: with open(file_path, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) writer.writerow([f'pixel{i}' for i in range(64)] + ['label']) for img_data, label in self.custom_data: writer.writerow(img_data + [label]) self.status_var.set(f"已保存 {len(self.custom_data)} 个样本到 {os.path.basename(file_path)}") except Exception as e: self.status_var.set(f"保存失败: {str(e)}") def load_training_data(self): """加载训练数据""" file_path = filedialog.askopenfilename( filetypes=[("CSV文件", "*.csv")], title="加载训练集" ) if not file_path: return try: self.custom_data = [] with open(file_path, 'r', newline='', encoding='utf-8') as f: reader = csv.reader(f) next(reader) # 跳过标题 for row in reader: if len(row) != 65: continue img_data = [float(pixel) for pixel in row[:64]] label = int(row[64]) self.custom_data.append((img_data, label)) self.status_var.set(f"已加载 {len(self.custom_data)} 个样本") except Exception as e: self.status_var.set(f"加载失败: {str(e)}") def run(self): """运行应用""" self.root.mainloop() if __name__ == "__main__": digits = load_digits() root = tk.Tk() app = HandwritingBoard(root, ModelFactory, digits) app.run() 根据这个代码重新生成合理代码,里面的组件存在遮挡情况
最新发布
06-23
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我是小z呀

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值