针对MNIST数据集进行j基于tf2.keras的nn模型训练和预测
部分脚本如下: 完整脚本见笔者github
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Sequential, regularizers, Model
from utils.utils_tools import clock, get_ministdata
import warnings
warnings.filterwarnings(action='ignore')
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.fc0 = layers.Dense(256, activation='relu')
self.bn1 = layers.BatchNormalization()
self.fc1 = layers.Dense(32, activation='relu', kernel_regularizer = regularizers.l2(0.06))
self.bn2 = layers.BatchNormalization()
self.fc2 = layers.Dense(10, activation='sigmoid')
def call(self, inputs_, training=None):
x = self.fc0(inputs_)
x = self.bn1(x)
x = self.fc1(x)
x = self.bn2(x)
return self.fc2(x)
def build_nnmodel(input_shape=(None, 784)):
nnmodel = MyModel()
nnmodel.build(input_shape=input_shape)
nnmodel.compile(
loss='categorical_crossentropy',
optimizer = 'adam',
metrics = ['accuracy']
)
return nnmodel
if __name__ == '__main__':
mnistdf = get_ministdata()
te_index = mnistdf.sample(frac=0.75).index.tolist()
mnist_te = mnistdf.loc[te_index, :]
mnist_tr = mnistdf.loc[~mnistdf.index.isin(te_index), :]
mnist_tr_x, mnist_tr_y = mnist_tr.iloc[:, :-1].values, tf.keras.utils.to_categorical(mnist_tr.iloc[:, -1].values)
mnist_te_x, mnist_te_y = mnist_te.iloc[:, :-1].values, tf.keras.utils.to_categorical(mnist_te.iloc[:, -1].values)
nnmodel = build_nnmodel()
stop = tf.keras.callbacks.EarlyStopping(monitor = 'accuracy', min_delta=0.0001)
history = nnmodel.fit(mnist_tr_x, mnist_tr_y
,validation_data = (mnist_te_x, mnist_te_y)
,epochs=10
,callbacks = [stop])
acc_final = round(max(history.history['accuracy']), 2)
print(f"acc:{acc_final:.3f}")
predict_ = np.argmax(nnmodel.predict(mnist_te_x), axis=1)
te_y = np.argmax(mnist_te_y, axis=1)
print(predict_)
print(te_y)
print(f'auc: {sum(predict_ == te_y)/te_y.shape[0]:.3f}')
"""
14000/14000 [==============================] - 10s 741us/sample - loss: 0.1472 - accuracy: 0.9730 - val_loss: 0.2162 - val_accuracy: 0.9557
Epoch 9/10
13952/14000 [============================>.] - ETA: 0s - loss: 0.1520 - accuracy: 0.9720
Epoch 00009: accuracy did not improve from 0.97300
14000/14000 [==============================] - 23s 2ms/sample - loss: 0.1522 - accuracy: 0.9719 - val_loss: 0.2390 - val_accuracy: 0.9506
acc:0.970
[2 2 8 ... 7 6 6]
[2 2 8 ... 7 6 6]
auc: 0.951, cost: 170.946
"""