论文地址:https://arxiv.org/pdf/2111.15193.pdf
代码地址:https://github.com/OliverRensu/Shunted-Transformer
模型是通过 SSA.py 文件中利用 @register_model 方法定义:
具体流程如下:
step1: model = ShuntedTransformer()
@register_model
def shunted_t(pretrained=False, **kwargs):
model = ShuntedTransformer(
patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[1, 2, 4, 1], sr_ratios=[8, 4, 2, 1], num_conv=0,
**kwargs)
model.default_cfg = _cfg()
return model
step2: Class ShuntedTransformer()
class ShuntedTransformer(nn.Module):
"""省略了一些简单的定义"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2,