每日Attention学习22——Inverted Residual RWKV

模块出处

[arXiv 25] [link] [code] RWKV-UNet: Improving UNet with Long-Range Cooperation for Effective Medical Image Segmentation


模块名称

Inverted Residual RWKV (IR-RWKV)


模块作用

用于vision的RWKV结构


模块结构

在这里插入图片描述


模块代码

注:cpp扩展请参考作者原仓库

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from timm.layers.activations import *
from functools import partial
from timm.layers import DropPath, create_act_layer, LayerType
from typing import Callable, Dict, Optional, Type
from torch.utils.cpp_extension import load


T_MAX = 1024
inplace = True
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
                verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={
     T_MAX}'])


def get_norm(norm_layer='in_1d'):
	eps = 1e-6
	norm_dict = {
   
		'none': nn.Identity,
		'in_1d': partial(nn.InstanceNorm1d, eps=eps),
		'in_2d': partial(nn.InstanceNorm2d, eps=eps),
		'in_3d': partial(nn.InstanceNorm3d, eps=eps),
		'bn_1d': partial(nn.BatchNorm1d, eps=eps),
		'bn_2d': partial(nn.BatchNorm2d, eps=eps),
		# 'bn_2d': partial(nn.SyncBatchNorm, eps=eps),
		'bn_3d': partial(nn.BatchNorm3d, eps=eps),
		'gn': partial(nn.GroupNorm, eps=eps),
		'ln_1d': partial(nn.LayerNorm, eps=eps),
		# 'ln_2d': partial(LayerNorm2d, eps=eps),
	}
	return norm_dict[norm_layer]


def get_act(act_layer='relu'):
	act_dict = {
   
		'none': nn.Identity,
		'sigmoid': Sigmoid,
		'swish': Swish,
		'mish': Mish,
		'hsigmoid': HardSigmoid,
		'hswish': HardSwish,
		'hmish': HardMish,
		'tanh': Tanh,
		'relu': nn.ReLU,
		'relu6': nn.ReLU6,
		'prelu': PReLU,
		'gelu': GELU,
		'silu': nn.SiLU
	}
	return act_dict[act_layer]


class ConvNormAct(nn.Module):
	def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,
				 skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):
		super(ConvNormAct, self).__init__()
		self.has_skip = skip and dim_in == dim_out
		padding = math.ceil((kernel_size - stride) / 2)
		self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)
		self.norm = get_norm(norm_layer)(dim_out)
		self.act = get_act(act_layer)(inplace=inplace)
		self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
	
	def forward(self, x):
		shortcut = x
		x = self.conv(x)
		x = self.norm(x)
		x = self.act(x)
		if self.has_skip:
			x = self.drop_path(x) + shortcut
		return x
      

class SE(nn.Module):
    def __init__(
            self,
            in_chs
倒置残差块(Inverted Residual block)是一种主导移动网络结构设计的模块。它通过引入两个设计原则来改变经典的残差瓶颈结构,即学习倒置残差和使用线性瓶颈。这种设计改变的必要性是重新思考信息丢失和梯度混淆的风险。为了有效缓解这些问题,提出了一种新的瓶颈设计,称为sandglass block。这种设计在更高维度进行恒等映射和空间转换,以减少信息丢失和梯度混淆的影响,非常有效。在MobileNeXt中,使用了基于sandglass block的架构,堆叠了多个这样的模块。网络的输入是一个32维输出的卷积层,然后是堆叠的sandglass block,最后是全局平均池化层将二维特征图压缩为一维,并由全连接层输出每个类别的分数。123 #### 引用[.reference_title] - *1* *2* [Rethinking Bottleneck Structure for Efficient Mobile Network Design](https://blog.youkuaiyun.com/goodenough5/article/details/129852613)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}} ] [.reference_item] - *3* [MobileNext:打破常规,依图逆向改造inverted residual block | ECCV 2020](https://blog.youkuaiyun.com/lichlee/article/details/125313988)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值