Swin Transformer v1模型:Swin Transformer模型具体代码实现-优快云博客
Swin Transformer mlp模型:Swin Transformer 变体1:使用MLP代替多头注意力机制-优快云博客
一、v2模型改动:
1、标准化位置不同:
v1模型是先标准化,再attn,之后残差连接:优点训练更稳定、梯度传播路径更短、适合深层网络
v2是先attn,再标准化,之后残差连接:优点是表达能力更强、收敛性能更好
2、前馈连接网络架构不同:
v1模型是先标准化,再mlp,之后残差连接
v2是先mlp,再标准化,之后残差连接
3、相对位置偏置的计算方式不同:
v1模型中相对位置偏置是离散的,即每个位置的相对位置偏置是离散的,通过参数学习得到
v2实现了一个连续相对位置偏置,即每个位置的相对位置偏置是连续的,通过参数学习得到,这里使用了mlp来学习相对位置偏置
4、映射得到q、k、v的方式不同:
v1模型中是直接使用全连接层得到qkv,然后再分割得到q、k、v
v2模型中是使用一个全连接层(可以自由设置偏置)得到qkv,然后再分割得到q、k、v
二、改动具体实现:
1、标准化位置改动:
v1模型:在SwinTransformerBlock中:
def forward(self, x):
#H和W是patch的高个数和宽个数
H, W = self.input_resolution
#B: 批次大小, L:序列长度,C:通道数
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# 循环移动窗口
if self.shift_size > 0:
if not self.fused_window_process:
#移动窗口,将窗口向右下角移动shift_size个像素
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# 分区:(B, H, W, C) -> (nW*B, window_size, window_size, C),其中nW=H//window_size*W//window_size
x_windows = window_partition(shifted_x, self.window_size)
# else:
# x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# 分区:(B, H, W, C) -> (nW*B, window_size, window_size, C),其中nW=H//window_size*W//window_size
x_windows = window_partition(shifted_x, self.window_size)
#重塑形状:(nW*B, window_size, window_size, C) -> (nW*B, window_size*window_size, C)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# 注意力机制
attn_windows = self.attn(x_windows, mask=self.attn_mask)
# 重塑形状:(nW*B, window_size*window_size, C) -> (nW*B, window_size, window_size, C)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# 反向循环移动窗口
if self.shift_size > 0:
if not self.fused_window_process:
#合并窗口,将窗口向左上角移动shift_size个像素,形状由(nW*B, window_size, window_size, C) -> (B, H, W, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W)
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
# else:
# x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
#只合并、不移动
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# 前馈网络
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
v2模型:在SwinTransformerBlock中:
def forward(self, x):
H, W = self.input_resolution
#L是窗口的数量
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = x.v