Transfer learning with TensorFlow Hub
- TensorFlow Hub是共享预训练模型组件的一种方式。 有关预训练模型的可搜索列表,请参见TensorFlow模块中心。
https://hub.tensorflow.google.cn/ - 本教程演示:如何将TensorFlow Hub与tf.keras结合使用。 如何使用TensorFlow Hub进行图像分类,如何进行简单的迁移学习。
import matplotlib.pylab as plt
import tensorflow as tf
#!pip install -q tensorflow-hub
#!pip install -q tensorflow-datasets
#豆瓣的速度,谁用谁知道..
#pip install tensorflow-hub -i https://pypi.douban.com/simple --user
#pip install tensorflow-datasets -i https://pypi.douban.com/simple --user
import tensorflow_hub as hub
from tensorflow.keras import layers
1.An ImageNet classifier
- Download the classifier(此例用mobilenet_v2) Use hub.module to load a mobilenet, and tf.keras.layers.Lambda to wrap it up as a keras layer.
- Any TensorFlow 2 compatible image classifier URL from hub.tensorflow.google.cn will work here.
classifier_url ="https://hub.tensorflow.google.cn/google/tf2-preview/mobilenet_v2/classification/2"
IMAGE_SHAPE = (224, 224)
classifier = tf.keras.Sequential([
hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,))#(224,224,3)
])
2.Run it on a single image
Download a single image to try the model on.
import numpy as np
import PIL.Image as Image
grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg
65536/61306 [================================] - 0s 3us/step
grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
# grace_hopper
(224, 224, 3)
#Add a batch dimension, and pass the image to the model.
result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
(1, 1001)
#The result is a 1001 element vector of logits, rating the probability of each class for the image.
#So the top class ID can be found with argmax:
predicted_class = np.argmax(result[0], axis=-1)
predicted_class
819
#Decode the predictions
#We have the predicted class ID, Fetch the ImageNet labels, and decode the predictions
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt
16384/10484 [==============================================] - 0s 21us/step
plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())
3.Simple transfer learning
Using TF Hub it is simple to retrain the top layer of the model to recognize the classes in our dataset.
Dataset
For this example you will use the TensorFlow flowers dataset:
data_root = tf.keras.utils.get_file(
'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 571s 2us/step
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
image_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SHAPE)
Found 3670 images belonging to 5 classes.
#The resulting object is an iterator that returns image_batch, label_batch pairs.
for image_batch, label_batch in image_data:
print("Image batch shape: ", image_batch.shape)
print("Label batch shape: ", label_batch.shape)
break
Image batch shape: (32, 224, 224, 3)
Label batch shape: (32, 5)
#Run the classifier on a batch of images
#Now run the classifier on the image batch.
result_batch = classifier.predict(image_batch)
result_batch.shape
(32, 1001)
predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]
predicted_class_names
array(['daisy', 'candle', 'daisy', 'hook', 'cardoon', 'rapeseed',
'hummingbird', 'daisy', 'bee', 'tray', 'strawberry', 'bonnet',
'plunger', 'sea anemone', 'sulphur butterfly', 'cardoon',
'picket fence', 'daisy', 'daisy', 'sea urchin', 'spider web',
'daisy', 'sea urchin', 'daisy', 'daisy', 'coral fungus', 'daisy',
'daisy', 'daisy', 'daisy', 'spider web', 'hip'], dtype='<U30')
#Now check how these predictions line up with the images:
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(predicted_class_names[n])
plt.axis('off')
_ = plt.suptitle("ImageNet predictions")
4.Download the headless model
TensorFlow Hub also distributes models without the top classification layer. These can be used to easily do transfer learning.
Any Tensorflow 2 compatible image feature vector URL from hub.tensorflow.google.cn will work here.
feature_extractor_url = "https://hub.tensorflow.google.cn/google/tf2-preview/mobilenet_v2/feature_vector/2"
#Create the feature extractor.
feature_extractor_layer = hub.KerasLayer(feature_extractor_url,
input_shape=(224,224,3))
#It returns a 1280-length vector for each image:
feature_batch = feature_extractor_layer(image_batch)
print(feature_batch.shape)
(32, 1280)
#Freeze the variables in the feature extractor layer, so that the training only modifies the new classifier layer.
feature_extractor_layer.trainable = False
5.Attach a classification head
#Now wrap the hub layer in a tf.keras.Sequential model, and add a new classification layer.
model = tf.keras.Sequential([
feature_extractor_layer,
layers.Dense(image_data.num_classes)
])
model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
keras_layer_1 (KerasLayer) (None, 1280) 2257984
_________________________________________________________________
dense (Dense) (None, 5) 6405
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________
predictions = model(image_batch)
predictions.shape
TensorShape([32, 5])
6.Train the model
#Use compile to configure the training process:
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['acc'])
class CollectBatchStats(tf.keras.callbacks.Callback):
def __init__(self):
self.batch_losses = []
self.batch_acc = []
def on_train_batch_end(self, batch, logs=None):
self.batch_losses.append(logs['loss'])
self.batch_acc.append(logs['acc'])
self.model.reset_metrics()
steps_per_epoch = np.ceil(image_data.samples/image_data.batch_size)
batch_stats_callback = CollectBatchStats()
history = model.fit(image_data, epochs=2,
steps_per_epoch=steps_per_epoch,
callbacks=[batch_stats_callback])
Epoch 1/2
115/115 [==============================] - 73s 631ms/step - loss: 0.4730 - acc: 0.8125
Epoch 2/2
115/115 [==============================] - 74s 639ms/step - loss: 0.1535 - acc: 1.0000
# Now after, even just a few training iterations, we can already see that the model is making progress on the task.
plt.figure()
plt.ylabel("Loss")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(batch_stats_callback.batch_losses)
[<matplotlib.lines.Line2D at 0x7f4605df5908>]
plt.figure()
plt.ylabel("Accuracy")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(batch_stats_callback.batch_acc)
[<matplotlib.lines.Line2D at 0x7f4605d9c208>]
7.Check the predictions
# To redo the plot from before, first get the ordered list of class names:
class_names = sorted(image_data.class_indices.items(), key=lambda pair:pair[1])
class_names = np.array([key.title() for key, value in class_names])
class_names
array(['Daisy', 'Dandelion', 'Roses', 'Sunflowers', 'Tulips'],
dtype='<U10')
# Run the image batch through the model and convert the indices to class names.
predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]
label_id = np.argmax(label_batch, axis=-1)
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
color = "green" if predicted_id[n] == label_id[n] else "red"
plt.title(predicted_label_batch[n].title(), color=color)
plt.axis('off')
_ = plt.suptitle("Model predictions (green: correct, red: incorrect)")
8.Export your model
Now that you’ve trained the model, export it as a saved model:
import time
t = time.time()
export_path = "/tmp/saved_models/{}".format(int(t))
model.save(export_path, save_format='tf')
export_path
INFO:tensorflow:Assets written to: /tmp/saved_models/1600927981/assets
INFO:tensorflow:Assets written to: /tmp/saved_models/1600927981/assets
'/tmp/saved_models/1600927981'
# Now confirm that we can reload it, and it still gives the same results:
reloaded = tf.keras.models.load_model(export_path)
result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)
abs(reloaded_result_batch - result_batch).max()
0.0