用self.modules()方法批量初始化模型权重
用self.modules()可以遍历组成网络的所有模块,以及这些模块的后代模块。
Example:
创建一个网络,其中包括一个预先定义的DoubleConv类
class DoubleConv(nn.Module): def __init__(self,in_channels,out_channels): super(DoubleConv,self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels,out_channels,3,1,1,bias=False), nn.BatchNorm2d(out_channels), nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self,x): return self.conv(x) class Normal_Down_Sampling(nn.Module): def __init__(self, in_channels, out_channels): super(Normal_Down_Sampling, self).__init__() self.conv = nn.Sequential( DoubleConv(in_channels,in_channels), nn.Conv2d(in_channels, out_channels, 7, 2), # 7*7 step