TL-实现CIFAR10分类

本文介绍了使用InceptionV3预训练模型对CIFAR-10数据集进行图像分类的方法,包括数据预处理、模型构建、训练策略和结果评估。通过ImageDataGenerator实现数据增强,提升模型性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

直接上干货
用到的模型比较多可选择合适的模型进行实验

import numpy as np
from tensorflow.keras.utils import to_categorical, normalize
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.datasets import cifar10, mnist
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

def run():
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()

    ##############################################
    # 数据预处理
    X_train = X_train.astype('float32') / 255
    X_test = X_test.astype('float32') / 255
    
    # X_train = normalize(X_train)
    # X_test = normalize(X_test)
    ##############################################

    num_classes = 10
    y_train = to_categorical(y_train, num_classes = num_classes)
    y_test = to_categorical(y_test, num_classes = num_classes)

    print(X_train.shape)
    print(y_train.shape)
    print(X_test.shape)
    print(y_test.shape)

    # from tensorflow.keras.applications.resnet50 import ResNet50
    # base_model = ResNet50(
    #     include_top = False,
    #     weights = "imagenet",
    #     input_shape = None
    # )

    # from tensorflow.keras.applications.xception import Xception
    # base_model = Xception(
    #     include_top = False,
    #     weights = "imagenet",
    #     input_shape = None
    # )

    # from tensorflow.keras.applications.vgg16 import VGG16
    # base_model = VGG16(
    #     include_top = False,
    #     weights = "imagenet",
    #     input_shape = None
    # )

    # from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
    # base_model = MobileNetV2(
    #     include_top = False,
    #     weights = "imagenet",
    #     input_shape = None
    # )

    from tensorflow.keras.applications.inception_v3 import InceptionV3
    base_model = InceptionV3(
        include_top = False,
        weights = "imagenet",
        input_shape = None
    )

    # from tensorflow.keras.applications.densenet import DenseNet121
    # base_model = DenseNet121(
    #     include_top = False,
    #     weights = "imagenet",
    #     input_shape = None
    # )

    ###################################################################
    # 全连接层
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation = 'relu')(x)
    predictions = Dense(num_classes, activation = 'softmax')(x)

    # 模型网络定义
    model = Model(inputs = base_model.input, outputs = predictions)

    model.compile(
        optimizer = Adam(),
        loss = 'categorical_crossentropy',
        metrics = ["acc"]
    )

    # model.summary()
    print("模型网络定义:{}层".format(len(model.layers)))

    # EarlyStopping
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=10,
        verbose=9
    )

    # Reduce Learning Rate
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.1,
        patience=3,
        verbose=1
    )

    ###################################################################
    # 训练参数
    p_batch_size = 128 
    p_epochs =5

    ###################################################################
    # 图片掺水训练
    # 准备图片:ImageDataGenerator
    train_gen  = ImageDataGenerator(
        featurewise_center=True,
        featurewise_std_normalization=True,
        width_shift_range=0.125,
        height_shift_range=0.125,
        horizontal_flip=True)
    test_gen = ImageDataGenerator(
        featurewise_center=True,
        featurewise_std_normalization=True)

    # 数据集前计算
    for data in (train_gen, test_gen):
        data.fit(X_train)

    history = model.fit(
        train_gen.flow(X_train, y_train, batch_size=p_batch_size),
        epochs=p_epochs,
        steps_per_epoch=X_train.shape[0] // p_batch_size,
        validation_data=test_gen.flow(X_test, y_test, batch_size=p_batch_size),
        validation_steps=X_test.shape[0] // p_batch_size,
        callbacks=[early_stopping, reduce_lr])

    # 显示训练结果
    plt.plot(history.history['acc'], label='acc')
    plt.plot(history.history['val_acc'], label='val_acc')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(loc='best')
    plt.show()

    # 模型保存
    model.save("model{}.h5".format(p_epochs))

    ###################################################################
    # 结果评价
    test_loss, test_acc = model.evaluate(
        test_gen.flow(X_test, y_test, batch_size=p_batch_size),
        steps=10)
    print('val_loss: {:.3f}\nval_acc: {:.3f}'.format(test_loss, test_acc ))

run()

训练验证可视化的图,时间有限epoch设置为8

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

VisionX Lab

你的鼓励将是我更新的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值