Keras在训练过程中根据epoch的值更换loss function

不停止训练根据epoch值更换loss function

以大于50epoch时将binary cross-entropy loss转换为focal loss为例。

1. 代码实现

Net Module

class Net(object):
    def __init__(self,epochs,img_rows,img_cols):
    	# initialize your parameters
        self.epochs = epochs
        self.img_rows = img_rows
        self.img_cols = img_cols
    def network_architecture(self):
    	# define your network architecture
    	inputs = Input(shape=(self.img_rows, self.img_cols),name='data')
    	outputs = Dense(10)
    	model = Model(inputs = [inputs], outputs = [outputs])
    	# initialize current_epoch which will be used in next training loops
    	self.current_epoch = K.variable(value=0)
    	# compile the model
    	model.compile(optimizer = Adam(learning_rate = 1e-4), loss = bce_focal_loss_consequence(change_epoch=50,current_epoch=self.current_epoch,gamma=2.,alpha=.25))
    	return model

Callback Module

class WarmUpCallback(Callback):
    def __init__(self,current_epoch):
        self.current_epoch = current_epoch

    def on_epoch_end(self, epoch, logs=None):
        K.set_value(self.current_epoch, epoch+1) 

Custom Loss Module

def bce_focal_loss_consequence(change_epoch,current_epoch,gamma,alpha):
    def bce_focal(y_true,y_pred):
        bool_case_1 = K.less(current_epoch,change_epoch)
        if bool_case_1:
            loss = binary_crossentropy(y_true,y_pred)
        else:
            loss = binary_focal_loss(gamma,alpha)(y_true,y_pred)
        return loss
    return bce_focal

Fit Module

my_callbacks = [WarmUpCallback(current_epoch=self.current_epoch)]
model.fit(training_generator, epochs=self.epochs,
			validation_data=validation_generator,callbacks=my_callbacks)

2. Loss Curve

在这里插入图片描述
如上图所示,如我们预期,Loss在50epoch时由于换了loss function而发生了突变,接着继续训练。

References:

https://github.com/keras-team/keras/issues/2595
https://stackoverflow.com/questions/42787181/variationnal-auto-encoder-implementing-warm-up-in-keras

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值