TensorFlow中的ResNet残差网络实战(1)

本文介绍了如何使用TensorFlow实现ResNet模型,包括数据预处理、模型定义、编译与训练,以及在MNIST数据集上的应用。重点讲解了ResNetBlock和ResNet结构,展示了从导入库到模型评估的完整流程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.导入相应的库

import os 
import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
from tensorflow.keras import datasets,losses,Sequential,optimizers

2.加载MNIST数据集:

(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
x_train,x_test=x_train.astype('float')/255.0,x_test.astype('float32')/255.0

Batch_Size=32
x_train=tf.reshape(x_train,[-1,28,28,1])
x_test=tf.reshape(x_test,[-1,28,28,1])

3.转one_hot编码:

y_train=tf.one_hot(y_train,depth=10).numpy()
y_test=tf.one_hot(y_test,depth=10).numpy()

4.设置全局参数:

EPOCHES=1
batch_Size=32
learning_rate=0.001

5.定义ResNet模型:

#定义每一个ResNetBlock,其中一组中包含多个ResNetBlock,每一个ResNetBlock又包含几层卷积层
class ResNetBlock(tf.keras.Model):
    def __init__(self,filter_num,stride=1):
        super(ResNetBlock,self).__init__()
        self.conv1=tf.keras.layers.Conv2D(filter_num,kernel_size=[3,3],strides=stride,padding='same')
        self.bn1=tf.keras.layers.BatchNormalization()
        self.relu=tf.keras.layers.Activation('relu')
        
        self.conv2=tf.keras.layers.Conv2D(filter_num,kernel_size=[3,3],strides=1,padding='same')
        self.bn2=tf.keras.layers.BatchNormalization()
        
        #当经过卷积层之后,发现经过卷积层之前的shape和经过卷积之后的shape不相同,
        #可以通过1*1的卷积层将它们的shape转换为一样的,再叠加。相当于当前的X通过乘以一个W,转换为相同的shape
        if stride!=1:
            self.downSample=Sequential([
                tf.keras.layers.Conv2D(filter_num,kernel_size=[1,1],strides=stride)
            ])
        #如果相同的话就直接叠加,不需要转换
        else:
            self.downSample=lambda x:x
    def call(self,inputs,training=None):
        #通过当前的第一层卷积层
        out=self.conv1(inputs)
        out=self.bn1(out)
        out=self.relu(out)
        
        #通过当前的第二层卷积层
        out=self.conv2(out)
        out=self.bn2(out)
        
        #最后经过一个f(x)+x
        identity=self.downSample(inputs)
        
        output=tf.keras.layers.add([identity,out])
        output=tf.nn.relu(output)
        return output
class ResNet(tf.keras.Model):
    def __init__(self,layers_num,num_classes=10):
        super(ResNet,self).__init__()
        #开始的输入层经过一个3*3,步长为1的卷积层和最大池化层
        self.stem=Sequential([
            tf.keras.layers.Conv2D(64,kernel_size=[3,3],strides=[1,1]),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            
            tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=[1,1],padding='same') 
        ])
        #通过第一个组
        self.layer1=self.build_resblock(64,layers_num[0])
        #通过第二个组
        self.layer2=self.build_resblock(128,layers_num[1],stride=2)
        #通过第三个组
        self.layer3=self.build_resblock(256,layers_num[2],stride=2)
        #通过第四个组
        self.layer4=self.build_resblock(512,layers_num[3],stride=2)
        
        #经过全局平均池化层
        self.avgPool=tf.keras.layers.GlobalAveragePooling2D()
        
        #经过最后的全连接层
        self.fc=tf.keras.layers.Dense(num_classes)
        
    #这个函数是在构建一个组中的ResNetBlock
    def build_resblock(self,filter_num,blocks,stride=1):
        res_block=Sequential([
            ResNetBlock(filter_num,stride)
        ]) 
        for i in range(1,blocks):
            res_block.add(ResNetBlock(filter_num,stride=1))
        return res_block
    
    def call(self,inputs,training=None):
        x=self.stem(inputs)
        
        x=self.layer1(x)
        x=self.layer2(x)
        x=self.layer3(x)
        x=self.layer4(x)
        
        out=self.avgPool(x)
        
        output=self.fc(out)
        return output

6.调用模型:

def resnet18():
    return ResNet([2,2,2,2])
def resnet34():
    return ResNet([3,4,6,3])

7.模型编译和优化器选择

model=resnet18()
optimizer=optimizers.Adam(learning_rate=learning_rate)
model.compile(optimizer=optimizer,loss=losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
model.build(input_shape=(None,28,28,1))
model.summary()

在这里插入图片描述
8.模型训练

model.fit(x_train,y_train,epochs=EPOCHES,validation_data=(x_test,y_test),verbose=1,batch_size=batch_Size)
model.save_weights('ResNet_34.h5')

在这里插入图片描述

9. 使用测试集数据评估误差和准确率

loss,accuracy=model.evaluate(x_test,y_test,batch_size=batch_Size,verbose=1)

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值