测试时间
1070 sge_resnet34 416 1 25ms,比原版慢了12ms
1070 sge_resnet34 416 8 90ms,比原版慢了40ms
import time
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch
__all__ = ['sge_resnet18', 'sge_resnet34', 'sge_resnet50', 'sge_resnet101',
'sge_resnet152']
class SpatialGroupEnhance(nn.Module):
def __init__(self, groups = 64):
super(SpatialGroupEnhance, self).__init__()