神经网络参数初始化对最终结果有重大影响

本文介绍了一个使用TensorFlow实现的手写数字识别模型,并通过调整参数初始化的方法显著提高了模型的准确率,从94%提升到了98%以上。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

神经网络参数初始化对最终的结果有重大影响。下面的代码是从github改编的,一开始使用正态分布初始化参数,stddev设置为1.0,结果正确率始终徘徊在94%,设置为0.01后,最终正确率达到了98%以上。所以参数的初始化需要跑代码测试哪种效果好。

import tensorflow as tf
import numpy as np
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)

trX,trY,teX,teY=mnist.train.images,mnist.train.labels,mnist.test.images,mnist.test.labels

def init_weight(shape):
    return tf.Variable(tf.random_normal(shape,mean=0.0,stddev=0.01),dtype=tf.float32)#stddev设置为1,准确率只有94%
def model(X,weight_1,weight_2,weight_o,keep_prob_input,keep_prob_hidden):
    input_drop=tf.nn.dropout(X,keep_prob=keep_prob_input)
    
    fc1=tf.nn.relu(tf.matmul(input_drop,weight_1))
    fc1_drop=tf.nn.dropout(fc1,keep_prob=keep_prob_hidden)
    
    fc2=tf.nn.relu(tf.matmul(fc1_drop,weight_2))
    fc2_drop=tf.nn.dropout(fc2,keep_prob=keep_prob_hidden)
    
    return tf.matmul(fc2_drop,weight_o)

x=tf.placeholder(tf.float32,shape=[None,784])
y=tf.placeholder(tf.float32,[None,10])

weight_1=init_weight([784,625])
weight_2=init_weight([625,625])
weight_o=init_weight([625,10])
keep_prob_input=tf.placeholder(tf.float32)
keep_prob_hidden=tf.placeholder(tf.float32)
fc_output=model(x,weight_1,weight_2,weight_o,keep_prob_input,keep_prob_hidden)

cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=fc_output,labels=y))
train=tf.train.RMSPropOptimizer(0.001).minimize(cross_entropy)
pred=tf.argmax(fc_output,1)
mini_batch=100
epoch=100
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch_i in range(epoch):
          for batch_i in range(mnist.train.num_examples//mini_batch):
                sess.run(train,feed_dict={x:trX[batch_i*mini_batch:(batch_i+1)*mini_batch],y:trY[batch_i*mini_batch:(batch_i+1)*mini_batch],keep_prob_hidden:0.5,keep_prob_input:0.8})
        
        
         accuracy=np.mean(np.argmax(teY,axis=1)==sess.run(pred,feed_dict={x:teX,y:teY,keep_prob_hidden:1,keep_prob_input:1}))
          print(accuracy)


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值