keras 分批训练 详解2 - keras 进阶教程

本文详细介绍使用Keras的train_on_batch方法进行分批训练的技巧,这种方法比fit_generator更直观、灵活,适合在线训练和增量学习。文章通过具体代码示例展示了如何手动循环epoch,遍历数据集并生成batch,最后调用train_on_batch进行训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

keras 分批训练2

今天讲的是如何使用keras进行分批训练(也叫增量训练增量学习在线训练批量训练)的第二种方法,上一种方法在这里:
https://blog.youkuaiyun.com/weixin_42744102/article/details/87272950

上一次讲的是fit_generator的方法,那个方法搞不清每层的名字就很容易报错、计算错了step和batch也很麻烦,需要自己写的生成器和该方法密切耦合,其实不太友好;那么今天我们讲讲另一种方法:train_on_batch

使用方法:

model.train_on_batch(x, y)

使用方法很简单,只需要传入一个batch的data和target就可以,有两个可选参数可以调整权重,一般不用填。
这个方法的好处就是简单易用而且直观,不需要处理fit_genrator的各种step、layer name的问题,不那么强耦合;直观之处就在于,整个方法的作用就是送一个batch的数据进去训练,调用一次就是用batch训练一次,很灵活,个人非常推荐使用。

当然,这个方法需要结合自己实现的一些别的东西,才能完成训练,你需要手动循环epoch次,然后每个循环里面嵌套一个循环,这个便利整个数据集,产生一个一个的batch,并在产生了这些batch之后,调用train_on_batch方法进行训练。这些步骤在fit_genrator中是实现了的,但是它的强耦合导致对开发不太友好,因此还不如自己实现

好了,贴上我一个工程里面的样例代码:

    for epoch in range(EPOCH):

        print('epoch', epoch)

        print(int(data_amo*(1.-VALIDATE_SPLIT)))

        for b_idx in range(0, int(data_amo*(1.-VALIDATE_SPLIT)), BATCH):

            with open('random_data') as f:

                data_gram_sentence = f.readlines()[b_idx:b_idx+BATCH]

            with open('random_target') as f:

                data_target = list(f.readlines())[b_idx:b_idx+BATCH]

            train_x = []

            train_y = []

            for sentence_gram_index in range(len(data_gram_sentence)):

                sentence_gram = data_gram_sentence[sentence_gram_index]

                grams = sentence_gram[:-1].split(' ')

                valid = 0

                sentence_vector = np.zeros(NUM_FEATURES)

                for gram in grams:

                    if gram in model:

                        valid += 1

                        sentence_vector += model[gram]

                if valid != 0:

                    sentence_vector = sentence_vector / valid

                train_x.append(sentence_vector)

                target_single = np.zeros(len(data_li))

                # print('*'*10)

                # print(one_hot_dict[data_target[sentence_gram_index][:-1]])

                target_single[int(one_hot_dict[data_target[sentence_gram_index][:-1]])] = 1.

                # target_single[one_hot_dict[[data_target[sentence_gram_index][:-1]]]] = 1.

                # print(target_single)

                train_y.append(target_single)

            train_x = np.array(train_x)

            train_y = np.array(train_y)

            ks_model.train_on_batch(train_x, train_y, sample_weight=None, class_weight=None)

有不懂或是发现我的疏漏错误的,欢迎随时联系我:1012950361@qq.com

下次见~

ValueError: Unrecognized keyword arguments passed to Embedding: {'batch_input_shape': [64, None]} --------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[15], line 2 1 if __name__ == "__main__": ----> 2 main() Cell In[11], line 23, in main() 16 dataset = create_dataset( 17 text_as_int, 18 seq_length=SEQ_LENGTH, 19 batch_size=BATCH_SIZE 20 ) 22 # 构建模型 ---> 23 model = build_model( 24 vocab_size=len(vocab), 25 embedding_dim=EMBEDDING_DIM, 26 rnn_units=RNN_UNITS, 27 batch_size=BATCH_SIZE 28 ) 30 # 编译模型 31 model.compile(optimizer='adam', loss=loss) Cell In[8], line 4, in build_model(vocab_size, embedding_dim, rnn_units, batch_size) 1 def build_model(vocab_size, embedding_dim, rnn_units, batch_size): 2 """构建RNN文本生成模型""" 3 model = tf.keras.Sequential([ ----> 4 tf.keras.layers.Embedding( 5 vocab_size, 6 embedding_dim, 7 batch_input_shape=[batch_size, None] 8 ), 9 tf.keras.layers.GRU( 10 rnn_units, 11 return_sequences=True, 12 stateful=True, 13 recurrent_initializer='glorot_uniform', 14 reset_after=False # 提高GPU性能 15 ), 16 tf.keras.layers.Dropout(0.2), # 添加dropout减少过拟合 17 tf.keras.layers.Dense(vocab_size) 18 ]) 19 return model File /usr/local/lib/python3.11/site-packages/keras/src/layers/core/embedding.py:100, in Embedding.__init__(self, input_dim, output_dim, embeddings_initializer, embeddings_regularizer, embeddings_constraint, mask_zero, weights, lora_rank, lora_alpha, **kwargs) 96 if input_length is not None: 97 warnings.warn( 98 "Argument `input_length` is deprecated. Just remove it." 99 ) --> 100 super().__init__(**kwargs) 101 self.input_dim = input_dim 102 self.output_dim = output_dim File /usr/local/lib/python3.11/site-packages/keras/src/layers/layer.py:291, in Layer.__init__(self, activity_regularizer, trainable, dtype, autocast, name, **kwargs) 289 self._input_shape_arg = input_shape_arg 290 if kwargs: --> 291 raise ValueError( 292 "Unrecognized keyword arguments " 293 f"passed to {self.__class__.__name__}: {kwargs}" 294 ) 296 self._path = None # Will be determined in `build_wrapper` 297 self.built = False ValueError: Unrecognized keyword arguments passed to Embedding: {'batch_input_shape': [64, None]}
最新发布
06-22
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值