注:在此之前确保环境中已经安装packaging模块
pip install packaging
若此模块安装过程中出现图下报错:这是由于环境创建时候的权限问题导致的,具体解决方案见我另一篇博文!
1.Triton模块安装
注意:必须先安装triton,否则后续mamba_ssm安装会出错!
(1)Triton模块在huggingface上的下载地址:https://hf-mirror.com/madbuda/triton-windows-builds,选择对应python版本下载到本地。
(2)将下载好的文件放到方便的路径下:我习惯放到用户目录下,这样在anaconda prompt中安装比较方便;当然其他路径也可以,安装的时候记得cd到文件对应路径下安装即可。
(3)cd到模块文件路径下,pip install即可(温馨提示:文件名过长,可输入前几个字符后按Tab键自动补全!)
2.casual-conv1d安装
(1)模块下载地址:casual-conv1d
下载解压放到本地路径下:
(2)对文件中的setup.py文件做部分修改:
FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
修改为:
FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "FALSE"
(3)cd到文件路径下:输入pip install . 进行安装
3.mamba_ssm安装
务必保证triton模块成功安装后在安装mamba_ssm,否则会出现如下错误!
(1)下载地址:mamba_ssm
同样,下载好后解压放到本地路径
(2)修改setup.py部分代码:
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
修改为:
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
再修改mamba_ssm/ops/selective_scan_interface.py 两处代码:
第一处:
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
修改为:
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
第二处:
def mamba_inner_fn(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
):
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
修改为:
def mamba_inner_fn(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
):
return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
(3)文件准备就绪,cd到文件所在路径下,执行如下代码安装。
pip install .
(4)安装成功:
温馨提示:所有模块的安装,请先activate对应的pytorch环境!