LeNet神经网络理论与实战

本文介绍了LeNet卷积神经网络的基本结构与实现方法,该网络由两个卷积层和三个全连接层组成,使用sigmoid作为激活函数,并采用最大池化进行下采样。文章还提供了基于TensorFlow实现的具体代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

          LeNet卷积神经网络是LeGun于1998年提出,是卷积神经网络的开篇之作,通过共享卷积核减少了网络的参数,在统计卷积神经网络的层数时,一般只统计卷积计算层和全连接计算层,其余操作可以认为时卷积计算层的附属。LeNet一共有五层网络,有两个卷积三个全连接。

由于LeNet提出的时候还没有BN操作,所以这儿不需要BN操作。LeNet时代,sigmoid是主流的激活函数。用2*2的池化核,步长为2进行最大池化,不使用全零填充,LeNet时代同样还没有Dropout。

根据CBAPD写出代码的样子如下所示:

 

import numpy as np
import tensorflow as tf
import os
from matplotlib import pyplot as plt
import PySide2
from tensorflow.keras.layers import Conv2D,BatchNormalization,Activation,MaxPooling2D,Dropout,Flatten,Dense
from tensorflow.keras import Model


dirname = os.path.dirname(PySide2.__file__)
plugin_path = os.path.join(dirname, 'plugins', 'platforms')
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = plugin_path
np.set_printoptions(threshold=np.inf)  # 设置打印出所有参数,不要省略

mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 由于fashion数据集是三维的(60000, 28, 28),而cifar10 数据集是四维的,而此网络是用来识别四维的数据所所以需要将3维的输入扩展维4维的输入
x_train = np.expand_dims(x_train, axis=3)
x_test = np.expand_dims(x_test, axis=3)


class LeNet5(Model):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.c1 = Conv2D(filters=6, kernel_size=(5,5), padding='valid')
        self.a1 = Activation('sigmoid')
        self.p1 = MaxPooling2D(pool_size=(2,2), strides=2, padding='valid')

        self.c2 = Conv2D(filters=16, kernel_size=(5, 5), padding='valid')
        self.a2 = Activation('sigmoid')
        self.p2 = MaxPooling2D(pool_size=(2, 2), strides=2, padding='valid')

        self.flatten = Flatten()
        self.f1 = Dense(120, activation='sigmoid')
        self.f2 = Dense(84, activation='sigmoid')
        self.f3 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.c1(x)
        x = self.a1(x)
        x = self.p1(x)

        x = self.c2(x)
        x = self.a2(x)
        x = self.p2(x)

        x = self.flatten(x)
        x = self.f1(x)
        x = self.f2(x)
        y = self.f3(x)

        return y

model = LeNet5()

print(model(x_test).shape)

model.compile(optimizer='adam',
              loss=tf.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = './checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
    print('------------------------load the model---------------------')
    model.load_weights(checkpoint_save_path)


cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=10, epochs=5,
                    validation_data=(x_test, y_test),
                    validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

############################  show  #############################
# 显示训练集和验证集的acc和loss曲线
acc=history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI炮灰

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值