达摩院20年提出的TResNet的代码解读
本文侧重于TResNet与resnet的比较来展开
Tresnet 的 forward代码
分为三层,首先图像通入self.body。此层主要是一系列的对于图像卷积。
然后通入self.global_pool,其实这一层主要是作者自己写了一下平均池化的module,然后作者说这样重写了之后可以使得GPU计算速度增加五倍。。不过论文也没有讲清楚为什么速度会增加。(个人认为是它重写之后可以避免池化之后还要缩减维度的操作。具体代码见FastGlobalAvgPool2d)然后最后通入一个head,这就是一个全联接层
def forward(self, x):
x = self.body(x)
self.embeddings = self.global_pool(x)
logits = self.head(self.embeddings)
return logits
class FastGlobalAvgPool2d(nn.Module):
def __init__(self, flatten=False):
super(FastGlobalAvgPool2d, self).__init__()
self.flatten = flatten
def forward(self, x):
if self.flatten:
in_size = x.size()
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
else:
return x.view(x.