文章目录
👇👇👇👇👇💕💕💕💕💕👇👇👇👇👇
团队公众号 Deep Mapping致力于为大家分享人工智能算法的基本原理与最新进展,探索深度学习与测绘行业的创新结合,展示团队的最新研究成果,并及时报道行业会议动态。每周更新,持续为你带来前沿科技资讯、实用技巧和深入解读。感谢每一位关注我们的读者,期待与大家一起见证更多精彩时刻!
👇👇👇👇👇💕💕💕💕💕👇👇👇👇👇

——————
原项目中并未考虑在Win下配置,因此需要对配置过程进行调整
——————
0 运行前提
- NVIDIA GPU
- PyTorch 1.12+
- CUDA 11.6+
1 配置cuda和cudnn
日常训练时我们可以安装高版本cuda、然后在conda环境中兼容低版本,但是Mamba配置时需要安装对等的版本 [Link]
2 配置VS2019
VS版本太新看会导致在配置causal-conv1d
时报错 [Link]
3 创建conda环境
(1) 创建一个python 3.10.13
的环境
conda create -n mamba_envs python=3.10.13
(2) 安装CUDA Toolkit 11.8
、PyTorch 2.1.1
、CUDA NVCC
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
(3) 安装packaging
conda install packaging
(4) 安装setuptools
pip install setuptools==68.2.2
(5) 安装Triton
[Link]
Triton 是一种用于高效编写自定义 GPU 内核的开源编译器和深度学习库。由 OpenAI 开发,主要用于加速在深度学习中常见的矩阵乘法、卷积和其他算子操作,同时使开发者能够编写高效的 GPU 代码而不需要复杂的 CUDA 编程经验。
4 配置causal-conv1d
git clone https://github.com/Dao-AILab/causal-conv1d.git
cd causal-conv1d
git checkout v1.1.1 # current latest version tag
pip install .
5 配置mamba
cd..
git clone https://github.com/state-spaces/mamba.git
cd mamba
git checkout v1.1.1
———————
setup.py修改
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "TRUE") == "TRUE"
selective_scan_fwd_kernel.cuh修改
(1) 头部加入
#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074
#endif
(2) 替换void selective_scan_fwd_launch函数
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
// processing 1 row.
static constexpr int kNRows = 1;
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
// constexpr int kSmemSize = Ktraits::kSmemSize;
static constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
// printf("smem_size = %d\n", kSmemSize);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
}
selective_scan_bwd_kernel.cuh修改
(1) 头部加入
#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074
#endif
static_switch.h修改
(1) 替换BOOL_SWITCH函数
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
———————
pip install .
6 测试
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print('success')