AI实战:迁移学习之使用ResNet做分类

本文介绍如何使用迁移学习,通过预训练的ResNet101模型进行特征提取,应用于cats_vs_dogs数据集,实现图像分类任务。通过调整学习率和模型结构,最终在数据集上取得良好的分类效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

迁移学习包括:

  • 1、Feature Extraction
  • 2、Fine-Tuning

本文基于 tensorflow2.0,使用 cats_vs_dog 数据集,应用 tf.keras.applications 创建 base model,使用 ResNet101 做 Feature Extraction。

  • 核心代码:
'''
Transfer learning with a pretrained ConvNet: resnet101

参考:
https://tensorflow.google.cn

下载模型位置:
~/.keras/models/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5
'''

from __future__ import absolute_import, division, print_function, unicode_literals
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

keras = tf.keras


#dataset

#Data preprocessing

#Data download
#Use TensorFlow Datasets to load the cats and dogs dataset.
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
                    'cats_vs_dogs', split=list(splits),
                    with_info=True, as_supervised=True)
                    
print(raw_train)
print(raw_validation)
print(raw_test)
                    

get_label_name = metadata.features['label'].int2str

for image, label in raw_train.take(2):
    plt.figure()
    plt.imshow(image)
    plt.title(get_label_name(label))
  

#Format the Data
IMG_SIZE = 160 # All images will be resized to 160x160

def format_example(image, label):
    image = tf.cast(image, tf.float32)
    image = (image/127.5) - 1
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return image, label
  
#shuffle and batch the data
train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)


BATCH_SIZE = 2#32
SHUFFLE_BUFFER_SIZE = 1000

train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)

for image_batch, label_batch in train_batches.take(1):
   pass
print(image_batch.shape)   


#Create the base model
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

# Create the base model from the pre-trained model ResNet101
base_model = tf.keras.applications.ResNet101(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
                                               
feature_batch = base_model(image_batch)
print(feature_batch.shape)

base_model.trainable = False
base_model.summary()

#Add a classification head
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)


prediction_layer = keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)

model = tf.keras.Sequential([
                  base_model,
                  global_average_layer,
                  prediction_layer ])


#Compile the model
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss='binary_crossentropy',
              metrics=['accuracy'])
model.summary()

#Train the model
num_train, num_val, num_test = (
                  metadata.splits['train'].num_examples*weight/10
                  for weight in SPLIT_WEIGHTS )

initial_epochs = 1#10

history = model.fit(train_batches,
                    epochs=initial_epochs,
                    validation_data=validation_batches)


# Save weights to a HDF5 file
model.save_weights('transfer_learning-resnet101-model-cats-dogs.h5', save_format='h5')

# Restore the model's state
#model.load_weights('my_model.h5')


#Learning curves
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

  • 过程输出
    • 1、cats_vs_dog 数据样式
      在这里插入图片描述 在这里插入图片描述
    • 2、学习曲线 在这里插入图片描述
    • 3、训练1个epoch的结果 在这里插入图片描述
    • 4、模型保存到 ./transfer_learning-resnet101-model-cats-dogs.h5
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

szZack

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值