训练模型,在保存模型,使得下次调用不用再次训练参数
from tensorflow import keras
import tensorflow as tf
import tensorflow.python.keras
fashion_mnist=keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels)=fashion_mnist.load_data()
class_name=['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Sneaker','Shirt','Bag','Ankle boot']
model = keras.Sequential() # 构建序列模型
model.add(keras.layers.Flatten(input_shape=(28,28)))#将数据拉长
model.add(keras.layers.Dense(128, activation='relu'))
model.add(keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.fit(train_images,train_labels, epochs=10)
test_loss,test_acc=model.evaluate(test_images,test_labels,verbose=2)
print('\nTest accuracy:',test_acc)
predictions=model.predict(test_images)
predictions.shape
print(predictions[0])
#保存好训练好的模型,保存权重参数和网络模型,通常以.h5为后缀名,以后使用就不用训练模型了
model.save('fashion_model.h5')
调用保存的模型
import tensorflow as tf
from tensorflow.python.keras import layers
from tensorflow import keras
model=keras.models.load_model('fashion_model.h5')
fashion_mnist=keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels)=fashion_mnist.load_data()
predictions=model.predict(test_images)
print(predictions.shape)
print(predictions[0])