SwinTransformer模型转化:pytorch模型转keras。

本文介绍了如何将基于PyTorch实现的SwinTransformer模型转换为Keras版本,详细讨论了包括数据重塑、窗口操作、注意力机制在内的关键步骤,并提供了代码实现。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

window_size (int): window size(M)

Returns:

windows: (num_windows*B, window_size, window_size, C)

“”"

B, H, W, C = x.shape

x = tf.reshape(x, [B, H // window_size, window_size, W // window_size, window_size, C])

transpose: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]

reshape: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]

x = tf.transpose(x, [0, 1, 3, 2, 4, 5])

windows = tf.reshape(x, [-1, window_size, window_size, C])

return windows

def window_reverse(windows, window_size: int, H: int, W: int):

“”"

将一个个window还原成一个feature map

Args:

windows: (num_windows*B, window_size, window_size, C)

window_size (int): Window size(M)

H (int): Height of image

W (int): Width of image

Returns:

x: (B, H, W, C)

“”"

B = int(windows.shape[0] / (H * W / window_size / window_size))

reshape: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]

x = tf.reshape(windows, [B, H // window_size, W // window_size, window_size, window_size, -1])

permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]

reshape: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]

x = tf.transpose(x, [0, 1, 3, 2, 4, 5])

x = tf.reshape(x, [B, H, W, -1])

return x

class PatchMerging(layers.Layer):

def init(self, dim: int, norm_layer=layers.LayerNormalization, name=None):

super(PatchMerging, self).init(name=name)

self.dim = dim

self.reduction = layers.Dense(2 * dim,

use_bias=False,

kernel_initializer=initializers.TruncatedNormal(stddev=0.02),

name=“reduction”)

self.norm = norm_layer(epsilon=1e-6, name=“norm”)

def call(self, x, H, W):

“”"

x: [B, H*W, C]

“”"

B, L, C = x.shape

assert L == H * W, “input feature has wrong size”

x = tf.reshape(x, [B, H, W, C])

padding

如果输入feature map的H,W不是2的整数倍,需要进行padding

pad_input = (H % 2 != 0) or (W % 2 != 0)

if pad_input:

paddings = tf.constant([[0, 0],

[0, 1],

[0, 1],

[0, 0]])

x = tf.pad(x, paddings)

x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]

x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]

x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]

x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]

x = tf.concat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]

x = tf.reshape(x, [B, -1, 4 * C]) # [B, H/2W/2, 4C]

x = self.norm(x)

x = self.reduction(x) # [B, H/2W/2, 2C]

return x

class MLP(layers.Layer):

“”"

MLP as used in Vision Transformer, MLP-Mixer and related networks

“”"

k_ini = initializers.TruncatedNormal(stddev=0.02)

b_ini = initializers.Zeros()

def init(self, in_features, mlp_ratio=4.0, drop=0., name=None):

super(MLP, self).init(name=name)

self.fc1 = layers.Dense(int(in_features * mlp_ratio), name=“fc1”,

kernel_initializer=self.k_ini, bias_initializer=self.b_ini)

self.act = layers.Activation(“gelu”)

self.fc2 = layers.Dense(in_features, name=“fc2”,

kernel_initializer=self.k_ini, bias_initializer=self.b_ini)

self.drop = layers.Dropout(drop)

def call(self, x, training=None):

x = self.fc1(x)

x = self.act(x)

x = self.drop(x, training=training)

x = self.fc2(x)

x = self.drop(x, training=training)

return x

class WindowAttention(layers.Layer):

r"“” Window based multi-head self attention (W-MSA) module with relative position bias.

It supports both of shifted and non-shifted window.

Args:

dim (int): Number of input channels.

window_size (tuple[int]): The height and width of the window.

num_heads (int): Number of attention heads.

qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True

attn_drop_ratio (float, optional): Dropout ratio of attention weight. Default: 0.0

proj_drop_ratio (float, optional): Dropout ratio of output. Default: 0.0

“”"

k_ini = initializers.GlorotUniform()

b_ini = initializers.Zeros()

def init

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值