模块出处
[arXiv 25] [link] [code] RWKV-UNet: Improving UNet with Long-Range Cooperation for Effective Medical Image Segmentation
模块名称
Inverted Residual RWKV (IR-RWKV)
模块作用
用于vision的RWKV结构
模块结构
模块代码
注:cpp扩展请参考作者原仓库
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from timm.layers.activations import *
from functools import partial
from timm.layers import DropPath, create_act_layer, LayerType
from typing import Callable, Dict, Optional, Type
from torch.utils.cpp_extension import load
T_MAX = 1024
inplace = True
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={
T_MAX}'])
def get_norm(norm_layer='in_1d'):
eps = 1e-6
norm_dict = {
'none': nn.Identity,
'in_1d': partial(nn.InstanceNorm1d, eps=eps),
'in_2d': partial(nn.InstanceNorm2d, eps=eps),
'in_3d': partial(nn.InstanceNorm3d, eps=eps),
'bn_1d': partial(nn.BatchNorm1d, eps=eps),
'bn_2d': partial(nn.BatchNorm2d, eps=eps),
# 'bn_2d': partial(nn.SyncBatchNorm, eps=eps),
'bn_3d': partial(nn.BatchNorm3d, eps=eps),
'gn': partial(nn.GroupNorm, eps=eps),
'ln_1d': partial(nn.LayerNorm, eps=eps),
# 'ln_2d': partial(LayerNorm2d, eps=eps),
}
return norm_dict[norm_layer]
def get_act(act_layer='relu'):
act_dict = {
'none': nn.Identity,
'sigmoid': Sigmoid,
'swish': Swish,
'mish': Mish,
'hsigmoid': HardSigmoid,
'hswish': HardSwish,
'hmish': HardMish,
'tanh': Tanh,
'relu': nn.ReLU,
'relu6': nn.ReLU6,
'prelu': PReLU,
'gelu': GELU,
'silu': nn.SiLU
}
return act_dict[act_layer]
class ConvNormAct(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,
skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):
super(ConvNormAct, self).__init__()
self.has_skip = skip and dim_in == dim_out
padding = math.ceil((kernel_size - stride) / 2)
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)
self.norm = get_norm(norm_layer)(dim_out)
self.act = get_act(act_layer)(inplace=inplace)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
if self.has_skip:
x = self.drop_path(x) + shortcut
return x
class SE(nn.Module):
def __init__(
self,
in_chs