Swin Transformer变体2:Swin Transformerv2模型

Swin Transformer v1模型:Swin Transformer模型具体代码实现-优快云博客

Swin Transformer mlp模型:Swin Transformer 变体1:使用MLP代替多头注意力机制-优快云博客

一、v2模型改动:

1、标准化位置不同:

v1模型是先标准化,再attn,之后残差连接:优点训练更稳定、梯度传播路径更短、适合深层网络

v2是先attn,再标准化,之后残差连接:优点是表达能力更强、收敛性能更好

2、前馈连接网络架构不同:

v1模型是先标准化,再mlp,之后残差连接

v2是先mlp,再标准化,之后残差连接

3、相对位置偏置的计算方式不同:

v1模型中相对位置偏置是离散的,即每个位置的相对位置偏置是离散的,通过参数学习得到

v2实现了一个连续相对位置偏置,即每个位置的相对位置偏置是连续的,通过参数学习得到,这里使用了mlp来学习相对位置偏置

4、映射得到q、k、v的方式不同:

v1模型中是直接使用全连接层得到qkv,然后再分割得到q、k、v

v2模型中是使用一个全连接层(可以自由设置偏置)得到qkv,然后再分割得到q、k、v

二、改动具体实现:

1、标准化位置改动:

v1模型:在SwinTransformerBlock中:

def forward(self, x):
        #H和W是patch的高个数和宽个数
        H, W = self.input_resolution
        #B: 批次大小, L:序列长度,C:通道数
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # 循环移动窗口
        if self.shift_size > 0:
            if not self.fused_window_process:
                #移动窗口,将窗口向右下角移动shift_size个像素
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                # 分区:(B, H, W, C) -> (nW*B, window_size, window_size, C),其中nW=H//window_size*W//window_size
                x_windows = window_partition(shifted_x, self.window_size)  
            # else:
            #     x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
        else:
            shifted_x = x
            # 分区:(B, H, W, C) -> (nW*B, window_size, window_size, C),其中nW=H//window_size*W//window_size
            x_windows = window_partition(shifted_x, self.window_size)  
        #重塑形状:(nW*B, window_size, window_size, C) -> (nW*B, window_size*window_size, C)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  

        # 注意力机制
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  

        # 重塑形状:(nW*B, window_size*window_size, C) -> (nW*B, window_size, window_size, C)
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)

        # 反向循环移动窗口
        if self.shift_size > 0:
            if not self.fused_window_process:
                #合并窗口,将窗口向左上角移动shift_size个像素,形状由(nW*B, window_size, window_size, C) -> (B, H, W, C)
                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  
                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            # else:
            #     x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
        else:
            #只合并、不移动
            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
            x = shifted_x
        
        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(x)

        # 前馈网络
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

v2模型:在SwinTransformerBlock中:

def forward(self, x):
        H, W = self.input_resolution
        #L是窗口的数量
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = x.v
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值