以VGG模型为例:
class LambdaBase(nn.Sequential): #lambda表达式
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))
class LambdaMap(LambdaBase): #map
def forward(self, input):
return list(map(self.lambda_func,self.forward_prepare(input)))
class LambdaReduce(LambdaBase):#reduce
def forward(self, input):
return reduce(self.lambda_func,self.forward_prepare(input))
使用:
VGG_Face_torch = nn.Sequential( # Sequential,
nn.Conv2d(in_channels=3,out_channels=64,kernel_size=(3,3),stride=(1,1),padding=(1,1)),
nn.ReLU(),
nn.Conv2d(64,64,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.MaxPool2d((2, 2),(2, 2),(0, 0),ceil_mode=True),
nn.Conv2d(64,128,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.Conv2d(128,128,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.MaxPool2d((2, 2),(2, 2),(0, 0),ceil_mode=True),
nn.Conv2d(128,256,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.Conv2d(256,256,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.MaxPool2d((2, 2), (2, 2), (0, 0),ceil_mode=True),
#ceil_mode - 如果等于True,计算输出信号大小的时候,会使用向上取整,代替默认的向下取整的操作
nn.Conv2d(256,512,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.MaxPool2d((2, 2),(2, 2),(0, 0),ceil_mode=True),
nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.MaxPool2d((2, 2),(2, 2),(0, 0),ceil_mode=True),
Lambda(lambda x: x.view(x.size(0),-1)), # (batchsize,channels,x,y),View相当于reshape,即变成batchsize*y
#如果batchsize为1,则直接x.view(1,-1)
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(25088,4096)), # Linear,
nn.ReLU(),
nn.Dropout(0.5),
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(4096,4096)), # Linear,
nn.ReLU(),
nn.Dropout(0.5),
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(4096,2622)), # Linear,
)