<think>我们面对的是一个较长的代码文件,主要功能是实现一个增量学习(IncrementalLearning)的模型训练过程,特别是在少样本增量学习(Few-Shot Class-IncrementalLearning,FSCIL)场景下。代码使用了MXNet框架。下面我将对代码进行逐段解读,并解释整体原理。###整体目标该代码的目标是训练一个模型,使其能够逐步学习新的类别(每个新类别只有很少的样本,即few-shot),同时保持对已学习类别的识别能力(避免灾难性遗忘)。代码中实现了一种结合了多种技术的方法,包括:
-基础训练(basesession):在基础类别上训练模型。
-增量训练(incrementalsessions):在基础类别上逐步加入新类别(每个session增加5个新类别),每个新类别只有k个样本(k-shot)。-使用多种损失函数来平衡新旧类别的学习,包括:-锚点损失(Anchor Loss, AL):使用之前类别的锚点(代表点)来约束特征空间,防止遗忘。-神经气体(NeuralGas,NG)损失:通过最小化和最大化损失来调整特征空间,保持类内紧凑和类间分离。-蒸馏损失(Probability Distillation Loss, PDL):利用旧模型输出的概率分布作为软目标来指导新模型的学习。-交叉熵损失(Cross-EntropyLoss)等。###代码结构1. **导入模块和设置参数**:导入必要的库,解析命令行参数。2.**模型选择**:根据参数选择模型(如quick_cnn或resnet18)。3.**数据加载**:使用自定义的数据加载器加载CIFAR100数据集,并按照增量学习的要求划分基础类别和新类别。4.**训练循环**:-对于每个session(包括基础session和增量session)进行训练。-在每个session中,使用不同的学习率调度策略。-在训练过程中,根据session的不同应用不同的损失函数。5. **测试**:在每个session训练后,测试模型在基础类别和新类别上的准确率。
6. **模型保存和结果记录**:保存模型参数,记录准确率并绘制图表。
###详细代码解读```pythonimport matplotlibmatplotlib.use('Agg')#设置matplotlib为非交互式后端,用于在无显示环境下保存图像import argparse, time, loggingimportosimport numpyas npimportmxnetas mxfrommxnetimport gluon,ndfrom mxnet importautogradas agfrommxnet.gluon.data.visionimport transformsimportsyssys.path.append('/home/ubuntu/PycharmProjects/pythonProject/fscil-master/')#添加自定义模块路径import modelfrom model.cifar_quick importquick_cnn#导入quick_cnn模型from model.cifar_resnet_v1 importcifar_resnet20_v1#导入resnet模型fromgluoncv.utilsimport makedirs#创建目录from gluoncv.data importtransforms asgcv_transforms#图像变换fromdataloader.dataloaderimport NC_CIFAR100, merge_datasets#自定义数据加载器from tools.utils importLinearWarmUp#自定义学习率预热from tools.utils importDataLoader#自定义数据加载器from tools.utils importparse_args#解析参数fromtools.plotimport plot_pr,plot_all_sess#绘图工具fromtools.loss importDistillationSoftmaxCrossEntropyLoss,NG_Min_Loss,NG_Max_Loss#自定义损失函数fromtools.ng_anchorimport prepare_anchor#准备锚点import jsonfromtools.utilsimport select_best, select_best2,select_best3#选择最佳模型#解析命令行参数opt= parse_args()#设置批量大小batch_size= opt.batch_size#设置GPU环境num_gpus= len(opt.gpus.split(','))batch_size*= max(1, num_gpus)
context= [mx.gpu(int(i))for iin opt.gpus.split(',')]num_workers =opt.num_workers#选择模型model_name= opt.model#根据模型名称初始化模型if model_name=='quick_cnn':classes=60#初始类别数(基础类别60个)#是否固定卷积层ifopt.fix_conv:fix_layers=3fix_fc=Falseelse:fix_layers=0fix_fc= Falsenet =quick_cnn(classes,fix_layers,fix_fc=fix_fc,fw=opt.fw)#创建quick_cnn模型feature_size =64#特征维度elif model_name=='resnet18':classes=60feature_size =64net =cifar_resnet20_v1(classes=classes,wo_bn=opt.wo_bn,fw=opt.fw)#创建resnet模型else:raise KeyError('network keyerror')#如果提供了预训练模型,则加载if opt.resume_from:net.load_parameters(opt.resume_from, ctx= context)
#根据参数选择数据集DATASET =eval(opt.dataset)#优化器设置optimizer= 'nag'#使用NAG优化器#保存周期save_period= opt.save_period#绘图保存路径plot_path= opt.save_plot_dir#创建保存结果的目录save_dir= time.strftime('./experimental_result/{}/{}/%Y-%m-%d-%H-%M-%S'.format(opt.dataset,model_name), time.localtime())
save_dir =save_dir+ opt.save_namemakedirs(save_dir)#配置日志logger= logging.getLogger()logger.setLevel(logging.INFO)log_save_dir= os.path.join(save_dir,'log.txt')fh =logging.FileHandler(log_save_dir)fh.setLevel(logging.INFO)logger.addHandler(fh)logger.info(opt)#记录参数#测试函数deftest(ctx, val_data,net,sess):metric= mx.metric.Accuracy()for i, batchin enumerate(val_data):data =gluon.utils.split_and_load(batch[0],ctx_list=ctx, batch_axis=0)label= gluon.utils.split_and_load(batch[1],ctx_list=ctx, batch_axis=0)outputs= [net(X, sess)[1] forX indata]#前向传播,获取输出(第二个元素是分类结果)metric.update(label,outputs)returnmetric.get()
#训练函数def train(net, ctx):if isinstance(ctx,mx.Context):ctx =[ctx]#如果没有预训练模型,则初始化模型if notopt.resume_from:net.initialize(mx.init.Xavier(), ctx=ctx)#使用一个随机输入初始化模型(确保模型参数已初始化)ifopt.dataset =='NC_CIFAR100':n =mx.nd.zeros(shape=(1,3,32,32),ctx=ctx[0])else:raise KeyError('dataset keyerror')form inrange(9):#为什么是9?因为模型有9个session(0-8)?这里可能是为了初始化所有session的层net(n,m)#学习率调度函数(包含预热)defmakeSchedule(start_lr,base_lr,length,step,factor):schedule= mx.lr_scheduler.MultiFactorScheduler(step=step,factor=factor)schedule.base_lr =base_lrschedule= LinearWarmUp(schedule,start_lr=start_lr,length=length)returnschedule#训练参数设置sesses= list(np.arange(opt.sess_num))#session列表,从0到opt.sess_num-1epochs= [opt.epoch]*opt.sess_num#每个session的训练轮数lrs= [opt.base_lrs]+[opt.lrs]*(opt.sess_num-1)#每个session的学习率lr_decay =opt.lr_decay#学习率衰减因子base_decay_epoch =[int(i)for iin opt.base_decay_epoch.split(',')] +[np.inf]#基础session的学习率衰减轮次lr_decay_epoch =[base_decay_epoch]+[[opt.inc_decay_epoch,np.inf]]*(opt.sess_num-1)#每个session的学习率衰减轮次#损失函数权重AL_weight= opt.AL_weightmin_weight =opt.min_weightoce_weight= opt.oce_weightpdl_weight= opt.pdl_weightmax_weight =opt.max_weighttemperature= opt.temperature#蒸馏温度#是否使用各种损失函数use_AL= opt.use_AL#锚点损失use_ng_min =opt.use_ng_min# NG最小损失(类内紧凑)use_ng_max= opt.use_ng_max#NG最大损失(类间分离)ng_update =opt.ng_update#是否更新NG锚点use_oce =opt.use_oce#旧类别交叉熵损失use_pdl= opt.use_pdl#概率蒸馏损失use_nme= opt.use_nme#相似性损失(与锚点的特征相似性)use_warmUp= opt.use_warmUp#学习率预热use_ng =opt.use_ng#使用神经气体(NG)方法fix_conv =opt.fix_conv#是否固定卷积层(仅训练全连接层)fix_epoch =opt.fix_epoch#固定卷积层的轮数#增量学习参数c_way =opt.c_way#每个session新增类别数k_shot= opt.k_shot#每个新类别的样本数base_acc= opt.base_acc#基础模型的准确率(如果从预训练模型开始)select_best_method= opt.select_best#选择最佳模型的方法init_class =60#初始类别数(基础类别)anchor_num =400#锚点数量(每个类别的锚点数?)#记录准确率acc_dict ={}#记录每个session的验证准确率(基础类别和新类别)all_best_e =[]#记录每个session的最佳轮次#如果不是maxhead模型,则重新初始化全连接层(增量学习时,新类别的分类层需要初始化)ifmodel_name[-7:] !='maxhead':#重新初始化多个全连接层(用于不同session的分类)net.fc3.initialize(mx.init.Normal(sigma=0.001), ctx=ctx, force_reinit=True)net.fc4.initialize(mx.init.Normal(sigma=0.001),ctx=ctx,force_reinit=True)net.fc5.initialize(mx.init.Normal(sigma=0.001),ctx=ctx,force_reinit=True)net.fc6.initialize(mx.init.Normal(sigma=0.001),ctx=ctx,force_reinit=True)net.fc7.initialize(mx.init.Normal(sigma=0.001),ctx=ctx,force_reinit=True)net.fc8.initialize(mx.init.Normal(sigma=0.001),ctx=ctx,force_reinit=True)net.fc9.initialize(mx.init.Normal(sigma=0.001),ctx=ctx,force_reinit=True)net.fc10.initialize(mx.init.Normal(sigma=0.001),ctx=ctx,force_reinit=True)#开始每个session的训练for sessin sesses:logger.info('session :%d'%sess)#创建学习率调度器schedule= makeSchedule(start_lr=0, base_lr=lrs[sess], length=5, step=lr_decay_epoch[sess], factor=lr_decay)#如果是基础session(session0)并且有预训练模型,则跳过训练ifsess==0 andopt.resume_from:acc_dict[str(sess)]= list()acc_dict[str(sess)].append([base_acc,0])#记录基础准确率,新类别准确率为0(因为还没有新类别)all_best_e.append(0)#最佳轮次为0continue#数据预处理(图像变换)#对于quick_cnn模型,在第一个增量session(session1)且使用锚点损失时,使用简单的变换(无数据增强)ifsess ==1and model_name =='quick_cnn'anduse_AL:transform_train =transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.4914,0.4822,0.4465],[0.2023,0.1994,0.2010])])transform_test= transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.4914,0.4822,0.4465], [0.2023,0.1994,0.2010])])anchor_trans =transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.4914,0.4822,0.4465], [0.2023,0.1994,0.2010])])else:#其他情况使用包含数据增强的变换transform_train =transforms.Compose([gcv_transforms.RandomCrop(32,pad=4),transforms.RandomFlipLeftRight(),transforms.ToTensor(),transforms.Normalize([0.5071,0.4866,0.4409],[0.2009,0.1984,0.2023])])transform_test =transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5071,0.4866,0.4409], [0.2009,0.1984,0.2023])])anchor_trans =transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5071,0.4866,0.4409], [0.2009,0.1984,0.2023])])#准备锚点(用于锚点损失、蒸馏损失等)if use_ALor use_nmeor use_pdlor use_oce:if sess!=0:#基础session不需要锚点if ng_update ==True:#更新锚点if sess==1:#第一个增量session#准备锚点(使用之前的基础类别数据)update_anchor1, bmu,variances=\prepare_anchor(DATASET,logger,anchor_trans,num_workers,feature_size,net,ctx,use_ng,init_class)update_anchor_data =DataLoader(update_anchor1,anchor_trans,update_anchor1.__len__(),num_workers,shuffle=False)#根据方差调整(论文中的技巧?)ifopt.ng_var:idx_1 =np.where(variances.asnumpy()>0.5)idx_2= np.where(variances.asnumpy() <0.5)variances[idx_1]=0.9variances[idx_2]=1else:#后续sessionbase_class= init_class +(sess -1) *5#已学习类别数new_class =list(init_class+ (sess-1)*5 +(np.arange(5)))#新增类别new_set =DATASET(train=True,fine_label=True,fix_class=new_class, base_class=base_class,logger=logger)#合并之前的锚点数据集和新类别数据集update_anchor2= merge_datasets(update_anchor1,new_set)update_anchor_data= DataLoader(update_anchor2, anchor_trans,update_anchor2.__len__(),num_workers,shuffle=False)elif(sess==1):#不更新锚点,但第一次增量session需要准备锚点update_anchor,bmu, variances=\prepare_anchor(DATASET,logger,anchor_trans,num_workers,feature_size,net,ctx,use_ng,init_class)update_anchor_data =DataLoader(update_anchor,anchor_trans,update_anchor.__len__(),num_workers,shuffle=False)ifopt.ng_var:idx_1 =np.where(variances.asnumpy()>0.5)idx_2= np.where(variances.asnumpy() <0.5)variances[idx_1]=0.9variances[idx_2]=1#计算锚点的特征(用于损失函数)forbatch inupdate_anchor_data:anc_data= gluon.utils.split_and_load(batch[0],ctx_list=[ctx[0]], batch_axis=0)anc_label= gluon.utils.split_and_load(batch[1],ctx_list=[ctx[0]], batch_axis=0)with ag.pause():#不记录梯度anchor_feat, anchor_logit= net(anc_data[0],sess-1)#使用上一个session的模型计算特征和logitanchor_feat =[anchor_feat]anchor_logit= [anchor_logit]#创建优化器trainer =gluon.Trainer(net.collect_params(), optimizer,{'learning_rate':lrs[sess], 'wd':opt.wd,'momentum':opt.momentum})#评估指标metric =mx.metric.Accuracy()train_metric= mx.metric.Accuracy()loss_fn= gluon.loss.SoftmaxCrossEntropyLoss()#交叉熵损失#初始化其他损失函数ifuse_nme:defloss_fn_disG(f1, f2,weight):f1 =f1.reshape(anchor_num,-1)f2 =f2.reshape(anchor_num,-1)similar= mx.nd.sum(f1*f2,1)#计算余弦相似度return(1-similar)*weight#相似度越高损失越小digG_weight= opt.nme_weight#相似性损失的权重if use_AL:#对于quick_cnn,每个session使用不同的权重if model_name =='quick_cnn':AL_w= [120,75,120,100,50,60,90,90]AL_weight =AL_w[sess-1]else:AL_weight=opt.AL_weight#是否使用方差加权ifopt.ng_var:def l2lossVar(feat,anc,weight,var):dim= feat.shape[1]feat= feat.reshape(-1, dim)anc =anc.reshape(-1,dim)loss= mx.nd.square(feat -anc)#平方差loss= loss* weight* var#乘以权重和方差return mx.nd.mean(loss, axis=0, exclude=True)loss_fn_AL =l2lossVarelse:loss_fn_AL =gluon.loss.L2Loss(weight=AL_weight)# L2损失if use_pdl:#蒸馏损失(软目标交叉熵)loss_fn_pdl= DistillationSoftmaxCrossEntropyLoss(temperature=temperature,hard_weight=0, weight=pdl_weight)ifuse_oce:#旧类别交叉熵损失loss_fn_oce =gluon.loss.SoftmaxCrossEntropyLoss(weight=oce_weight)if use_ng_max:#NG最大损失(类间分离)loss_fn_max= NG_Max_Loss(lmbd=max_weight,margin=0.5)ifuse_ng_min:# NG最小损失(类内紧凑)min_loss= NG_Min_Loss(num_classes=opt.c_way,feature_size=feature_size,lmbd=min_weight, #center weight=0.01 inthe paperctx=ctx[0])min_loss.initialize(mx.init.Xavier(magnitude=2.24),ctx=ctx,force_reinit=True)#初始化center_trainer= gluon.Trainer(min_loss.collect_params(),optimizer="sgd",optimizer_params={"learning_rate":opt.ng_min_lr}) #优化器#学习率衰减计数lr_decay_count=0#数据加载#基础session(session0)加载基础类别数据ifopt.cum andsess==1 :#累积模式:在第一个增量session时,合并基础数据和第一个增量数据base_class= list(np.arange(init_class))joint_data =DATASET(train=True,fine_label=True,c_way=init_class, k_shot=500, fix_class=base_class, logger=logger)if sess==0 :base_class =list(np.arange(init_class))new_class =list(init_class+ (np.arange(5)))#第一个增量session的5个新类别base_data= DATASET(train=True, fine_label=True, c_way=init_class,k_shot=500,fix_class=base_class,logger=logger)bc_val_data= DataLoader(DATASET(train=False,fine_label=True,fix_class=base_class,logger=logger),transform_test,100,num_workers,shuffle=False)nc_val_data =DataLoader(DATASET(train=False,fine_label=True,fix_class=new_class, base_class=len(base_class),logger=logger),transform_test,100,num_workers,shuffle=False)else:#增量sessionbase_class =list(np.arange(init_class+ (sess-1)*5))#已学习类别new_class= list(init_class +(sess-1)*5 +(np.arange(5)))#新增类别train_data_nc =DATASET(train=True,fine_label=True,c_way=c_way, k_shot=k_shot,fix_class=new_class, base_class=len(base_class),logger=logger)bc_val_data= DataLoader(DATASET(train=False,fine_label=True,fix_class=base_class,logger=logger),transform_test,100,num_workers,shuffle=False)nc_val_data =DataLoader(DATASET(train=False,fine_label=True,fix_class=new_class, base_class=len(base_class),logger=logger),transform_test,100,num_workers,shuffle=False)#构建训练数据集if sess==0:train_data =DataLoader(base_data, transform_train,min(batch_size, base_data.__len__()), num_workers, shuffle=True)else:if opt.cum: #累积模式:合并所有已学习类别的数据joint_data =merge_datasets(joint_data, train_data_nc)train_data =DataLoader(joint_data,transform_train, min(batch_size,joint_data.__len__()),num_workers,shuffle=True)elif opt.use_all_novel: #使用所有新类别的数据(不包含基础类别)if sess==1:novel_data= train_data_ncelse:novel_data= merge_datasets(novel_data,train_data_nc)train_data= DataLoader(novel_data, transform_train,min(batch_size, novel_data.__len__()), num_workers, shuffle=True)else: #基本模式:只使用当前session的新类别数据train_data =DataLoader(train_data_nc, transform_train,min(batch_size, train_data_nc.__len__()), num_workers, shuffle=True)#训练循环for epochin range(epochs[sess]):tic =time.time()train_metric.reset()metric.reset()train_loss, train_anchor_loss,train_oce_loss=0,0,0train_disg_loss,train_pdl_loss, train_min_loss=0,0,0train_max_loss=0num_batch =len(train_data)#设置学习率(预热或衰减)ifuse_warmUp:lr= schedule(epoch)#使用预热调度trainer.set_learning_rate(lr)else:lr =trainer.learning_rateif epoch== lr_decay_epoch[sess][lr_decay_count]:#达到衰减轮次trainer.set_learning_rate(trainer.learning_rate*lr_decay)lr_decay_count +=1#在增量session的初期固定卷积层(仅训练全连接层)if sess!=0and epoch<fix_epoch:fix_cnn =fix_convelse:fix_cnn =False#遍历训练数据fori,batch inenumerate(train_data):data= gluon.utils.split_and_load(batch[0],ctx_list=ctx, batch_axis=0)label =gluon.utils.split_and_load(batch[1], ctx_list=ctx,batch_axis=0)all_loss =list()#存储所有损失withag.record():#记录梯度#前向传播output_feat,output =net(data[0],sess,fix_cnn)#输出特征和分类结果output_feat =[output_feat]output =[output]#计算交叉熵损失loss =[loss_fn(yhat,y)for yhat,y inzip(output, label)]all_loss.extend(loss)#如果使用相似性损失(NME)ifuse_nme:#计算当前样本在锚点上的特征anchor_h =[net(X,sess,fix_cnn)[0]for Xin anc_data]#计算相似性损失disg_loss= [loss_fn_disG(a_h, a, weight=digG_weight) fora_h, ain zip(anchor_h,anchor_feat)]all_loss.extend(disg_loss)#如果使用NG最大损失(类间分离)ifsess >0and use_ng_max:max_loss =[loss_fn_max(feat, label, feature_size,epoch,sess,init_class) forfeat,label inzip(output_feat, label)]all_loss.extend(max_loss[0])#如果使用锚点损失(AL)if sess>0 anduse_AL:anchor_h =[net(X,sess,fix_cnn)[0]for Xin anc_data]#当前模型计算的锚点特征ifopt.ng_var:#使用方差加权的L2损失anchor_loss =[loss_fn_AL(anchor_h[0], anchor_feat[0], AL_weight,variances)]all_loss.extend(anchor_loss)else:anchor_loss= [loss_fn_AL(a_h, a) fora_h, ain zip(anchor_h,anchor_feat)]all_loss.extend(anchor_loss)#如果使用NG最小损失(类内紧凑)ifsess >0and use_ng_min:loss_min =min_loss(output_feat[0],label[0])#计算最小损失all_loss.extend(loss_min)#如果使用概率蒸馏损失(PDL)ifsess >0and use_pdl:anchor_l= [net(X, sess, fix_cnn)[1] forX inanc_data]#当前模型对锚点的分类输出anchor_l= [anchor_l[0][:,:60+(sess-1)*5]]#只取旧类别的部分#旧模型的软目标(温度缩放后的概率)soft_label= [mx.nd.softmax(anchor_logit[0][:,:60+(sess-1)*5] /temperature)]#计算蒸馏损失pdl_loss =[loss_fn_pdl(a_h,a,soft_a) fora_h, a, soft_a inzip(anchor_l, anc_label,soft_label)]all_loss.extend(pdl_loss)#如果使用旧类别交叉熵损失(OCE)if sess>0 anduse_oce:anchorp =[net(X,sess,fix_cnn)[1]for Xin anc_data]#当前模型对锚点的分类输出#计算交叉熵损失(锚点数据上的旧类别)oce_Loss =[loss_fn_oce(ap,a)for ap, ain zip(anchorp,anc_label)]all_loss.extend(oce_Loss)#对损失取平均(以便反向传播)all_loss =[nd.mean(l) forl inall_loss]#反向传播ag.backward(all_loss)#更新模型参数trainer.step(1,ignore_stale_grad=True)#如果使用NG最小损失,则更新其参数(类中心)ifuse_ng_min:center_trainer.step(opt.c_way*opt.k_shot)#记录损失train_loss +=sum([l.sum().asscalar() forl inloss])ifsess >0and use_AL:train_anchor_loss+= sum([al.mean().asscalar()for alin anchor_loss])ifsess >0and use_oce:train_oce_loss+= sum([al.mean().asscalar()for alin oce_Loss])ifsess >0and use_nme:train_disg_loss+= sum([al.mean().asscalar()for alin disg_loss])if sess>0 anduse_pdl:train_pdl_loss +=sum([al.mean().asscalar() foral inpdl_loss])ifsess >0and use_ng_min:train_min_loss+= sum([al.mean().asscalar()for alin loss_min])ifsess >0and use_ng_max:train_max_loss+= sum([al.mean().asscalar()for alin max_loss[0]])#更新训练准确率train_metric.update(label, output)#计算平均损失train_loss /=batch_size* num_batchname,acc =train_metric.get()#测试基础类别和新类别name,bc_val_acc =test(ctx, bc_val_data, net, sess)name,nc_val_acc =test(ctx, nc_val_data, net, sess)#记录准确率ifepoch==0:acc_dict[str(sess)]=list()acc_dict[str(sess)].append([bc_val_acc,nc_val_acc])#计算整体准确率(加权平均)if sess==0:overall= bc_val_accelse:overall= (bc_val_acc*(init_class+(sess-1)*5)+nc_val_acc*5)/(init_class+sess*5)#日志记录logger.info('[Epoch%d] lr=%.4f train=%.4f| val(base)=%.4f val(novel)=%.4foverall=%.4f |loss=%.8f ancloss=%.8f ''pdlloss:%.8f oce loss: %.8ftime:%.8f'%(epoch, lr,acc,bc_val_acc,nc_val_acc,overall,train_loss, train_anchor_loss/AL_weight,train_pdl_loss/pdl_weight,train_oce_loss/oce_weight,time.time()-tic))if use_nme:logger.info('digG loss:%.8f'%(train_disg_loss/digG_weight))if use_ng_min:logger.info('min_loss:%.8f'%(train_min_loss/min_weight))ifuse_ng_max:logger.info('max_loss:%.8f'% (train_max_loss /max_weight))#保存模型if save_periodand save_dir and(epoch+1)% save_period==0:net.save_parameters('%s/sess-%s-%d.params'%(save_dir, model_name,epoch))#选择最佳模型(根据验证准确率)select =eval(select_best_method)#选择最佳模型的方法(select_best, select_best2,select_best3)best_e= select(acc_dict,sess)#最佳轮次logger.info('bestselect :base:%fnovel:%f'%(acc_dict[str(sess)][best_e][0],acc_dict[str(sess)][best_e][1]))#如果是quick_cnn且使用锚点损失,则在选择最佳模型后重新加载该模型ifuse_AL andmodel_name=='quick_cnn':reload_path ='%s/sess-%s-%d.params'% (save_dir, model_name,best_e)net.load_parameters(reload_path,ctx=context)#记录当前session的最佳轮次all_best_e.append(best_e)#加载当前session的最佳模型reload_path ='%s/sess-%s-%d.params'%(save_dir,model_name, best_e)net.load_parameters(reload_path, ctx=context)#保存准确率字典with open('%s/acc_dict.json'%save_dir, 'w')asjson_file:json.dump(acc_dict, json_file)#绘制当前session的准确率曲线plot_pr(acc_dict,sess,save_dir)#绘制所有session的准确率曲线plot_all_sess(acc_dict,save_dir,all_best_e)#主函数defmain():#是否使用混合精度(hybridize)if opt.mode== 'hybrid':net.hybridize()train(net, context)
if__name__ =='__main__':main()```
###原理总结该代码实现了一个增量学习框架,用于少样本增量学习(FSCIL)。主要特点包括:1.**基础训练**:在大量基础类别数据上训练模型。
2. **增量训练**:逐步添加新类别(每个session增加5个类别),每个新类别只有k个样本(k-shot)。
3. **防止遗忘**:-**锚点损失(AL)**:使用基础类别的锚点(代表点)来约束特征空间,防止模型遗忘。-**概率蒸馏损失(PDL)**:利用旧模型输出的概率分布作为软目标,指导新模型的学习。- **旧类别交叉熵损失(OCE)**:在锚点数据上计算旧类别的交叉熵损失。4. **特征空间调整**:-**神经气体(NG)损失**:包括最小化损失(类内紧凑)和最大化损失(类间分离),以优化特征空间。- **相似性损失(NME)**:保持新样本与锚点的特征相似性。5.**训练技巧**:- **固定卷积层**:在增量训练的初期固定卷积层,只训练全连接层,以减少过拟合。-**学习率预热和衰减**:使用学习率预热和按轮次衰减。- **累积训练**:可以选择累积所有已学习类别的数据一起训练,或者只使用新类别数据。###增量学习过程1. **Session0**:训练基础类别(60个类别)。
2. **Session1-8**:每个session增加5个新类别(每个新类别5个样本),共8个增量session,最终达到100个类别。###评估每个session训练后,测试模型在基础类别和新类别上的准确率,并计算整体准确率。通过绘制准确率曲线来展示模型性能。###注意事项-代码中使用了大量的命令行参数来控制训练过程,如模型选择、损失函数开关、权重设置等。-数据加载器`NC_CIFAR100`是自定义的,用于按增量学习要求划分CIFAR100数据集。
-模型结构(如`quick_cnn`)定义在`model`目录下。
###运行运行该代码需要配置好MXNet环境,并准备好CIFAR100数据集。通过命令行参数指定训练配置。###相关问题§§相关问题§§:1.增量学习中如何防止灾难性遗忘?除了代码中提到的锚点损失、蒸馏损失,还有哪些方法?2.代码中神经气体(NG)损失的具体原理是什么?它是如何工作的?
3.在少样本增量学习场景中,数据增强有哪些特殊考虑?
4.如何设计一个增量学习的实验来比较不同方法的性能?常用的评价指标有哪些?</think>### 整体代码解读
这段代码实现了一个**增量学习(Incremental Learning)系统**,特别关注**少样本增量学习(Few-Shot Class-Incremental Learning, FSCIL)**场景。系统使用MXNet框架,在CIFAR-100数据集上训练模型,使其能够逐步学习新类别(每次增加5个新类别),每个新类别只有少量样本(k-shot),同时保持对已学习类别的识别能力。
#### 核心原理
1. **增量学习框架**:
- 基础阶段(Session 0):在60个基础类别上训练模型
- 增量阶段(Session 1-8):每次增加5个新类别(每个类别只有k个样本)
- 最终达到100个类别(60基础 + 8×5增量)
2. **关键技术**:
- **锚点损失(Anchor Loss)**:使用旧类别代表点约束特征空间
- **神经气体(Neural Gas)**:维护特征空间的拓扑结构
- **知识蒸馏(Knowledge Distillation)**:旧模型指导新模型学习
- **类内紧凑/类间分离**:通过NG-Min/NG-Max损失优化特征空间
3. **防止灾难性遗忘**:
- 特征空间正则化
- 旧类别概率蒸馏
- 固定卷积层参数
### 逐行代码解读
```python
import matplotlib
matplotlib.use('Agg') # 设置matplotlib为非交互模式,用于服务器环境生成图片
import argparse, time, logging
import os
import numpy as np
import mxnet as mx
from mxnet import gluon, nd
from mxnet import autograd as ag
from mxnet.gluon.data.vision import transforms
import sys
sys.path.append('/home/ubuntu/PycharmProjects/pythonProject/fscil-master/') # 添加自定义模块路径
import model
from model.cifar_quick import quick_cnn # 导入快速CNN模型
from model.cifar_resnet_v1 import cifar_resnet20_v1 # 导入ResNet模型
from gluoncv.utils import makedirs # 创建目录工具
from gluoncv.data import transforms as gcv_transforms # 图像增强变换
# 自定义模块导入
from dataloader.dataloader import NC_CIFAR100, merge_datasets # 增量学习数据集加载器
from tools.utils import LinearWarmUp # 学习率预热策略
from tools.utils import DataLoader # 数据加载器
from tools.utils import parse_args # 参数解析器
from tools.plot import plot_pr, plot_all_sess # 结果可视化
from tools.loss import DistillationSoftmaxCrossEntropyLoss, NG_Min_Loss, NG_Max_Loss # 自定义损失函数
from tools.ng_anchor import prepare_anchor # 神经气体锚点准备
import json
from tools.utils import select_best, select_best2, select_best3 # 模型选择策略
# 解析命令行参数
opt = parse_args()
# 基础设置
batch_size = opt.batch_size
num_gpus = len(opt.gpus.split(','))
batch_size *= max(1, num_gpus) # 根据GPU数量调整batch size
context = [mx.gpu(int(i)) for i in opt.gpus.split(',')] # GPU上下文
num_workers = opt.num_workers # 数据加载线程数
# 模型选择
model_name = opt.model
if model_name == 'quick_cnn':
classes = 60 # 初始类别数
# 卷积层固定设置
if opt.fix_conv:
fix_layers = 3
fix_fc = False
else:
fix_layers = 0
fix_fc = False
net = quick_cnn(classes, fix_layers, fix_fc=fix_fc, fw=opt.fw) # 创建快速CNN
feature_size = 64 # 特征维度
elif model_name == 'resnet18':
classes = 60
feature_size = 64
net = cifar_resnet20_v1(classes=classes, wo_bn=opt.wo_bn, fw=opt.fw) # 创建ResNet
else:
raise KeyError('network key error')
# 加载预训练模型
if opt.resume_from:
net.load_parameters(opt.resume_from, ctx=context)
# 数据集选择
DATASET = eval(opt.dataset) # 通常是NC_CIFAR100
# 优化器设置
optimizer = 'nag' # Nesterov加速梯度下降
save_period = opt.save_period # 模型保存周期
plot_path = opt.save_plot_dir # 绘图保存路径
# 创建结果保存目录
save_dir = time.strftime('./experimental_result/{}/{}/%Y-%m-%d-%H-%M-%S'.format(opt.dataset, model_name), time.localtime())
save_dir = save_dir + opt.save_name
makedirs(save_dir)
# 日志配置
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_save_dir = os.path.join(save_dir, 'log.txt')
fh = logging.FileHandler(log_save_dir)
fh.setLevel(logging.INFO)
logger.addHandler(fh)
logger.info(opt) # 记录配置参数
# 测试函数
def test(ctx, val_data, net, sess):
"""评估模型在验证集上的性能"""
metric = mx.metric.Accuracy()
for i, batch in enumerate(val_data):
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
outputs = [net(X, sess)[1] for X in data] # 获取分类输出
metric.update(label, outputs)
return metric.get()
# 主训练函数
def train(net, ctx):
if isinstance(ctx, mx.Context):
ctx = [ctx]
# 模型初始化
if not opt.resume_from:
net.initialize(mx.init.Xavier(), ctx=ctx)
# 使用虚拟输入初始化模型
if opt.dataset == 'NC_CIFAR100':
n = mx.nd.zeros(shape=(1,3,32,32),ctx=ctx[0])
else:
raise KeyError('dataset keyerror')
for m in range(9): # 初始化所有session层
net(n,m)
# 学习率调度器
def makeSchedule(start_lr, base_lr, length, step, factor):
"""创建带预热的学习率调度"""
schedule = mx.lr_scheduler.MultiFactorScheduler(step=step, factor=factor)
schedule.base_lr = base_lr
schedule = LinearWarmUp(schedule, start_lr=start_lr, length=length)
return schedule
# 训练参数配置
sesses = list(np.arange(opt.sess_num)) # session列表 [0,1,...,8]
epochs = [opt.epoch] * opt.sess_num # 每个session的训练轮数
lrs = [opt.base_lrs] + [opt.lrs] * (opt.sess_num - 1) # 学习率配置
lr_decay = opt.lr_decay # 学习率衰减因子
base_decay_epoch = [int(i) for i in opt.base_decay_epoch.split(',')] + [np.inf]
lr_decay_epoch = [base_decay_epoch] + [[opt.inc_decay_epoch, np.inf]] * (opt.sess_num - 1)
# 损失函数权重
AL_weight = opt.AL_weight # 锚点损失权重
min_weight = opt.min_weight # NG-Min损失权重
oce_weight = opt.oce_weight # 旧类别交叉熵权重
pdl_weight = opt.pdl_weight # 概率蒸馏损失权重
max_weight = opt.max_weight # NG-Max损失权重
temperature = opt.temperature # 蒸馏温度
# 功能开关
use_AL = opt.use_AL # 是否使用锚点损失
use_ng_min = opt.use_ng_min # 是否使用NG-Min损失
use_ng_max = opt.use_ng_max # 是否使用NG-Max损失
ng_update = opt.ng_update # 是否更新神经气体节点
use_oce = opt.use_oce # 是否使用旧类别交叉熵
use_pdl = opt.use_pdl # 是否使用概率蒸馏
use_nme = opt.use_nme # 是否使用相似性损失
use_warmUp = opt.use_warmUp # 是否使用学习率预热
use_ng = opt.use_ng # 是否使用神经气体
fix_conv = opt.fix_conv # 是否固定卷积层
fix_epoch = opt.fix_epoch # 固定卷积层的轮数
# 增量学习参数
c_way = opt.c_way # 新增类别数(通常为5)
k_shot = opt.k_shot # 每个新类别的样本数
base_acc = opt.base_acc # 基础模型准确率
select_best_method = opt.select_best # 最佳模型选择方法
init_class = 60 # 初始类别数
anchor_num = 400 # 锚点数量
# 结果记录
acc_dict = {} # 存储每个session的准确率
all_best_e = [] # 存储每个session的最佳epoch
# 增量session的FC层初始化
if model_name[-7:] != 'maxhead':
net.fc3.initialize(mx.init.Normal(sigma=0.001), ctx=ctx, force_reinit=True)
# ... 其他FC层初始化
# 主训练循环(按session进行)
for sess in sesses:
logger.info('session : %d' % sess)
# 创建学习率调度器
schedule = makeSchedule(start_lr=0, base_lr=lrs[sess], length=5,
step=lr_decay_epoch[sess], factor=lr_decay)
# 基础session的特殊处理
if sess == 0 and opt.resume_from:
acc_dict[str(sess)] = list()
acc_dict[str(sess)].append([base_acc, 0]) # [基础类别acc, 新类别acc]
all_best_e.append(0)
continue
# 数据增强配置
if sess == 1 and model_name == 'quick_cnn' and use_AL:
# 简单变换(无增强)
transform_train = transforms.Compose([...])
else:
# 完整数据增强
transform_train = transforms.Compose([
gcv_transforms.RandomCrop(32, pad=4),
transforms.RandomFlipLeftRight(),
...
])
# 锚点准备(神经气体)
if use_AL or use_nme or use_pdl or use_oce:
if sess != 0: # 基础session不需要锚点
if ng_update: # 更新神经气体节点
if sess == 1: # 第一次增量session
# 准备初始锚点
update_anchor1, bmu, variances = prepare_anchor(...)
update_anchor_data = DataLoader(...)
# 方差调整
if opt.ng_var: ...
else: # 后续session
# 合并新旧锚点
update_anchor2 = merge_datasets(update_anchor1, new_set)
update_anchor_data = DataLoader(...)
elif sess == 1: # 不更新但首次准备锚点
update_anchor, bmu, variances = prepare_anchor(...)
update_anchor_data = DataLoader(...)
# 计算锚点特征
for batch in update_anchor_data:
with ag.pause(): # 不计算梯度
anchor_feat, anchor_logit = net(anc_data[0], sess-1)
# 创建优化器
trainer = gluon.Trainer(net.collect_params(), optimizer, {...})
# 评估指标
metric = mx.metric.Accuracy()
train_metric = mx.metric.Accuracy()
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() # 基础损失
# 自定义损失初始化
if use_nme: # 相似性损失
def loss_fn_disG(f1, f2, weight): ...
if use_AL: # 锚点损失
if opt.ng_var: # 方差加权
def l2lossVar(feat, anc, weight, var): ...
else:
loss_fn_AL = gluon.loss.L2Loss(weight=AL_weight)
if use_pdl: # 概率蒸馏损失
loss_fn_pdl = DistillationSoftmaxCrossEntropyLoss(...)
if use_oce: # 旧类别交叉熵
loss_fn_oce = gluon.loss.SoftmaxCrossEntropyLoss(weight=oce_weight)
if use_ng_min: # NG-Min损失(类内紧凑)
min_loss = NG_Min_Loss(...)
min_loss.initialize(...)
center_trainer = gluon.Trainer(...) # 单独优化器
if use_ng_max: # NG-Max损失(类间分离)
loss_fn_max = NG_Max_Loss(...)
lr_decay_count = 0 # 学习率衰减计数器
# 数据加载
if sess == 0: # 基础session
base_data = DATASET(train=True, fine_label=True, ...)
bc_val_data = DataLoader(...) # 基础类别验证集
nc_val_data = DataLoader(...) # 新类别验证集
else: # 增量session
train_data_nc = DATASET(...) # 新类别训练数据
bc_val_data = DataLoader(...)
nc_val_data = DataLoader(...)
# 训练数据组合策略
if sess == 0:
train_data = DataLoader(base_data, ...)
else:
if opt.cum: # 累积模式(合并新旧数据)
joint_data = merge_datasets(joint_data, train_data_nc)
train_data = DataLoader(joint_data, ...)
elif opt.use_all_novel: # 仅使用新数据
novel_data = merge_datasets(novel_data, train_data_nc)
train_data = DataLoader(novel_data, ...)
else: # 基本模式(仅当前新数据)
train_data = DataLoader(train_data_nc, ...)
# 当前session的训练循环
for epoch in range(epochs[sess]):
# 训练指标重置
train_metric.reset()
metric.reset()
# 损失初始化
train_loss, train_anchor_loss, ... = 0, 0, 0
# 学习率调整
if use_warmUp:
lr = schedule(epoch)
trainer.set_learning_rate(lr)
else:
...
# 卷积层固定策略
if sess != 0 and epoch < fix_epoch:
fix_cnn = fix_conv
else:
fix_cnn = False
# 批次训练
for i, batch in enumerate(train_data):
data = ... # 分割数据到GPU
label = ...
with ag.record(): # 自动微分记录
output_feat, output = net(data[0], sess, fix_cnn) # 前向传播
# 基础交叉熵损失
loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
all_loss.extend(loss)
# 各种自定义损失
if use_nme: ... # 相似性损失
if use_ng_max: ... # NG-Max损失
if use_AL: ... # 锚点损失
if use_ng_min: ... # NG-Min损失
if use_pdl: ... # 概率蒸馏
if use_oce: ... # 旧类别交叉熵
all_loss = [nd.mean(l) for l in all_loss] # 平均损失
# 反向传播和优化
ag.backward(all_loss)
trainer.step(...)
if use_ng_min:
center_trainer.step(...) # NG-Min单独优化
# 损失记录
train_loss += ...
if use_AL: train_anchor_loss += ...
...
# 更新训练指标
train_metric.update(label, output)
# 计算平均损失和准确率
train_loss /= batch_size * num_batch
name, acc = train_metric.get()
# 验证评估
name, bc_val_acc = test(ctx, bc_val_data, net, sess) # 基础类别
name, nc_val_acc = test(ctx, nc_val_data, net, sess) # 新类别
# 记录结果
if epoch == 0:
acc_dict[str(sess)] = list()
acc_dict[str(sess)].append([bc_val_acc, nc_val_acc])
# 计算整体准确率
if sess == 0:
overall = bc_val_acc
else:
overall = (bc_val_acc * (init_class + (sess-1)*5) + nc_val_acc * 5) / (init_class + sess*5)
# 日志记录
logger.info(f'[Epoch {epoch}] lr={lr:.4f} ...')
# 模型保存
if save_period and (epoch + 1) % save_period == 0:
net.save_parameters(f'{save_dir}/sess-{model_name}-{epoch}.params')
# 选择最佳模型
select = eval(select_best_method)
best_e = select(acc_dict, sess)
logger.info(f'best select: base: {acc_dict[str(sess)][best_e][0]} novel: {acc_dict[str(sess)][best_e][1]}')
# 特殊处理:重新加载最佳模型
if use_AL and model_name == 'quick_cnn':
net.load_parameters(f'{save_dir}/sess-{model_name}-{best_e}.params', ctx=context)
# session结束处理
all_best_e.append(best_e)
net.load_parameters(f'{save_dir}/sess-{model_name}-{best_e}.params', ctx=context)
# 保存结果
with open(f'{save_dir}/acc_dict.json', 'w') as json_file:
json.dump(acc_dict, json_file)
# 可视化
plot_pr(acc_dict, sess, save_dir)
# 所有session结束后的可视化
plot_all_sess(acc_dict, save_dir, all_best_e)
# 主函数
def main():
if opt.mode == 'hybrid': # 混合精度模式
net.hybridize()
train(net, context)
if __name__ == '__main__':
main()
```
### 关键组件详解
1. **增量学习流程**:
```python
for sess in sesses: # 遍历所有session
if sess == 0: # 基础训练
else: # 增量训练
```
- Session 0:60个基础类别训练
- Session 1-8:每次增加5个新类别(k-shot)
2. **神经气体锚点**:
```python
update_anchor, bmu, variances = prepare_anchor(...)
```
- 维护特征空间的拓扑结构
- 作为旧类别的代表点
- 用于约束特征空间变化
3. **损失函数组合**:
```python
# 基础交叉熵
loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
# 锚点损失(特征空间约束)
if use_AL: ...
# 概率蒸馏(旧类别知识迁移)
if use_pdl: ...
# 类内紧凑/类间分离
if use_ng_min: ... # 类内紧凑
if use_ng_max: ... # 类间分离
```
多种损失协同工作防止灾难性遗忘
4. **训练策略**:
```python
if sess != 0 and epoch < fix_epoch:
fix_cnn = True # 固定卷积层
```
- 增量训练初期固定特征提取器
- 只调整分类层参数
- 减少过拟合风险
### 系统工作流程
1. **初始化**:
- 解析配置参数
- 准备数据集和模型
- 配置日志和结果目录
2. **基础训练**(Session 0):
- 在60个类别上训练模型
- 保存基础模型
3. **增量训练**(Session 1-8):
- 加载前session模型
- 准备新增类别的少量样本
- 更新神经气体锚点
- 使用组合损失进行训练
- 评估新旧类别性能
- 保存最佳模型
4. **结果分析**:
- 记录每个session的准确率
- 可视化性能曲线
- 生成最终报告
### 创新点总结
1. **神经气体锚点**:动态维护特征空间结构
2. **损失函数组合**:多角度约束特征空间
3. **渐进式训练策略**:
- 卷积层固定
- 学习率预热
- 早停策略
4. **灵活的数据策略**:
- 累积模式
- 仅新数据模式
- 混合模式
### 适用场景
- 需要持续学习新类别的视觉系统
- 样本稀缺的新类别学习
- 防止灾难性遗忘的场景
- 资源受限的边缘设备学习