merge_bn 合并bn层到conv或FC层原理介绍及代码实现

本文详细介绍了BN层(Batch Normalization)在深度学习中的作用及其对模型训练速度的影响。同时探讨了BN层在预测阶段存在的问题,并提出了一种将BN层合并到卷积层或全连接层的方法,以提高预测效率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

参考文献:http://blog.youkuaiyun.com/diye2008/article/details/78492181

1.bn合并的必要性:

   bn层即batch-norm层,一般是深度学习中用于加速训练速度和一种方法,一般放置在卷积层(conv层)或者全连接层之后,将数据归一化并加速了训练拟合速度。但是bn层虽然在深度学习模型训练时起到了一定的积极作用,但是在预测时因为凭空多了一些层,影响了整体的计算速度并占用了更多内存或者显存空间。所以我们设想如果能将bn层合并到相邻的卷积层或者全连接层之后就好了,于是就有了这篇文章所提到的工作。


2.bn合并本身的数学原理:

               bn层一般在神经网络中‘所处的位置如下图所示:


如上图可以看到,bn层的位置一般在conv(or Fc)层的后面,也有一些情况bn在conv(or Fc)层的前面。我们先来两种情况分别来考虑。


       2.1 bn层在conv层之后的情形

         bn合并的原理,可以由下两张图所示

                                      bn层进行数据处理的过程

             原文推导有误,本文推导如下:


代码:

https://github.com/chuanqi305/MobileNet-SSD

merge_bn.py

     









