网络
开头
Block
结尾
数据
生成tfrecords
import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
import glob
import random
def img2tfrecord(img_path,label,save_path):
length = len(img_path)
writer = tf.io.TFRecordWriter(save_path)
for i in range(length):
with open(img_path[i],'rb') as f:
image = f.read()
features = {}
features['image'] = tf.train.Feature(bytes_list = tf.train.BytesList(value = [image]))
features['label'] = tf.train.Feature(int64_list = tf.train.Int64List(value = [int(label[i])]))
tf_fetures = tf.train.Features(feature = features)
tf_example = tf.train.Example(features = tf_fetures)
tf_seriallized = tf_example.SerializeToString()
writer.write(tf_seriallized)
writer.close()
def create_val_tfrecords(val_img_dir,class_label):
dir = os.listdir(val_img_dir)
val_img_path = []
val_label = []
for class_name in dir:
this_label = class_label[class_name]
image_dir = val_img_dir + '/' + class_name
this_img_path = glob.glob(image_dir+'/*.jpg')
this_label = np.zeros(len(this_img_path))+this_label
val_img_path.extend(this_img_path)
val_label.extend(this_label)
print('finish ' + '{} to tfrecoed'.format(val_img_dir+'/'+class_name))
img2tfrecord(val_img_path,val_label,'./tfrecords/{}.tfrecoeds'.format(val_img_dir))
print('write '+'./tfrecords/{}.tfrecoeds'.format(val_img_dir))
def create_train_tfrecords(train_img_dir,class_label,test_radio):
dir = os.listdir(train_img_dir)
train_img_path = []
train_label = []
test_img_path = []
test_label = []
for class_name in dir:
this_label = class_label[class_name]
image_dir = train_img_dir + '/' + class_name
this_img_path = glob.glob(image_dir+'/*.jpg')
test_nb = int(len(this_img_path)*test_radio)
random.shuffle(this_img_path)
this_train = this_img_path[:-test_nb]
this_test = this_img_path[-test_nb:]
this_train_label = np.zeros(len(this_train)) + this_label
this_test_label = np.zeros(len(this_test)) + this_label
train_img_path.extend(this_train)
train_label.extend(this_train_label)
test_img_path.extend(this_test)
test_label.extend(this_test_label)
print('finish ' + '{} to tfrecoed'.format(train_img_dir + '/' + class_name))
img2tfrecord(train_img_path, train_label, './tfrecords/{}.tfrecoeds'.format(train_img_dir))
print('write ' + './tfrecords/train.tfrecoeds')
img2tfrecord(test_img_path, test_label, './tfrecords/test.tfrecoeds')
print('write ' + './tfrecords/test.tfrecoeds')
train_img_dir = './train'
val_img_dir = './val'
label = {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
test_radio = 0.1
if __name__ == '__main__':
create_val_tfrecords(val_img_dir,label)
create_train_tfrecords(train_img_dir,label,test_radio)
读tfrecords并训练
import os
import tensorflow as tf
def parse_tfrecord_example(example_queue):
parse_example = tf.io.parse_single_example(serialized=example_queue,features=object_feature)
image = tf.image.decode_jpeg(parse_example['image'],channels=3)
image = tf.image.resize(image,(224,224))
image /= 255.
label = parse_example['label']
label = tf.cast(label, dtype=tf.int32)
label = tf.one_hot(label, depth=5, on_value=1.0, off_value=0.0)
return image,label
def conv2d_nb(input, nb_filter, kernel_size = (1,1),strides = 1, padding = 'same'):
conv = tf.keras.layers.Conv2D(filters=nb_filter,
kernel_size=kernel_size,
strides=strides,
padding=padding,
kernel_regularizer=tf.keras.regularizers.l2(0.0001))(input)
nb = tf.keras.layers.BatchNormalization()(conv)
act = tf.keras.layers.LeakyReLU(alpha=0.1)(nb)
return act
def ADD(input,residual):
input_shape = tf.keras.backend.int_shape(input)
residual_shape = tf.keras.backend.int_shape(residual)
equal_channels = input_shape[3]==residual_shape[3]
x = round(input_shape[1]/residual_shape[1])
y = round(input_shape[2]/residual_shape[2])
if not equal_channels:
input = tf.keras.layers.Conv2D(filters=residual_shape[3],kernel_size=(1,1),strides=(x,y),padding='same')(input)
out = tf.keras.layers.add([input,residual])
return out
def basic_block(input,nb_filter):
conv2 = input
for i in range(2):
conv2 = conv2d_nb(conv2, nb_filter = nb_filter, kernel_size=(3,3),strides=1)
output = ADD(input,conv2)
return output
def residual_block(input,nb_filter,nb_restnet):
residual = input
for i in range(nb_restnet):
residual = basic_block(residual,nb_filter)
return residual
def restnet_18(input_shape = (224,224,3),nb_classs = 5):
inputs = tf.keras.layers.Input(input_shape)
conv1 = conv2d_nb(inputs,nb_filter = 64,kernel_size=(3,3),strides = 2,padding='same')
pool1 = tf.keras.layers.MaxPool2D(pool_size=(3,3),strides=2,padding='same')(conv1)
restnet_1 = residual_block(pool1,nb_filter = 64,nb_restnet = 2)
restnet_2 = residual_block(restnet_1, nb_filter=128, nb_restnet=2)
restnet_3 = residual_block(restnet_2, nb_filter=256, nb_restnet=2)
restnet_4 = residual_block(restnet_3, nb_filter=512, nb_restnet=2)
pool2 = tf.keras.layers.GlobalAvgPool2D()(restnet_4)
output_ = tf.keras.layers.Dense(nb_classs,activation='softmax')(pool2)
model1 = tf.keras.Model(inputs,output_)
model1.summary()
return model1
object_feature = {}
object_feature['image'] = tf.io.FixedLenFeature([], dtype=tf.string)
object_feature['label'] = tf.io.FixedLenFeature([], dtype=tf.int64)
Tfrecords_path = os.listdir('./tfrecords/')
for i in Tfrecords_path:
if i.split('.')[0] == 'val':
val_tfrecords_name = './tfrecords/'+i
elif i.split('.')[0] == 'test':
test_tfrecords_name = './tfrecords/' + i
elif i.split('.')[0] == 'train':
train_tfrecords_name = './tfrecords/' + i
tfrecords2train_dataset = tf.data.TFRecordDataset(train_tfrecords_name)
tfrecords2test_dataset = tf.data.TFRecordDataset(test_tfrecords_name)
tfrecords2val_dataset = tf.data.TFRecordDataset(val_tfrecords_name)
train_dataset = tfrecords2train_dataset.map(parse_tfrecord_example)
test_dataset = tfrecords2test_dataset.map(parse_tfrecord_example)
val_dataset = tfrecords2val_dataset.map(parse_tfrecord_example)
train = train_dataset.shuffle(buffer_size=4000).batch(10)
test = test_dataset.batch(batch_size=1)
val = val_dataset.batch(batch_size=1)
model = restnet_18()
model.compile(optimizer='adam',
loss=tf.losses.CategoricalCrossentropy(),
metrics=['accuracy'])
model.fit(train,batch_size=64, epochs=300,validation_data=test)
model.evaluate(val)
效果很差,有待进一步学习