实践:将cuda代码转化成hip代码,并编写为pytorch的extension
将cuda代码转化成hip代码,并编写为pytorch的extension
目的:将fused-attention这份代码,由cuda转化到hip。并且让转化后的hip代码也能编写为pytorch的extension。
难点:fused-attention 使用了wmma,如何让这种调库的代码编写为pytorch的extension。
首先将代码转到hip并能成功执行
- 下载代码到本地
git clone https://github.com/kst179/fused-attention.git
- 使用hipify工具转化代码
/opt/rocm/bin/hipconvertinplace-perl.sh [路径] # 这里路径指向fused attention本地地址
- 修改wmma相关代码
- wmma 替换成rocwmma
- #include <mma.h> 替换成 #include <rocwmma/rocwmma.hpp>
- warp_size 改成64
- 去掉所有的nvcuda
- 删掉多余的配置代码(AMD的共享内存不需要这样设置)
CHECK_CUDA_ERROR(cudaFuncSetAttribute(
attention_kernel<head_dim, chunk_size, n_warps>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size));
- 修改half计算相关代码
- 修改half的定义
using half_t = half;
using half2_t = half2;
- half2不能使用.x和.y获取元素,修改成__low2half()和__high2half()
if (threads_per_row == 1) {
vec[row] = __hmax(threadwise_val.x, threadwise_val.y);
return;
}
...
if (thread_idx < height) {
threadwise_val = *(half2_t*)&aux[thread_idx * ldm_aux];
#pragma unroll
for (uint32_t i = 1; i < threads_per_row; ++i) {
threadwise_val = __hmax2(threadwise_val, *(half2_t*)&aux[thread_idx * ldm_aux + i * elements_per_storage]);
}
vec[thread_idx] = __hmax(threadwise_val

本文详细描述了如何将cuda代码转换为hip,并将其作为pytorch的extension,涉及wmma库的迁移、hipify工具使用、half类型处理以及遇到的问题和解决方案。作者还提供了测试步骤和可能遇到的错误,如outofmemory和精度分析。
最低0.47元/天 解锁文章
3231

被折叠的 条评论
为什么被折叠?



