keras 分批训练2
今天讲的是如何使用keras进行分批训练(也叫增量训练、增量学习、在线训练、批量训练)的第二种方法,上一种方法在这里:
https://blog.youkuaiyun.com/weixin_42744102/article/details/87272950
上一次讲的是fit_generator的方法,那个方法搞不清每层的名字就很容易报错、计算错了step和batch也很麻烦,需要自己写的生成器和该方法密切耦合,其实不太友好;那么今天我们讲讲另一种方法:train_on_batch
使用方法:
model.train_on_batch(x, y)
使用方法很简单,只需要传入一个batch的data和target就可以,有两个可选参数可以调整权重,一般不用填。
这个方法的好处就是简单易用而且直观,不需要处理fit_genrator的各种step、layer name的问题,不那么强耦合;直观之处就在于,整个方法的作用就是送一个batch的数据进去训练,调用一次就是用batch训练一次,很灵活,个人非常推荐使用。
当然,这个方法需要结合自己实现的一些别的东西,才能完成训练,你需要手动循环epoch次,然后每个循环里面嵌套一个循环,这个便利整个数据集,产生一个一个的batch,并在产生了这些batch之后,调用train_on_batch方法进行训练。这些步骤在fit_genrator中是实现了的,但是它的强耦合导致对开发不太友好,因此还不如自己实现
好了,贴上我一个工程里面的样例代码:
for epoch in range(EPOCH):
print('epoch', epoch)
print(int(data_amo*(1.-VALIDATE_SPLIT)))
for b_idx in range(0, int(data_amo*(1.-VALIDATE_SPLIT)), BATCH):
with open('random_data') as f:
data_gram_sentence = f.readlines()[b_idx:b_idx+BATCH]
with open('random_target') as f:
data_target = list(f.readlines())[b_idx:b_idx+BATCH]
train_x = []
train_y = []
for sentence_gram_index in range(len(data_gram_sentence)):
sentence_gram = data_gram_sentence[sentence_gram_index]
grams = sentence_gram[:-1].split(' ')
valid = 0
sentence_vector = np.zeros(NUM_FEATURES)
for gram in grams:
if gram in model:
valid += 1
sentence_vector += model[gram]
if valid != 0:
sentence_vector = sentence_vector / valid
train_x.append(sentence_vector)
target_single = np.zeros(len(data_li))
# print('*'*10)
# print(one_hot_dict[data_target[sentence_gram_index][:-1]])
target_single[int(one_hot_dict[data_target[sentence_gram_index][:-1]])] = 1.
# target_single[one_hot_dict[[data_target[sentence_gram_index][:-1]]]] = 1.
# print(target_single)
train_y.append(target_single)
train_x = np.array(train_x)
train_y = np.array(train_y)
ks_model.train_on_batch(train_x, train_y, sample_weight=None, class_weight=None)
有不懂或是发现我的疏漏错误的,欢迎随时联系我:1012950361@qq.com
下次见~