1.Dropout原理
Droupout是指在深度网络的训练中,以一定的概率随机地“临时丢弃”一部分神经元。
具体来讲,Dropout作用于每份小批量训练数据,由于其随机丢弃部分神经元的机制,相当于每次迭代都在训练不同结构的神经网络。类似于Bagging方法,dropout可被认为是一种实用的大规模深度神经网络的模型集成算法。
Dropout的具体实现中,要求某个神经元节点激活值以一定的概率p被“丢弃”,即该神经元暂时停止工作。对于任意神经元,每次训练中都与一组随机挑选的不同的神经元集合共同进行优化,这个过程会减弱全体神经元之间的联合适应性,减少过拟合的风险,增加泛化能力。
在训练时,随机选出隐藏层的神经元,然后将其删除。被删除的神经元不再进行信号的传递,训练时,每传递一次数据,就会随机选择要删除的神经元。然后,测试时,虽然会传递所有的神经元信号,但是对于各个神经元的输出,要乘上训练时的删除比例后再输出。
2 实现L1、L2正则化
参考:
https://blog.youkuaiyun.com/guyuealian/article/details/88426648
正则化的类:
class Regularization(torch.nn.Module):
def __init__(self,model,weight_decay,p=2):
'''
:param model 模型
:param weight_decay:正则化参数
:param p: 范数计算中的幂指数值,默认求2范数,
当p=0为L2正则化,p=1为L1正则化
'''
super(Regularization, self).__init__()
if weight_decay <= 0:
print("param weight_decay can not <=0")
exit(0)
self.model=model
self.weight_decay=weight_decay
self.p=p
self.weight_list=self.get_weight(model)
self.weight_info(self.weight_list)
def to(self,device):
'''
指定运行模式
:param device: cude or cpu
:return:
'''
self.device=device
super().to(device)
return self
def forward(self, model):
self.weight_list=self.get_weight(model)#获得最新的权重
reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p)
return reg_loss
def get_weight(self,model):
'''
获得模型的权重列表
:param model:
:return:
'''
weight_list = []
for name, param in model.named_parameters():
if 'weight' in name:
weight = (name, param)
weight_list.append(weight)
return weight_list
def regularization_loss(self,weight_list, weight_decay, p=2):
'''
计算张量范数
:param weight_list:
:param p: 范数计算中的幂指数值,默认求2范数
:param weight_decay:
:return:
'''
# weight_decay=Variable(torch.FloatTensor([weight_decay]).to(self.device),requires_grad=True)
# reg_loss=Variable(torch.FloatTensor([0.]).to(self.device),requires_grad=True)
# weight_decay=torch.FloatTensor([weight_decay]).to(self.device)
# reg_loss=torch.FloatTensor([0.]).to(self.device)
reg_loss=0
for name, w in weight_list:
l2_reg = torch.norm(w, p=p)
reg_loss = reg_loss + l2_reg
reg_loss=weight_decay*reg_loss
return reg_loss
def weight_info(self,weight_list):
'''
打印权重列表信息
:param weight_list:
:ret