修改之前
首先贴出来训练部分的代码:
def train(self, train_generator, validation_generator, pre_model_path=None):
'''
:param train_generator: 训练集
:param validation_generator: 测试集
:param pre_model_path: 预训练模型,在之前模型上继续训练,目前仅支持h5模型
'''
# 在已有模型基础上继续训练
if pre_model_path:
self.model = load_model(pre_model_path)
# 配置模型
with open(pjoin(TXT_DIR, 'message.txt'), 'r') as f:
_, TRAIN_SIZE, VAL_SIZE, _ = list(map(int, f.readline().split(',')))
STEP_PER_EPOCH = TRAIN_SIZE // BATCH_SIZE + 1
VALIDATION_STEPS = VAL_SIZE // BATCH_SIZE + 1
optimizer = optimizers.RMSprop(lr=LEARNING_RATE)
self.model.compile(loss='mse', optimizer=optimizer, metrics=['mae'])
# 训练
self.history = self.model.fit(train_generator, steps_per_epoch=STEP_PER_EPOCH, epochs=EPOCH,