import matplotlib matplotlib.use('Agg') import argparse, time, logging import os import numpy as np import mxnet as mx from mxnet import gluon, nd from mxnet import autograd as ag from mxnet.gluon.data.vision import transforms import sys sys.path.append('/home/ubuntu/PycharmProjects/pythonProject/fscil-master/') import model from model.cifar_quick import quick_cnn from model.cifar_resnet_v1 import cifar_resnet20_v1 from gluoncv.utils import makedirs from gluoncv.data import transforms as gcv_transforms from dataloader.dataloader import NC_CIFAR100, merge_datasets from tools.utils import LinearWarmUp from tools.utils import DataLoader from tools.utils import parse_args from tools.plot import plot_pr, plot_all_sess from tools.loss import DistillationSoftmaxCrossEntropyLoss,NG_Min_Loss,NG_Max_Loss from tools.ng_anchor import prepare_anchor import json from tools.utils import select_best, select_best2, select_best3 opt = parse_args() batch_size = opt.batch_size num_gpus = len(opt.gpus.split(',')) batch_size *= max(1, num_gpus) context = [mx.gpu(int(i)) for i in opt.gpus.split(',')] num_workers = opt.num_workers model_name = opt.model # ========================================================================== if model_name=='quick_cnn': classes = 60 if opt.fix_conv: fix_layers = 3 fix_fc =False else: fix_layers = 0 fix_fc = False net = quick_cnn(classes, fix_layers, fix_fc=fix_fc, fw=opt.fw) feature_size = 64 elif model_name=='resnet18': classes = 60 feature_size = 64 net = cifar_resnet20_v1(classes=classes, wo_bn=opt.wo_bn, fw=opt.fw) else: raise KeyError('network key error') if opt.resume_from: net.load_parameters(opt.resume_from, ctx = context) DATASET = eval(opt.dataset) # ========================================================================== optimizer = 'nag' save_period = opt.save_period plot_path = opt.save_plot_dir save_dir = time.strftime('./experimental_result/{}/{}/%Y-%m-%d-%H-%M-%S'.format(opt.dataset, model_name), time.localtime()) save_dir = save_dir + opt.save_name makedirs(save_dir) logger = logging.getLogger() logger.setLevel(logging.INFO) log_save_dir = os.path.join(save_dir, 'log.txt') fh = logging.FileHandler(log_save_dir) fh.setLevel(logging.INFO) logger.addHandler(fh) logger.info(opt) def test(ctx, val_data, net, sess): metric = mx.metric.Accuracy() for i, batch in enumerate(val_data): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) outputs = [net(X, sess)[1] for X in data] metric.update(label, outputs) return metric.get() def train(net, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] if not opt.resume_from: net.initialize(mx.init.Xavier(), ctx=ctx) if opt.dataset == 'NC_CIFAR100': n = mx.nd.zeros(shape=(1,3,32,32),ctx=ctx[0]) #####init CNN else: raise KeyError('dataset keyerror') for m in range(9): net(n,m) def makeSchedule(start_lr,base_lr,length,step,factor): schedule = mx.lr_scheduler.MultiFactorScheduler(step=step, factor=factor) schedule.base_lr = base_lr schedule = LinearWarmUp(schedule, start_lr=start_lr, length=length) return schedule # ========================================================================== sesses = list(np.arange(opt.sess_num)) epochs = [opt.epoch]*opt.sess_num lrs = [opt.base_lrs]+[opt.lrs]*(opt.sess_num-1) lr_decay = opt.lr_decay base_decay_epoch = [int(i) for i in opt.base_decay_epoch.split(',')] + [np.inf] lr_decay_epoch = [base_decay_epoch]+[[opt.inc_decay_epoch, np.inf]]*(opt.sess_num-1) AL_weight = opt.AL_weight min_weight = opt.min_weight oce_weight = opt.oce_weight pdl_weight = opt.pdl_weight max_weight = opt.max_weight temperature = opt.temperature use_AL = opt.use_AL # anchor loss use_ng_min = opt.use_ng_min # Neural Gas min loss use_ng_max = opt.use_ng_max # Neural Gas min loss ng_update = opt.ng_update # Neural Gas update node use_oce = opt.use_oce # old samples cross entropy loss use_pdl = opt.use_pdl # probability distillation loss use_nme = opt.use_nme # Similarity loss use_warmUp = opt.use_warmUp use_ng = opt.use_ng # Neural Gas fix_conv = opt.fix_conv # fix cnn to train novel classes fix_epoch = opt.fix_epoch c_way = opt.c_way k_shot = opt.k_shot base_acc = opt.base_acc # base model acc select_best_method = opt.select_best # select from _best, _best2, _best3 init_class = 60 anchor_num = 400 # ========================================================================== acc_dict = {} all_best_e = [] if model_name[-7:] != 'maxhead': net.fc3.initialize(mx.init.Normal(sigma=0.001), ctx=ctx, force_reinit=True) net.fc4.initialize(mx.init.Normal(sigma=0.001), ctx=ctx, force_reinit=True) net.fc5.initialize(mx.init.Normal(sigma=0.001), ctx=ctx, force_reinit=True) net.fc6.initialize(mx.init.Normal(sigma=0.001), ctx=ctx, force_reinit=True) net.fc7.initialize(mx.init.Normal(sigma=0.001), ctx=ctx, force_reinit=True) net.fc8.initialize(mx.init.Normal(sigma=0.001), ctx=ctx, force_reinit=True) net.fc9.initialize(mx.init.Normal(sigma=0.001), ctx=ctx, force_reinit=True) net.fc10.initialize(mx.init.Normal(sigma=0.001), ctx=ctx, force_reinit=True) for sess in sesses: logger.info('session : %d'%sess) schedule = makeSchedule(start_lr=0, base_lr=lrs[sess], length=5, step=lr_decay_epoch[sess], factor=lr_decay) # prepare the first anchor batch if sess==0 and opt.resume_from: acc_dict[str(sess)] = list() acc_dict[str(sess)].append([base_acc,0]) all_best_e.append(0) continue # quick cnn totally unfix, not use data augmentation if sess == 1 and model_name == 'quick_cnn'and use_AL: transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) ]) anchor_trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) ]) else: transform_train = transforms.Compose([ gcv_transforms.RandomCrop(32, pad=4), transforms.RandomFlipLeftRight(), transforms.ToTensor(), transforms.Normalize([0.5071, 0.4866, 0.4409], [0.2009, 0.1984, 0.2023]) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5071, 0.4866, 0.4409], [0.2009, 0.1984, 0.2023]) ]) anchor_trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5071, 0.4866, 0.4409], [0.2009, 0.1984, 0.2023]) ]) # ng_init and ng_update if use_AL or use_nme or use_pdl or use_oce: if sess != 0: if ng_update == True: if sess==1: update_anchor1, bmu, variances= \ prepare_anchor(DATASET,logger,anchor_trans,num_workers,feature_size,net,ctx,use_ng,init_class) update_anchor_data = DataLoader(update_anchor1, anchor_trans, update_anchor1.__len__(), num_workers, shuffle=False) if opt.ng_var: idx_1 = np.where(variances.asnumpy() > 0.5) idx_2 = np.where(variances.asnumpy() < 0.5) variances[idx_1] = 0.9 variances[idx_2] = 1 else: base_class = init_class + (sess - 1) * 5 new_class = list(init_class + (sess - 1) * 5 + (np.arange(5))) new_set = DATASET(train=True, fine_label=True, fix_class=new_class, base_class=base_class, logger=logger) update_anchor2 = merge_datasets(update_anchor1, new_set) update_anchor_data = DataLoader(update_anchor2, anchor_trans, update_anchor2.__len__(), num_workers, shuffle=False) elif(sess==1): update_anchor, bmu, variances = \ prepare_anchor(DATASET,logger,anchor_trans,num_workers,feature_size,net,ctx,use_ng,init_class) update_anchor_data = DataLoader(update_anchor, anchor_trans, update_anchor.__len__(), num_workers, shuffle=False) if opt.ng_var: idx_1 = np.where(variances.asnumpy() > 0.5) idx_2 = np.where(variances.asnumpy() < 0.5) variances[idx_1] = 0.9 variances[idx_2] = 1 for batch in update_anchor_data: anc_data = gluon.utils.split_and_load(batch[0], ctx_list=[ctx[0]], batch_axis=0) anc_label = gluon.utils.split_and_load(batch[1], ctx_list=[ctx[0]], batch_axis=0) with ag.pause(): anchor_feat, anchor_logit = net(anc_data[0], sess-1) anchor_feat = [anchor_feat] anchor_logit = [anchor_logit] trainer = gluon.Trainer(net.collect_params(), optimizer, {'learning_rate': lrs[sess], 'wd': opt.wd, 'momentum': opt.momentum}) metric = mx.metric.Accuracy() train_metric = mx.metric.Accuracy() loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() # ========================================================================== # all loss init if use_nme: def loss_fn_disG(f1, f2, weight): f1 = f1.reshape(anchor_num,-1) f2 = f2.reshape(anchor_num,-1) similar = mx.nd.sum(f1*f2, 1) return (1-similar)*weight digG_weight = opt.nme_weight if use_AL: if model_name == 'quick_cnn': AL_w = [120, 75, 120, 100, 50, 60, 90, 90] AL_weight = AL_w[sess-1] else: AL_weight=opt.AL_weight if opt.ng_var: def l2lossVar(feat, anc, weight, var): dim = feat.shape[1] feat = feat.reshape(-1, dim) anc = anc.reshape(-1, dim) loss = mx.nd.square(feat - anc) loss = loss * weight * var return mx.nd.mean(loss, axis=0, exclude=True) loss_fn_AL = l2lossVar else: loss_fn_AL = gluon.loss.L2Loss(weight=AL_weight) if use_pdl: loss_fn_pdl = DistillationSoftmaxCrossEntropyLoss(temperature=temperature, hard_weight=0, weight=pdl_weight) if use_oce: loss_fn_oce = gluon.loss.SoftmaxCrossEntropyLoss(weight=oce_weight) if use_ng_min: loss_fn_max = NG_Max_Loss(lmbd=max_weight, margin=0.5) if use_ng_min: min_loss = NG_Min_Loss(num_classes=opt.c_way, feature_size=feature_size, lmbd=min_weight, # center weight = 0.01 in the paper ctx=ctx[0]) min_loss.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx, force_reinit=True) # init matrix. center_trainer = gluon.Trainer(min_loss.collect_params(), optimizer="sgd", optimizer_params={"learning_rate": opt.ng_min_lr}) # alpha=0.1 in the paper. # ========================================================================== lr_decay_count = 0 # dataloader if opt.cum and sess==1 : base_class = list(np.arange(init_class)) joint_data = DATASET(train=True, fine_label=True, c_way=init_class, k_shot=500, fix_class=base_class, logger=logger) if sess == 0 : base_class = list(np.arange(init_class)) new_class = list(init_class + (np.arange(5))) base_data = DATASET(train=True, fine_label=True, c_way=init_class, k_shot=500, fix_class=base_class, logger=logger) bc_val_data = DataLoader(DATASET(train=False, fine_label=True, fix_class=base_class, logger=logger) , transform_test, 100, num_workers, shuffle=False) nc_val_data = DataLoader( DATASET(train=False, fine_label=True, fix_class=new_class, base_class=len(base_class), logger=logger) , transform_test, 100, num_workers, shuffle=False) else: base_class = list(np.arange(init_class + (sess-1)*5)) new_class = list(init_class + (sess-1)*5 + (np.arange(5))) train_data_nc = DATASET(train=True, fine_label=True, c_way=c_way, k_shot=k_shot, fix_class=new_class, base_class=len(base_class), logger=logger) bc_val_data = DataLoader(DATASET(train=False, fine_label=True, fix_class=base_class, logger=logger) , transform_test, 100, num_workers, shuffle=False) nc_val_data = DataLoader( DATASET(train=False, fine_label=True, fix_class=new_class, base_class=len(base_class), logger=logger) , transform_test, 100, num_workers, shuffle=False) if sess == 0: train_data = DataLoader(base_data, transform_train, min(batch_size, base_data.__len__()), num_workers, shuffle=True) else: if opt.cum: # cumulative : merge base and novel dataset. joint_data = merge_datasets(joint_data, train_data_nc) train_data = DataLoader(joint_data, transform_train, min(batch_size, joint_data.__len__()), num_workers, shuffle=True) elif opt.use_all_novel: # use all novel data if sess==1: novel_data = train_data_nc else: novel_data = merge_datasets(novel_data, train_data_nc) train_data = DataLoader(novel_data, transform_train, min(batch_size, novel_data.__len__()), num_workers, shuffle=True) else: # basic method train_data = DataLoader(train_data_nc, transform_train, min(batch_size, train_data_nc.__len__()), num_workers, shuffle=True) for epoch in range(epochs[sess]): tic = time.time() train_metric.reset() metric.reset() train_loss, train_anchor_loss, train_oce_loss = 0, 0, 0 train_disg_loss, train_pdl_loss, train_min_loss= 0, 0, 0 train_max_loss=0 num_batch = len(train_data) if use_warmUp: lr = schedule(epoch) trainer.set_learning_rate(lr) else: lr = trainer.learning_rate if epoch == lr_decay_epoch[sess][lr_decay_count]: trainer.set_learning_rate(trainer.learning_rate*lr_decay) lr_decay_count += 1 if sess!=0 and epoch<fix_epoch: fix_cnn = fix_conv else: fix_cnn = False for i, batch in enumerate(train_data): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) all_loss = list() with ag.record(): output_feat, output = net(data[0],sess,fix_cnn) output_feat = [output_feat] output = [output] loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)] all_loss.extend(loss) if use_nme: anchor_h = [net(X, sess, fix_cnn)[0] for X in anc_data] disg_loss = [loss_fn_disG(a_h, a, weight=digG_weight) for a_h, a in zip(anchor_h, anchor_feat)] all_loss.extend(disg_loss) if sess > 0 and use_ng_max: max_loss = [loss_fn_max(feat, label, feature_size, epoch, sess,init_class) for feat, label in zip(output_feat, label)] all_loss.extend(max_loss[0]) if sess > 0 and use_AL: # For anchor loss anchor_h = [net(X, sess, fix_cnn)[0] for X in anc_data] if opt.ng_var: anchor_loss = [loss_fn_AL(anchor_h[0], anchor_feat[0], AL_weight, variances)] all_loss.extend(anchor_loss) else: anchor_loss = [loss_fn_AL(a_h, a) for a_h, a in zip(anchor_h, anchor_feat)] all_loss.extend(anchor_loss) if sess > 0 and use_ng_min: loss_min = min_loss(output_feat[0], label[0]) all_loss.extend(loss_min) if sess > 0 and use_pdl: anchor_l = [net(X, sess, fix_cnn)[1] for X in anc_data] anchor_l = [anchor_l[0][:,:60+(sess-1)*5]] soft_label = [mx.nd.softmax(anchor_logit[0][:,:60+(sess-1)*5] / temperature)] pdl_loss = [loss_fn_pdl(a_h, a, soft_a) for a_h, a, soft_a in zip(anchor_l, anc_label, soft_label)] all_loss.extend(pdl_loss) if sess > 0 and use_oce: anchorp = [net(X, sess, fix_cnn)[1] for X in anc_data] oce_Loss = [loss_fn_oce(ap, a) for ap, a in zip(anchorp, anc_label)] all_loss.extend(oce_Loss) all_loss = [nd.mean(l) for l in all_loss] ag.backward(all_loss) trainer.step(1,ignore_stale_grad=True) if use_ng_min: center_trainer.step(opt.c_way*opt.k_shot) train_loss += sum([l.sum().asscalar() for l in loss]) if sess > 0 and use_AL: train_anchor_loss += sum([al.mean().asscalar() for al in anchor_loss]) if sess > 0 and use_oce: train_oce_loss += sum([al.mean().asscalar() for al in oce_Loss]) if sess > 0 and use_nme: train_disg_loss += sum([al.mean().asscalar() for al in disg_loss]) if sess > 0 and use_pdl: train_pdl_loss += sum([al.mean().asscalar() for al in pdl_loss]) if sess > 0 and use_ng_min: train_min_loss += sum([al.mean().asscalar() for al in loss_min]) if sess > 0 and use_ng_max: train_max_loss += sum([al.mean().asscalar() for al in max_loss[0]]) train_metric.update(label, output) train_loss /= batch_size * num_batch name, acc = train_metric.get() name, bc_val_acc = test(ctx, bc_val_data, net, sess) name, nc_val_acc = test(ctx, nc_val_data, net, sess) if epoch==0: acc_dict[str(sess)]=list() acc_dict[str(sess)].append([bc_val_acc,nc_val_acc]) if sess == 0: overall = bc_val_acc else: overall = (bc_val_acc*(init_class+(sess-1)*5)+nc_val_acc*5)/(init_class+sess*5) logger.info('[Epoch %d] lr=%.4f train=%.4f | val(base)=%.4f val(novel)=%.4f overall=%.4f | loss=%.8f anc loss=%.8f ' 'pdl loss:%.8f oce loss: %.8f time: %.8f' % (epoch, lr, acc, bc_val_acc, nc_val_acc, overall, train_loss, train_anchor_loss/AL_weight, train_pdl_loss/pdl_weight, train_oce_loss/oce_weight,time.time()-tic)) if use_nme: logger.info('digG loss:%.8f'%(train_disg_loss/digG_weight)) if use_ng_min: logger.info('min_loss:%.8f'%(train_min_loss/min_weight)) if use_ng_max: logger.info('max_loss:%.8f'% (train_max_loss /max_weight)) if save_period and save_dir and (epoch + 1) % save_period == 0: net.save_parameters('%s/sess-%s-%d.params'%(save_dir, model_name, epoch)) select = eval(select_best_method) best_e = select(acc_dict, sess) logger.info('best select : base: %f novel: %f '%(acc_dict[str(sess)][best_e][0],acc_dict[str(sess)][best_e][1])) if use_AL and model_name =='quick_cnn': reload_path = '%s/sess-%s-%d.params' % (save_dir, model_name, best_e) net.load_parameters(reload_path, ctx=context) all_best_e.append(best_e) reload_path = '%s/sess-%s-%d.params'%(save_dir, model_name, best_e) net.load_parameters(reload_path, ctx=context) with open('%s/acc_dict.json'%save_dir, 'w') as json_file: json.dump(acc_dict, json_file) plot_pr(acc_dict,sess,save_dir) plot_all_sess(acc_dict,save_dir,all_best_e) def main(): if opt.mode == 'hybrid': net.hybridize() train(net, context) if __name__ == '__main__': main() 解读上述代码,并标注每行代码都是要干什么的,本代码整体要干什么,原理是什么
07-04
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值