使用SS2D写了一个简单的神经网络进行训练,但是训练报错:
NotImplementedError: You must implement either the backward or vjp method for your custom autograd.Function to use it with backward mode AD.
环境:
CUDA11.8
torch=2.0.0
mamba_ssm=2.0.2
causal-conv1d=1.2.1
Ubuntu 22.04 , Python3.11.8
import torch
import torch.nn as nn
from functools import partial
from typing import Callable
import math
from einops import repeat
from timm.models.layers import DropPath
from mmyolo.registry import MODELS # new
try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
except Exception as e:
# print(e, flush=True)
pass
try:
"sscore acts the same as mamba_ssm"
SSMODE = "sscore"
import selective_scan_cuda_core
except Exception as e:
print(e, flush=True)
"you should install mamba_ssm to use this"
SSMODE = "mamba_ssm"
import selective_scan_cuda
class SelectiveScanMamba(torch.autograd.Function):
# comment all checks if inside cross_selective_scan
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1,
oflex=True):
ctx.delta_softplus = delta_softplus
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout, *args):
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
False
)
return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
class SelectiveScanCore(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1,
oflex=True):
ctx.delta_softplus = delta_softplus
if SSMODE == "mamba_ssm":
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
else:
out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows)
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
# out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1)
# ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout, *args):
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
if SSMODE == "mamba_ssm":
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
False # option to recompute out_z, not used here
)
else:
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
)
return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
class SelectiveScanFake(torch.autograd.Function):
# comment all checks if inside cross_selective_scan
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1,
oflex=True):
ctx.delta_softplus = delta_softplus
ctx.backnrows = backnrows
x = delta
out = u
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout, *args):
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
du, ddelta, dA, dB, dC, dD, ddelta_bias = u * 0, delta * 0, A * 0, B * 0, C * 0, C * 0, (
D * 0 if D else None), (delta_bias * 0 if delta_bias else None)
return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
class CrossScan(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
B, C, H, W = x.shape
ctx.shape = (B, C, H, W)
# xs = x.new_empty((B, 4, C, H * W))
xs = x.new_empty((B, 8, C, H * W))
xs[:, 0] = x.flatten(2, 3)
xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
xs[:, 4] = diagonal_gather(x)
xs[:, 5] = antidiagonal_gather(x)
xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
# out: (b, k, d, l)
B, C, H, W = ctx.shape
L = H * W
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
# y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
y_rb = y_rb.view(B, -1, H, W)
y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L)
y_da = diagonal_scatter(y_da[:, 0], (B, C, H, W)) + antidiagonal_scatter(y_da[:, 1], (B, C, H, W))
y_res = y_rb + y_da
# return y.view(B, -1, H, W)
return y_res
def antidiagonal_gather(tensor):
B, C, H, W = tensor.size()
shift = torch.arange(H, device=tensor.device).unsqueeze(1)
index = (torch.arange(W, device=tensor.device) - shift) % W
expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
return tensor.gather(3, expanded_index).transpose(-1, -2).reshape(B, C, H * W)
def diagonal_gather(tensor):
B, C, H, W = tensor.size()
shift = torch.arange(H, device=tensor.device).unsqueeze(1)
index = (shift + torch.arange(W, device=tensor.device)) % W
expanded_index = index.unsqueez