超参数:epoch=200,batch_size=24(原paper),如果您没有足够的GPU内存,可以将bacth_size减少到12或6以节省内存。
import torch
from torch import nn
def no_weight_decay():
return {
'absolute_pos_embed'}
def no_weight_decay_keywords():
return {
'relative_position_bias_table'}
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
class Swin_Unet(nn.Module):
def __init__(self, img_size, patch_size, in_channels, num_classes,
embed_dim, depths, num_heads,
window_size, mlp_ratio, qkv_bias, qk_scale,
drop_rate, attn_drop_rate, drop_path_rate):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.num_features_up = int(embed_dim * 2)
self.mlp_ratio = mlp_ratio
# patch partition 和 linear embedding
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
norm_layer=nn.LayerNorm)
patches_resolution

文章探讨了Swin-Unet模型的结构,包括编码器、解码器设计,以及如何在内存受限情况下调整batch_size。
最低0.47元/天 解锁文章
4712

被折叠的 条评论
为什么被折叠?



