flash-linear-attention CUDA算子成功实现(但限制极多。。)

博客介绍了在C++/DirectX着色器基础上速成CUDA编程,虽能入门但高级加速不足。给出代码仓库,对比了纯pytorch实现,指出其显存消耗问题。重点介绍5种线性注意力算子写法,包括原始方式、分块方式等,对比了它们的可读性、显存消耗和速度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在 C++/DirectX着色器 的基础上速成CUDA编程,还好思维模式基本通用,就多了线程组排布和共享内存方面的东西,入门还行,高级加速方面就不太行了。

代码仓库:https://github.com/One-sixth/flash-linear-attention-pytorch

虽然相比纯 pytorch 的实现,更大幅度地减少了显存的消耗,但是速度没有办法,缺乏精力和时间去深究 CUDA优化,非常消耗时间。
算了。作为其他类型的 线性注意力算子的起步

5种写法的算子介绍

normal_linear_attention_ops.py
原始方式,显存占用最大,速度最快
可读性:最佳
显存消耗:1X (O^2)
速度:1X

flash_linear_attention_ops.py
原始分块方式
内部使用torch.split,不需要填充到指定长度
可读性:佳
显存消耗:0.7X
速度:0.5X

flash_linear_attention_ops_2.py
基于 flash_linear_attention_ops.py 改为块索引方式,略快一丁点
内部需要填充到指定倍数长度
可读性:佳
显存消耗:0.7X
速度:0.505X

flash_linear_attention_ops_3.py
基于 flash_linear_attention_ops_2.py 加入显式内存复用方式,略快一丁点
即在一开始就分配所有需要的显存,在计算过程中,完全不需要新的显存分配
内部需要填充到指定倍数长度
可读性:中
显存消耗:0.7X
速度:0.51X

flash_linear_attention_ops_4.py
基于 flash_linear_attention_ops_3.py,改为CUDA/C++算子方式
本人的CUDA/C++技术有限,没有精力继续研究了
内部需要填充到指定倍数长度
限制很多,不支持float32以外的数据类型
可以作为其他类型线性注意力的参考实现
算子已经通过 pytorch2.1 + CUDA12.1 环境测试
可读性:较差
显存消耗:0.3X
速度:0.33X

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值