tf2.0学习笔记2

本文介绍了使用TensorFlow进行图像分类的过程,包括数据加载、预处理、MobileNetV2模型选择、训练与验证,以及模型保存和回调机制。作者展示了如何构建并训练一个针对特定目录下图像的分类模型,涵盖了数据增强、模型结构和训练策略。
import datetime
starttime = datetime.datetime.now()
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback, TensorBoard
import pathlib
import os
import random
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
gpu_ok = tf.test.is_gpu_available()
print("tf version:", tf.__version__)
print("use GPU", gpu_ok)
####################
data_dir = "train-val"
data_root = pathlib.Path(data_dir)
for item in data_root.iterdir():
    print(item)

all_image_paths = list(data_root.glob("*/*"))

image_count = len(all_image_paths)

print("all num:",image_count)
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)

label_names = sorted(item.name for item in data_root.glob("*/") if item.is_dir()) 
print(label_names)

label_to_index = dict((name,index)for index,name in enumerate(label_names))
print(label_to_index)

all_image_labels = [pathlib.Path(path).parent.name for path in all_image_paths]
print(all_image_labels[:5],len(all_image_labels))

item_labels = [label_to_index[label]for label in all_image_labels]
print(all_image_paths[:5],len(all_image_paths))
print(item_labels[:5],len(all_image_labels))

def load_and_preporocess_image(path):
    image = tf.io.read_file(path) # 读取图片
    image = tf.image.decode_jpeg(image,channels=3) # 对图片进行解码
    image = tf.image.resize(image,[96,96]) # 定义图片形状
    image = tf.cast(image,tf.float32) # 改变图片的数据类型
    image = image/255.0 # 归一化
    image = 2*image-1  #  归一化到-1到1 之间
    return image

path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) # 创建样本数据集
AUTOTUNE = tf.data.experimental.AUTOTUNE
image_ds = path_ds.map(load_and_preporocess_image,num_parallel_calls=AUTOTUNE)
label_ds = tf.data.Dataset.from_tensor_slices(item_labels)
image_label_ds = tf.data.Dataset.zip((image_ds,label_ds))
print(image_label_ds)

test_count = int(image_count*0.2) # 测试集取百分之20
train_count = image_count-test_count
train_data = image_label_ds.skip(test_count) # 训练集 : skip 跳过 test数据集
test_data = image_label_ds.take(test_count) #  测试集
BATCH_SIZE = 8 # 定义批次
train_data = train_data.shuffle(buffer_size=int(train_count/20)).repeat(-1) # 对训练集进行乱序,repeat(-1) 一直重复
train_data = train_data.batch(BATCH_SIZE)
train_data = train_data.prefetch(buffer_size=AUTOTUNE)

test_data = test_data.batch(BATCH_SIZE)
mobile_net = tf.keras.applications.MobileNetV2(input_shape=(96, 96, 3),
                                               include_top=False)
inputs = tf.keras.Input(shape=(96, 96, 3))
x = mobile_net(inputs)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x1 = tf.keras.layers.Dense(1024,activation="relu")(x)
x2 = tf.keras.layers.Dense(2048,activation="relu")(x1)
out_item = tf.keras.layers.Dense(len(label_names),
                                 activation="softmax")(x2)
model = tf.keras.Model(inputs = inputs,
                       outputs = [out_item])
model.summary()
model.compile(optimizer="adam",
             loss="sparse_categorical_crossentropy", 
             metrics=["acc"])
train_steps = train_count//BATCH_SIZE
test_steps = test_count//BATCH_SIZE

earlystop = EarlyStopping(monitor='val_loss', patience=100, mode='min',
                                  restore_best_weights=True)

tensorboard = TensorBoard(log_dir='./log', write_graph=True, update_freq='epoch')
filepath = os.path.join("model/","chepai-{epoch:02d}-{val_loss:.2f}.h5")
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1,
                                     save_best_only=True,
                                     mode='min', save_weights_only=False)

history = model.fit_generator(train_data,
                              epochs=20,
                              steps_per_epoch=train_steps,
                              validation_data=test_data,
                              validation_steps=test_steps,
                              callbacks=[checkpoint, earlystop, tensorboard], verbose=2,
                              shuffle=False,
                              use_multiprocessing=True, workers=4)


endtime = datetime.datetime.now()
print("Start Trainning:",starttime)
print("Finished Trainning:",endtime)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值