我安装了mamba-ssm 2.2.5,系统CUDA驱动为12.8,已安装cudatoolkit,version=12.9。执行mamba检测代码如下:
```python
import torch
from mamba_ssm import Mamba
def test_mamba_basic():
# 检查CUDA是否可用
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
if device == "cpu":
print("警告: 未检测到CUDA,将使用CPU运行(速度会很慢)")
# 定义模型参数(最小化配置)
batch_size = 1
seq_len = 10 # 序列长度
d_model = 64 # 模型维度
n_layers = 1 # 仅1层
# 创建Mamba模型
model = Mamba(
d_model=d_model,
n_layers=n_layers,
vocab_size=None, # 不使用文本token嵌入
).to(device)
# 生成随机输入 (batch_size, seq_len, d_model)
x = torch.randn(batch_size, seq_len, d_model).to(device)
# 执行前向传播
try:
with torch.no_grad(): # 禁用梯度计算加速测试
output = model(x)
# 验证输出形状是否正确
assert output.shape == (batch_size, seq_len, d_model), "输出形状不正确"
print("✅ Mamba前向传播成功!")
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
except Exception as e:
print(f"❌ 运行失败: {str(e)}")
if __name__ == "__main__":
test_mamba_basic()
```
运行时报错:
```
/tmp/tmp6bkt1lgp/main.c:1:10: fatal error: cuda.h: No such file or directory
1 | #include "cuda.h"
| ^~~~~~~~
compilation terminated.
Traceback (most recent call last):
File "/home/xzj_24/MGCN/test_mamba.py", line 2, in <module>
from mamba_ssm import Mamba
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/mamba_ssm/__init__.py", line 3, in <module>
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 18, in <module>
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py", line 179, in <module>
def _layer_norm_fwd_1pass_kernel(
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 378, in decorator
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 130, in __init__
self.do_bench = driver.active.get_benchmarker()
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/triton/runtime/driver.py", line 23, in __getattr__
self._initialize_obj()
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/triton/runtime/driver.py", line 20, in _initialize_obj
self._obj = self._init_fn()
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/triton/runtime/driver.py", line 9, in _create_driver
return actives[0]()
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/triton/backends/nvidia/driver.py", line 536, in __init__
self.utils = CudaUtils() # TODO: make static
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/triton/backends/nvidia/driver.py", line 89, in __init__
mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils")
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/triton/backends/nvidia/driver.py", line 66, in compile_module_from_src
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/triton/runtime/build.py", line 58, in _build
subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
File "/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/subprocess.py", line 369, in check_call
raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmp6bkt1lgp/main.c', '-O3', '-shared', '-fPIC', '-Wno-psabi', '-o', '/tmp/tmp6bkt1lgp/cuda_utils.cpython-310-x86_64-linux-gnu.so', '-lcuda', '-L/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/triton/backends/nvidia/lib', '-L/lib/x86_64-linux-gnu', '-I/home/xzj_24/MGCN/.pixi/envs/default/lib/python3.10/site-packages/triton/backends/nvidia/include', '-I/tmp/tmp6bkt1lgp', '-I/home/xzj_24/MGCN/.pixi/envs/default/include/python3.10', '-I/home/xzj_24/MGCN/.pixi/envs/default/targets/x86_64-linux/include']' returned non-zero exit status 1.
```
已确认.zshrc中配置了包括但不限于CUDA, CUDA_HOME, LD_LIBRARY_PATH等环境变量,仍无法检测到。系统gcc version 为 11.
请给出排查思路与解决方案。