Transfer learning with TensorFlow Hub
- TensorFlow Hub是共享预训练模型组件的一种方式。 有关预训练模型的可搜索列表,请参见TensorFlow模块中心。
https://hub.tensorflow.google.cn/ - 本教程演示:如何将TensorFlow Hub与tf.keras结合使用。 如何使用TensorFlow Hub进行图像分类,如何进行简单的迁移学习。
import matplotlib.pylab as plt
import tensorflow as tf
#!pip install -q tensorflow-hub
#!pip install -q tensorflow-datasets
#豆瓣的速度,谁用谁知道..
#pip install tensorflow-hub -i https://pypi.douban.com/simple --user
#pip install tensorflow-datasets -i https://pypi.douban.com/simple --user
import tensorflow_hub as hub
from tensorflow.keras import layers
1.An ImageNet classifier
- Download the classifier(此例用mobilenet_v2) Use hub.module to load a mobilenet, and tf.keras.layers.Lambda to wrap it up as a keras layer.
- Any TensorFlow 2 compatible image classifier URL from hub.tensorflow.google.cn will work here.
classifier_url ="https://hub.tensorflow.google.cn/google/tf2-preview/mobilenet_v2/classification/2"
IMAGE_SHAPE = (224, 224)
classifier = tf.keras.Sequential([
hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,))#(224,224,3)
])
2.Run it on a single image
Download a single image to try the model on.
import numpy as np
import PIL.Image as Image
grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg
65536/61306 [================================] - 0s 3us/step

grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
# grace_hopper
(224, 224, 3)
#Add a batch dimension, and pass the image to the model.
result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
(1, 1001)
#The result is a 1001 element vector of logits, rating the probability of each class for the image.
#So the top class ID can be found with argmax:
predicted_class = np.argmax(result[0], axis=-1)
predicted_class
819
#Decode the predictions
#We have the predicted class ID, Fetch the ImageNet labels, and decode the predictions
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')</

本教程介绍了如何使用TensorFlowHub结合tf.keras进行图像分类。首先,展示了如何加载预训练的MobileNet模型进行单张图片的预测。接着,利用TensorFlowHub进行简单迁移学习,对TensorFlow花朵数据集进行训练,调整模型以识别新的类别。通过几个训练迭代,模型在任务上取得了进步。最后,导出训练好的模型并验证其预测一致性。
最低0.47元/天 解锁文章
3990

被折叠的 条评论
为什么被折叠?



