Mamba学习笔记(5)——快速配置(Win版,含Heardware-aware)


👇👇👇👇👇💕💕💕💕💕👇👇👇👇👇
团队公众号 Deep Mapping致力于为大家分享人工智能算法的基本原理与最新进展,探索深度学习与测绘行业的创新结合,展示团队的最新研究成果,并及时报道行业会议动态。每周更新,持续为你带来前沿科技资讯、实用技巧和深入解读。感谢每一位关注我们的读者,期待与大家一起见证更多精彩时刻!
👇👇👇👇👇💕💕💕💕💕👇👇👇👇👇
在这里插入图片描述
——————
原项目中并未考虑在Win下配置,因此需要对配置过程进行调整
——————

0 运行前提

  • NVIDIA GPU
  • PyTorch 1.12+
  • CUDA 11.6+

1 配置cuda和cudnn

日常训练时我们可以安装高版本cuda、然后在conda环境中兼容低版本,但是Mamba配置时需要安装对等的版本 [Link]

2 配置VS2019

VS版本太新看会导致在配置causal-conv1d时报错 [Link]

3 创建conda环境

(1) 创建一个python 3.10.13的环境

conda create -n mamba_envs python=3.10.13

(2) 安装CUDA Toolkit 11.8PyTorch 2.1.1CUDA NVCC

conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc

(3) 安装packaging

conda install packaging

(4) 安装setuptools

pip install setuptools==68.2.2

(5) 安装Triton [Link]

Triton 是一种用于高效编写自定义 GPU 内核的开源编译器和深度学习库。由 OpenAI 开发,主要用于加速在深度学习中常见的矩阵乘法、卷积和其他算子操作,同时使开发者能够编写高效的 GPU 代码而不需要复杂的 CUDA 编程经验。

4 配置causal-conv1d

git clone https://github.com/Dao-AILab/causal-conv1d.git
cd causal-conv1d
git checkout v1.1.1 # current latest version tag
pip install .

5 配置mamba

cd..
git clone https://github.com/state-spaces/mamba.git
cd mamba
git checkout v1.1.1

———————

setup.py修改

FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "TRUE") == "TRUE"

selective_scan_fwd_kernel.cuh修改

(1) 头部加入

#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074
#endif

(2) 替换void selective_scan_fwd_launch函数

void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
    // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
    // processing 1 row.
    static constexpr int kNRows = 1;
    BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
        BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
            BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
                BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
                    using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
                    // constexpr int kSmemSize = Ktraits::kSmemSize;
                    static constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
                    // printf("smem_size = %d\n", kSmemSize);
                    dim3 grid(params.batch, params.dim / kNRows);
                    auto kernel = &selective_scan_fwd_kernel<Ktraits>;
                    if (kSmemSize >= 48 * 1024) {
                        C10_CUDA_CHECK(cudaFuncSetAttribute(
                            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
                    }
                    kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
                    C10_CUDA_KERNEL_LAUNCH_CHECK();
                });
            });
        });
    });
}

selective_scan_bwd_kernel.cuh修改

(1) 头部加入

#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074
#endif

static_switch.h修改

(1) 替换BOOL_SWITCH函数

#define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \
    [&] {                                                                            \
        if (COND) {                                                                  \
            static constexpr bool CONST_NAME = true;                                        \
            return __VA_ARGS__();                                                    \
        } else {                                                                     \
            static constexpr bool CONST_NAME = false;                                       \
            return __VA_ARGS__();                                                    \
        }                                                                            \
    }()

———————

pip install .

6 测试

import torch
from mamba_ssm import Mamba
 
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print('success')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

RuiXuan Zhang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值