【注意:mamba就不是给windows玩的,linux下环境配置很方便,少有报错,基本都不是啥大问题,大伙别硬在windows上浪费时间了,装个Ubuntu复现起来很快的】
mamba_ssm和causal_conv1d模块的配置方法见博文:
【Windows系统下Mamba环境配置】triton 、causal conv1d和mamba_ssm模块配置保姆教程
继上述成功安装mamba_ssm和causal-conv1d模块后,运行环境发现此两模块仍有如下报错:
File "E:\install\Anaconda\envs\SR\Lib\site-packages\mamba_ssm\ops\selective_scan_interface.py", line 10, in <module>
import causal_conv1d_cuda
ModuleNotFoundError: No module named 'causal_conv1d_cuda'
至今,未搞清楚此报错原因。但通过以下方式初步解决了这两个报错问题!
1.找不到“causal_conv1d_cuda”
(1)首先,进入conda环境,卸载causal-conv1d模块,保留本地路径下的安装文件用于后续修改和再次安装。
(2)从本地文件夹中找到causal_conv1d_interface.py文件,进入修改
(3)注释掉xx_cuda的模块导入
(4)将其causal_conv1d_fn部分函数代码用以下进行替换(注释部分为替换前源码,未注释部分为替换后)
# def causal_conv1d_fn(x, weight, bias=None, activation=None):
# """
# x: (batch, dim, seqlen)
# weight: (dim, width)
# bias: (dim,)
# activation: either None or "silu" or "swish"
#
# out: (batch, dim, seqlen)
# """
# return CausalConv1dFn.apply(x, weight, bias, activation)
def causal_conv1d_fn(x,weight,bias=None,seq_idx=None,initial_states=None,return_final_states=False,final_states_out=None,activation=None,):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
#### 修改在这里!!!!!
return causal_conv1d_ref(x,weight,bias,seq_idx,initial_states,return_final_states,final_states_out,activation,)
(5)按照此前模块的安装方法,重新安装此模块即可。
2.selective_scan_cuda
(1)首先,在pytorch环境下卸载mamba_ssm模块(上次安装模块时候的文件保留,用于后续修改和重新安装mamba_ssm模块)
(2)根据文件存放路径,找到mamba_ssm - > ops - >selective_scan_interface.py文件,打开
(3)将文件模块导入部分,含xx_cuda的此两句注释掉
(4)将 selective_scan_fn函数部分代码做如下修改:(注释部分为修改前,未注释部分为修改后)
# def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
# return_last_state=False):
# """if return_last_state is True, returns (out, last_state)
# last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
# not considered in the backward pass.
# """
# return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
同理,将mamba_inner_fn也做修改替换
# def mamba_inner_fn(
# xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
# out_proj_weight, out_proj_bias,
# A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
# C_proj_bias=None, delta_softplus=True):
# return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
# out_proj_weight, out_proj_bias,
# A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
def mamba_inner_fn(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True):
return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
(5)修改完成后,按照前一篇博文mamba_ssm的安装方法重新安装模块即可
3.模块验证
上述模块修改和重新安装完成后,运行代码会出现如下警告:
No Module named “causal_conv1d_cuda”、“selective_scan_cuda”报错问题初步解决,报错原因尚未弄清,有大佬请评论区指点!
后续弄清再进行更新!(给个三连吧!)
至于上图中的警告,后续继续更新!