import time
import math
from functools import partial
from typing import Optional, Callable, Any
import numpy as np
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
# 确保所有模块在同一个设备上运行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
class MockSelectiveScanCuda:
@staticmethod
def fwd(u, delta, A, B, C, D=None, *args, **kwargs):
device = u.device
delta = delta.to(device)
A = A.to(device)
B = B.to(device)
C = C.to(device)
D = D.to(device) if (D is not None and D.device != device) else D
batch_size, total_dim, seq_len = u.shape
K = B.shape[0]
d_state = B.shape[2]
d_model = total_dim // K
out = torch.zeros_like(u)
x = torch.zeros_like(u)
for b in range(batch_size):
for k in range(K):
for i in range(seq_len):
decay = torch.exp(-torch.exp(delta[b, k * d_model:(k + 1) * d_model, i]))
if i == 0:
x[b, k * d_model:(k + 1) * d_model, i] = u[b, k * d_model:(k + 1) * d_model, i]
else:
x[b, k * d_model:(k + 1) * d_model, i] = u[b, k * d_model:(k + 1) * d_model, i] + \
decay * x[b, k * d_model:(k + 1) * d_model, i - 1]
out[b, k * d_model:(k + 1) * d_model, i] = x[b, k * d_model:(k + 1) * d_model, i]
if D is not None:
for k in range(K):
out[:, k * d_model:(k + 1) * d_model, :] += D[k * d_model:(k + 1) * d_model].view(1, -1, 1) * u[:,
k * d_model:(
k + 1) * d_model,
:]
return out, x, None, None, None, None, None
@staticmethod
def bwd(u, delta, A, B, C, D, *args, **kwargs):
device = u.device
delta = delta.to(device)
A = A.to(device)
B = B.to(device)
C = C.to(device)
D = D.to(device) if (D is not None and D.device != device) else D
return (torch.zeros_like(u), torch.zeros_like(delta), torch.zeros_like(A),
torch.zeros_like(B), torch.zeros_like(C),
torch.zeros_like(D) if D is not None else None,
None, None, None, None, None, None)
# 提供 selective_scan_fn 的替代实现
def selective_scan_fn(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, return_last_state=False):
""" 替代 selective_scan_fn 的简单实现 """
# 简化实现,仅用于演示目的
# 实际使用时应替换为高效实现
batch_size, seq_len, dim = u.shape
d_state = A.shape[-1]
# 初始化状态
x = torch.zeros(batch_size, dim, d_state, device=u.device)
outputs = []
for i in range(seq_len):
# 计算离散时间步长
d = delta[:, i]
if delta_softplus:
d = F.softplus(d + (delta_bias if delta_bias is not None else 0))
# 状态更新
x = x * torch.exp(-d.unsqueeze(-1) * A.unsqueeze(0)) + u[:, i].unsqueeze(-1) * B[:, i].unsqueeze(1)
# 输出计算
y = torch.einsum('bnd,bnd->bn', x, C[:, i])
if D is not None:
y += u[:, i] * D
outputs.append(y.unsqueeze(1))
y = torch.cat(outputs, dim=1)
return y
# 使用模拟的 selective_scan_cuda
try:
import selective_scan_cuda
except ImportError:
selective_scan_cuda = MockSelectiveScanCuda()
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
class EfficientMerge(torch.autograd.Function):
@staticmethod
def forward(ctx, ys: torch.Tensor, ori_h: int, ori_w: int, step_size=2):
B, K, C, L = ys.shape
H, W = math.ceil(ori_h / step_size), math.ceil(ori_w / step_size)
ctx.shape = (H, W)
ctx.ori_h = ori_h
ctx.ori_w = ori_w
ctx.step_size = step_size
new_h = H * step_size
new_w = W * step_size
y = ys.new_empty((B, C, new_h, new_w))
y[:, :, ::step_size, ::step_size] = ys[:, 0].reshape(B, C, H, W)
y[:, :, 1::step_size, ::step_size] = ys[:, 1].reshape(B, C, W, H).transpose(dim0=2, dim1=3)
y[:, :, ::step_size, 1::step_size] = ys[:, 2].reshape(B, C, H, W)
y[:, :, 1::step_size, 1::step_size] = ys[:, 3].reshape(B, C, W, H).transpose(dim0=2, dim1=3)
if ori_h != new_h or ori_w != new_w:
y = y[:, :, :ori_h, :ori_w].contiguous()
y = y.view(B, C, -1)
return y
@staticmethod
def backward(ctx, grad_x: torch.Tensor):
H, W = ctx.shape
B, C, L = grad_x.shape
step_size = ctx.step_size
grad_x = grad_x.view(B, C, ctx.ori_h, ctx.ori_w)
if ctx.ori_w % step_size != 0:
pad_w = step_size - ctx.ori_w % step_size
grad_x = F.pad(grad_x, (0, pad_w, 0, 0))
W = grad_x.shape[3]
if ctx.ori_h % step_size != 0:
pad_h = step_size - ctx.ori_h % step_size
grad_x = F.pad(grad_x, (0, 0, 0, pad_h))
H = grad_x.shape[2]
B, C, H, W = grad_x.shape
H = H // step_size
W = W // step_size
grad_xs = grad_x.new_empty((B, 4, C, H * W))
grad_xs[:, 0] = grad_x[:, :, ::step_size, ::step_size].reshape(B, C, -1)
grad_xs[:, 1] = grad_x.transpose(dim0=2, dim1=3)[:, :, ::step_size, 1::step_size].reshape(B, C, -1)
grad_xs[:, 2] = grad_x[:, :, ::step_size, 1::step_size].reshape(B, C, -1)
grad_xs[:, 3] = grad_x.transpose(dim0=2, dim1=3)[:, :, 1::step_size, 1::step_size].reshape(B, C, -1)
return grad_xs, None, None, None
class SelectiveScan(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1):
assert nrows in [1, 2, 3, 4], f"{nrows}"
ctx.delta_softplus = delta_softplus
ctx.nrows = nrows
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if B.dim() == 3:
B = B.unsqueeze(dim=1)
ctx.squeeze_B = True
if C.dim() == 3:
C = C.unsqueeze(dim=1)
ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout, *args):
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, False
)
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None)
class EfficientScan(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, step_size=2):
B, C, org_h, org_w = x.shape
ctx.shape = (B, C, org_h, org_w)
ctx.step_size = step_size
if org_w % step_size != 0:
pad_w = step_size - org_w % step_size
x = F.pad(x, (0, pad_w, 0, 0))
W = x.shape[3]
if org_h % step_size != 0:
pad_h = step_size - org_h % step_size
x = F.pad(x, (0, 0, 0, pad_h))
H = x.shape[2]
H = H // step_size
W = W // step_size
xs = x.new_empty((B, 4, C, H * W))
xs[:, 0] = x[:, :, ::step_size, ::step_size].contiguous().view(B, C, -1)
xs[:, 1] = x.transpose(dim0=2, dim1=3)[:, :, ::step_size, 1::step_size].contiguous().view(B, C, -1)
xs[:, 2] = x[:, :, ::step_size, 1::step_size].contiguous().view(B, C, -1)
xs[:, 3] = x.transpose(dim0=2, dim1=3)[:, :, 1::step_size, 1::step_size].contiguous().view(B, C, -1)
xs = xs.view(B, 4, C, -1)
return xs
@staticmethod
def backward(ctx, grad_xs: torch.Tensor):
B, C, org_h, org_w = ctx.shape
step_size = ctx.step_size
newH, newW = math.ceil(org_h / step_size), math.ceil(org_w / step_size)
grad_x = grad_xs.new_empty((B, C, newH * step_size, newW * step_size))
grad_xs = grad_xs.view(B, 4, C, newH, newW)
grad_x[:, :, ::step_size, ::step_size] = grad_xs[:, 0].reshape(B, C, newH, newW)
grad_x[:, :, 1::step_size, ::step_size] = grad_xs[:, 1].reshape(B, C, newW, newH).transpose(dim0=2, dim1=3)
grad_x[:, :, ::step_size, 1::step_size] = grad_xs[:, 2].reshape(B, C, newH, newW)
grad_x[:, :, 1::step_size, 1::step_size] = grad_xs[:, 3].reshape(B, C, newW, newH).transpose(dim0=2, dim1=3)
if org_h != grad_x.shape[-2] or org_w != grad_x.shape[-1]:
grad_x = grad_x[:, :, :org_h, :org_w]
return grad_x, None
def cross_selective_scan(
x: torch.Tensor = None,
x_proj_weight: torch.Tensor = None,
x_proj_bias: torch.Tensor = None,
dt_projs_weight: torch.Tensor = None,
dt_projs_bias: torch.Tensor = None,
A_logs: torch.Tensor = None,
Ds: torch.Tensor = None,
out_norm: torch.nn.Module = None,
nrows=-1,
delta_softplus=True,
to_dtype=True,
step_size=2,
):
B, D, H, W = x.shape
D, N = A_logs.shape
K, D_in, R = dt_projs_weight.shape
L = H * W
if nrows < 1:
if D % 4 == 0:
nrows = 4
elif D % 3 == 0:
nrows = 3
elif D % 2 == 0:
nrows = 2
else:
nrows = 1
ori_h, ori_w = H, W
xs = EfficientScan.apply(x, step_size)
H = math.ceil(H / step_size)
W = math.ceil(W / step_size)
L = H * W
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight)
if x_proj_bias is not None:
x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight)
xs = xs.view(B, -1, L).to(torch.float)
dts = dts.contiguous().view(B, -1, L).to(torch.float)
As = -torch.exp(A_logs.to(torch.float))
Bs = Bs.contiguous().to(torch.float)
Cs = Cs.contiguous().to(torch.float)
Ds = Ds.to(torch.float)
delta_bias = dt_projs_bias.view(-1).to(torch.float)
def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1):
return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows)
ys: torch.Tensor = selective_scan(
xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus, nrows,
).view(B, K, -1, L)
ori_h, ori_w = int(ori_h), int(ori_w)
y = EfficientMerge.apply(ys, ori_h, ori_w, step_size)
H = ori_h
W = ori_w
L = H * W
y = y.transpose(dim0=1, dim1=2).contiguous()
y = out_norm(y).view(B, H, W, -1)
return (y.to(x.dtype) if to_dtype else y)
class SS2D(nn.Module):
def __init__(
self,
d_model=96,
d_state=16,
ssm_ratio=2.0,
ssm_rank_ratio=2.0,
dt_rank="auto",
act_layer=nn.SiLU,
d_conv=3,
conv_bias=True,
dropout=0.0,
bias=False,
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
simple_init=False,
forward_type="v2",
step_size=2,
**kwargs,
):
factory_kwargs = {"device": None, "dtype": None}
super().__init__()
d_expand = int(ssm_ratio * d_model)
d_inner = int(min(ssm_rank_ratio, ssm_ratio) * d_model) if ssm_rank_ratio > 0 else d_expand
self.dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
self.d_state = math.ceil(d_model / 6) if d_state == "auto" else d_state
self.d_conv = d_conv
self.d_inner = d_inner # 保存d_inner供后续使用
self.step_size = step_size
self.disable_z_act = forward_type[-len("nozact"):] == "nozact"
if self.disable_z_act:
forward_type = forward_type[:-len("nozact")]
if forward_type[-len("softmax"):] == "softmax":
forward_type = forward_type[:-len("softmax")]
self.out_norm = nn.Softmax(dim=1)
elif forward_type[-len("sigmoid"):] == "sigmoid":
forward_type = forward_type[:-len("sigmoid")]
self.out_norm = nn.Sigmoid()
else:
self.out_norm = nn.LayerNorm(d_inner)
self.forward_core = dict(
v0=self.forward_corev0,
v0_seq=self.forward_corev0_seq,
v1=self.forward_corev2,
v2=self.forward_corev2,
share_ssm=self.forward_corev0_share_ssm,
share_a=self.forward_corev0_share_a,
).get(forward_type, self.forward_corev2)
self.K = 4 if forward_type not in ["share_ssm"] else 1
self.K2 = self.K if forward_type not in ["share_a"] else 1
self.in_proj = nn.Linear(d_model, d_expand * 2, bias=bias, **factory_kwargs)
self.act: nn.Module = act_layer()
if self.d_conv > 1:
self.conv2d = nn.Conv2d(
in_channels=d_expand,
out_channels=d_expand,
groups=d_expand,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
self.ssm_low_rank = False
if d_inner < d_expand:
self.ssm_low_rank = True
self.in_rank = nn.Conv2d(d_expand, d_inner, kernel_size=1, bias=False, **factory_kwargs)
self.out_rank = nn.Linear(d_inner, d_expand, bias=False, **factory_kwargs)
# 修复维度不匹配问题:使用d_inner作为输入维度
self.x_proj = [
nn.Linear(d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs)
for _ in range(self.K)
]
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0))
del self.x_proj
self.dt_projs = [
self.dt_init(self.dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs)
for _ in range(self.K)
]
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0))
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0))
del self.dt_projs
self.A_logs = self.A_log_init(self.d_state, d_inner, copies=self.K2, merge=True)
self.Ds = self.D_init(d_inner, copies=self.K2, merge=True)
self.out_proj = nn.Linear(d_expand, d_model, bias=bias, **factory_kwargs)
self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
if simple_init:
self.Ds = nn.Parameter(torch.ones((self.K2 * d_inner)))
self.A_logs = nn.Parameter(
torch.randn((self.K2 * d_inner, self.d_state)))
self.dt_projs_weight = nn.Parameter(torch.randn((self.K, d_inner, self.dt_rank)))
self.dt_projs_bias = nn.Parameter(torch.randn((self.K, d_inner)))
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
**factory_kwargs):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
dt_init_std = dt_rank ** -0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
dt = torch.exp(
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
return dt_proj
@staticmethod
def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A)
if copies > 0:
A_log = repeat(A_log, "d n -> r d n", r=copies)
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def D_init(d_inner, copies=-1, device=None, merge=True):
D = torch.ones(d_inner, device=device)
if copies > 0:
D = repeat(D, "n1 -> r n1", r=copies)
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D)
D._no_weight_decay = True
return D
def forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False):
# 使用替代的 selective_scan_fn
selective_scan = selective_scan_fn
if not channel_first:
x = x.permute(0, 3, 1, 2).contiguous()
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
xs = xs.float().view(B, -1, L)
dts = dts.contiguous().float().view(B, -1, L)
Bs = Bs.float()
Cs = Cs.float()
As = -torch.exp(self.A_logs.float())
Ds = self.Ds.float()
dt_projs_bias = self.dt_projs_bias.float().view(-1)
out_y = selective_scan(
xs, dts, As, Bs, Cs, Ds,
delta_bias=dt_projs_bias,
delta_softplus=True,
).view(B, K, -1, L)
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
y = torch.stack([out_y[:, 0], wh_y, inv_y[:, 0], invwh_y], dim=1).view(B, -1, L)
y = y.transpose(dim0=1, dim1=2).contiguous()
y = self.out_norm(y).view(B, H, W, -1)
return y.to(x.dtype) if to_dtype else y
def forward_corev0_seq(self, x: torch.Tensor, to_dtype=False, channel_first=False):
# 使用替代的 selective_scan_fn
selective_scan = selective_scan_fn
if not channel_first:
x = x.permute(0, 3, 1, 2).contiguous()
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
xs = xs.float().view(B, -1, L)
dts = dts.contiguous().float().view(B, -1, L)
Bs = Bs.float()
Cs = Cs.float()
As = -torch.exp(self.A_logs.float())
Ds = self.Ds.float()
dt_projs_bias = self.dt_projs_bias.float().view(-1)
out_y = selective_scan(
xs, dts, As, Bs, Cs, Ds,
delta_bias=dt_projs_bias,
delta_softplus=True,
).view(B, K, -1, L)
y = out_y[:, 0].view(B, -1, L)
y = y.transpose(dim0=1, dim1=2).contiguous()
y = self.out_norm(y).view(B, H, W, -1)
return y.to(x.dtype) if to_dtype else y
def forward_corev0_share_ssm(self, x: torch.Tensor, to_dtype=False, channel_first=False):
# 使用替代的 selective_scan_fn
selective_scan = selective_scan_fn
if not channel_first:
x = x.permute(0, 3, 1, 2).contiguous()
B, C, H, W = x.shape
L = H * W
K = 1
x_hwwh = x.view(B, -1, L).unsqueeze(1)
xs = x_hwwh
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
xs = xs.float().view(B, -1, L)
dts = dts.contiguous().float().view(B, -1, L)
Bs = Bs.float()
Cs = Cs.float()
As = -torch.exp(self.A_logs.float())
Ds = self.Ds.float()
dt_projs_bias = self.dt_projs_bias.float().view(-1)
out_y = selective_scan(
xs, dts, As, Bs, Cs, Ds,
delta_bias=dt_projs_bias,
delta_softplus=True,
).view(B, K, -1, L)
y = out_y[:, 0].view(B, -1, L)
y = y.transpose(dim0=1, dim1=2).contiguous()
y = self.out_norm(y).view(B, H, W, -1)
return y.to(x.dtype) if to_dtype else y
def forward_corev0_share_a(self, x: torch.Tensor, to_dtype=False, channel_first=False):
# 使用替代的 selective_scan_fn
selective_scan = selective_scan_fn
if not channel_first:
x = x.permute(0, 3, 1, 2).contiguous()
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
xs = xs.float().view(B, -1, L)
dts = dts.contiguous().float().view(B, -1, L)
Bs = Bs.float()
Cs = Cs.float()
As = -torch.exp(self.A_logs.float())
Ds = self.Ds.float()
dt_projs_bias = self.dt_projs_bias.float().view(-1)
As = As.repeat(K, 1)
Ds = Ds.repeat(K)
out_y = selective_scan(
xs, dts, As, Bs, Cs, Ds,
delta_bias=dt_projs_bias,
delta_softplus=True,
).view(B, K, -1, L)
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
y = torch.stack([out_y[:, 0], wh_y, inv_y[:, 0], invwh_y], dim=1).view(B, -1, L)
y = y.transpose(dim0=1, dim1=2).contiguous()
y = self.out_norm(y).view(B, H, W, -1)
return y.to(x.dtype) if to_dtype else y
def forward_corev2(self, x: torch.Tensor, to_dtype=False, channel_first=False):
if not channel_first:
x = x.permute(0, 3, 1, 2).contiguous()
B, C, H, W = x.shape
if self.ssm_low_rank:
x = self.in_rank(x)
y = cross_selective_scan(
x,
self.x_proj_weight,
None,
self.dt_projs_weight,
self.dt_projs_bias,
self.A_logs,
self.Ds,
self.out_norm,
to_dtype=to_dtype,
step_size=self.step_size,
)
if self.ssm_low_rank:
y = self.out_rank(y)
if not channel_first:
y = y.permute(0, 3, 1, 2).contiguous()
return y
def forward(self, x: torch.Tensor):
B, H, W, C = x.shape
x = self.in_proj(x)
x, gate = x.chunk(2, dim=-1)
x = self.act(x) * gate
x = x.permute(0, 3, 1, 2).contiguous()
if self.d_conv > 1:
x = self.conv2d(x)
x = x.permute(0, 2, 3, 1).contiguous()
x = self.forward_core(x)
x = self.dropout(self.out_proj(x))
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class VSSBlock(nn.Module):
def __init__(
self,
hidden_dim,
drop_path=0.0,
norm_layer=nn.LayerNorm,
attn_drop_rate=0.0,
d_state=16,
ssm_ratio=2.0,
ssm_rank_ratio=2.0,
dt_rank="auto",
d_conv=3,
step_size=2,
act_layer=nn.GELU,
):
super().__init__()
self.norm1 = norm_layer(hidden_dim)
self.ssm = SS2D(
d_model=hidden_dim,
d_state=d_state,
ssm_ratio=ssm_ratio,
ssm_rank_ratio=ssm_rank_ratio,
dt_rank=dt_rank,
act_layer=act_layer,
d_conv=d_conv,
dropout=attn_drop_rate,
step_size=step_size,
)
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
self.norm2 = norm_layer(hidden_dim)
mlp_hidden_dim = int(hidden_dim * 4.0)
self.mlp = Mlp(
in_features=hidden_dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=attn_drop_rate
)
def forward(self, x, H, W):
B, L, C = x.shape
assert L == H * W, "输入特征长度必须等于H*W"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
x = self.ssm(x)
x = x.reshape(B, L, C)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class VSSBlock_Cross(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ssm_ratio=2.0,
ssm_rank_ratio=2.0,
dt_rank="auto",
d_state=16,
d_conv=3,
step_size=2,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.ssm = SS2D(
d_model=dim,
d_state=d_state,
ssm_ratio=ssm_ratio,
ssm_rank_ratio=ssm_rank_ratio,
dt_rank=dt_rank,
act_layer=act_layer,
d_conv=d_conv,
dropout=drop,
step_size=step_size,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, H, W):
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
x = self.ssm(x)
x = x.reshape(B, H * W, C)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
# 测试代码
if __name__ == "__main__":
print("\n测试VSSBlock...")
hidden_dim = 96
block = VSSBlock(
hidden_dim=hidden_dim,
d_state=16,
drop_path=0.1
).to(device)
B, H, W = 2, 8, 8
x = torch.randn(B, H * W, hidden_dim).to(device)
with torch.no_grad():
output = block(x, H, W)
print(f"输入形状: {x.shape} (设备: {x.device})")
print(f"输出形状: {output.shape} (设备: {output.device})")
print("\n测试VSSBlock_Cross...")
dim = 96
block_cross = VSSBlock_Cross(
dim=dim,
num_heads=8,
mlp_ratio=4.0,
drop=0.0,
drop_path=0.1,
ssm_ratio=2.0,
ssm_rank_ratio=2.0,
d_state=16,
d_conv=3,
step_size=2,
).to(device)
x_cross = torch.randn(B, H * W, dim).to(device)
with torch.no_grad():
output_cross = block_cross(x_cross, H, W)
print(f"输入形状: {x_cross.shape} (设备: {x_cross.device})")
print(f"输出形状: {output_cross.shape} (设备: {output_cross.device})")
print("\n测试SS2D核心功能...")
ssm = SS2D(
d_model=hidden_dim,
d_state=16,
ssm_ratio=2.0,
ssm_rank_ratio=2.0,
dt_rank="auto",
step_size=2,
).to(device)
x_ssm = torch.randn(B, H, W, hidden_dim).to(device)
with torch.no_grad():
output_ssm = ssm(x_ssm)
print(f"输入形状: {x_ssm.shape} (设备: {x_ssm.device})")
print(f"输出形状: {output_ssm.shape} (设备: {output_ssm.device})")
怎么把这个代码加入CBMA注意力机制,并给我完整代码