使用keras建立checkpoint

本文介绍如何在Keras中使用Checkpoint保存模型权重,并利用TensorBoard进行训练过程的可视化。通过实例展示了如何配置Callback函数实现模型的自动保存与训练监控。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.简介

checkpoint是一种系统状态的快照方法,可以直接使用。checkpoint是模型的权重,可以用来预测,也可以用来继续训练。
keras中的回调函数callbacks提供了checkpoint功能。
Tensorboard是一种训练可视化的操作。在keras的回调函数中也有相应的功能。
下面这个示例,将两种的情况都包涵在内了。

2.示例

#!/usr/bin/env python
# encoding: utf-8

import pandas as pd
import numpy as np
import matplotlib.pylab as plt
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from keras.callbacks import ModelCheckpoint
from keras.callbacks import TensorBoard
#读取数据
dataf = pd.read_csv("./data/data_complete.csv")[['sump']].values[0:2*288]
#print(dataf)
def create_dataset(dataset, timesteps):
    datax=[]
    datay=[]
    for each in range(len(dataset)-timesteps):
        x = dataset[each:each+timesteps,0]
        y = dataset[each+timesteps,0]
        datax.append(x)
        datay.append(y)
    return np.array(datax),np.array(datay)

#构造train and test
scaler = MinMaxScaler(feature_range=(0,1))
dataf = scaler.fit_transform(dataf)
trainsize = int(len(dataf)*0.7)
train = dataf[0:trainsize]
test = dataf[trainsize:len(dataf)]
timesteps = 288
trainx, trainy = create_dataset(train, timesteps)
testx, testy = create_dataset(test,timesteps)
#print(trainx)
#print(trainx.shape)
#变换
trainx = np.reshape(trainx,(trainx.shape[0],timesteps,1))
testx = np.reshape(testx, (testx.shape[0],timesteps,1))
#print(trainx)
#print(trainx.shape)
#lstm
model = Sequential()
model.add(LSTM(4,input_shape=(timesteps,1)))
model.add(Dense(1))
#model.load_weights("")#可以在这里加载checkpoint权重模型,继续训练。
model.compile(loss="mean_squared_error",optimizer="adam",metrics=["accuracy"])#metrics要设置
#checkpoint
filepath = "weights-impovement-{epoch:02d}--{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath,monitor="val_acc", verbose=1,
                             save_best_only=True,mode="max")#checkpoint的示例
callbacks_list = [checkpoint]
tensorboard = TensorBoard(log_dir="log")#tensorboard的示例
callbacks_tensor = [tensorboard]
model.fit(trainx,trainy, epochs=3, batch_size=5,
          validation_split=0.25,callbacks=callbacks_tensor)#callbacks=callbacks_list
#
# #predict
# train_predict = model.predict(trainx)
# test_predict = model.predict(testx)
# #invert
# train_predict = scaler.inverse_transform(train_predict)
# trainy = scaler.inverse_transform([trainy])
# print(train_predict[0],trainy[0])
#
# test_predict = scaler.inverse_transform(test_predict)
# testy = scaler.inverse_transform([testy])
# #error
# train_score = np.sqrt(mean_squared_error(trainy[0],train_predict[:,0]))
# test_score = np.sqrt(mean_squared_error(testy[0],test_predict[:,0]))
# print("train score RMSE: %.2f"% train_score)
# print("test score RMSE: %.2f"% test_score)
#
# #plot
# # shift train predictions for plotting
# trainPredictPlot = np.empty_like(dataf)
# trainPredictPlot[:, :] = np.nan
# trainPredictPlot[timesteps:len(train_predict)+timesteps, :] = train_predict
#
# # shift test predictions for plotting
# testPredictPlot = np.empty_like(dataf)
# testPredictPlot[:,:] = np.nan
# testPredictPlot[len(train_predict)+(timesteps*2):len(dataf), :] = test_predict
#
# # plot baseline and predictions
# plt.plot(scaler.inverse_transform(dataf))
# plt.plot(trainPredictPlot)
# plt.plot(testPredictPlot)
# plt.show()

checkpoint:
Train on 86 samples, validate on 29 samples
Epoch 1/3

5/86 [>…] - ETA: 22s - loss: 9.5544e-04 - acc: 0.0000e+00
10/86 [>…] - ETA: 12s - loss: 9.2816e-04 - acc: 0.1000
15/86 [
==>…] - ETA: 8s - loss: 7.5216e-04 - acc: 0.0667

Epoch 00001: val_acc improved from -inf to 0.00000, saving model to weights-impovement-01–0.00.hdf5
Epoch 2/3
Epoch 00003: val_acc did not improve from 0.00000

tensorboard:
生成一个log文件夹,理由有个tensorboard文件。
在cmd下.py脚本目录下执行:tensorboard --logdir=log
然后在浏览器localhost:6006,即可看到可视化过程。

3.参考

如何为Keras中的深度学习模型建立Checkpoint

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值