学习一下迁移学习。。。
1.inceptionV3

"""
python=3.6
tf=2.3.1
"""
import os
import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.applications import *
from tensorflow.keras.preprocessing.image import *
from tensorflow import keras as keras
import numpy as np
# import cv2
from PIL import Image
import matplotlib.pyplot as plt
#2.加载数据
train_normal = r"C:\Users\Administrator\PycharmProjects\pythonProject\01_tf2_somking_calling\data\train\normal"
train_phone = r"C:\Users\Administrator\PycharmProjects\pythonProject\01_tf2_somking_calling\data\train\calling"
train_smoke = r"C:\Users\Administrator\PycharmProjects\pythonProject\01_tf2_somking_calling\data\train\smoking"
train_dir = r"C:\Users\Administrator\PycharmProjects\pythonProject\01_tf2_somking_calling\data\train"
val_dir = r"C:\Users\Administrator\PycharmProjects\pythonProject\01_tf2_somking_calling\data\val"#验证集用01_train_test_splict.py生成
val_normal = r"C:\Users\Administrator\PycharmProjects\pythonProject\01_tf2_somking_calling\data\val\normal"
val_phone = r"C:\Users\Administrator\PycharmProjects\pythonProject\01_tf2_somking_calling\data\val\calling"
val_smoke = r"C:\Users\Administrator\PycharmProjects\pythonProject\01_tf2_somking_calling\data\val\smoking"
test_dir = r"C:\Users\Administrator\PycharmProjects\pythonProject\01_tf2_somking_calling\data\test"#测试集是需要我们预测结果
train_normal_num = len(os.listdir(train_normal))
train_phone_num = len(os.listdir(train_phone))
train_smoke_num = len(os.listdir(train_smoke))
val_normal_num = len(os.listdir(val_normal))
val_phone_num = len(os.listdir(val_phone))
val_smoke_num = len(os.listdir(val_smoke))
train_all = train_normal_num + train_phone_num + train_smoke_num
val_all = val_normal_num + val_phone_num + val_smoke_num
test_num = len(os.listdir(test_dir))
print("train normal number: ", train_normal_num)
print("train phone_num: ", train_phone_num)
print("train_smoke_num: ", train_smoke_num)
print("all train images: ", train_all)
print("val normal number: ", val_normal_num)
print("val phone_num: ", val_phone_num)
print("val_smoke_num: ", val_smoke_num)
print("all val images: ", val_all)
print("all test images: ", test_num)
# 3.设置超参数
batch_size = 512 #显存内存允许的情况下尽量大
epochs = 1000 #有早停技术,尽量大
height = 299
width = 299
num_classes = 3
# 4.数据增强
# 读取训练数据并作数据增强
train_datagen = tf.keras

最低0.47元/天 解锁文章
3765

被折叠的 条评论
为什么被折叠?



