// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
// // GPU上运行函数kernel对应一个Grid,每个Grid内有多个Block,每个Block由多个Thread组成。
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
// 用于告诉编译器一个指针是唯一的,即没有别的指针指向同一块内存。这样编译器就可以进行更好的优化,例如在循环中使用 restrict 修饰的指针可以避免循环迭代时的指针别名问题,从而提高代码的性能
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y; // 这是处理的第几个seq blockDim.y = seq的总数量
const int partition_idx = blockIdx.z; // 单个头 单个seq 的第几个part
const int max_num_partitions = gridDim.z; // 单个头 单个seq 再分part
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block.
return;
}
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); // 计算seq需要的block数量 blocktable里的blocksize = 16
const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); // seq 除以 blocksize 向上取整
const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
// warp_size = 32 num_threads = 128 这里的threads 我猜应该是一个Block里的thread
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); // = 2
// 确定在一个warp中 线程组的大小, 一个warp32线程 blocksize=16(这个block是blocktable) 也就是两个线程一组 处理一个token
constexpr int NUM_THREAD_GROUPS = // 64 一个Block(核函数的block) 中有多少个线程组, 也就是说一个block处理多少个token
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = // 1 一个thread_group处理几个token
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; // 几个warp
const int thread_idx = threadIdx.x; // 块内线程数
const int warp_idx = thread_idx / WARP_SIZE; // 第几个warp
const int lane = thread_idx % WARP_SIZE; // warp 内线程数
const int head_idx = blockIdx.x; // 第几个头
const int num_heads = gridDim.x; // 一共多少头
const int num_queries_per_kv = num_heads / num_kv_heads; // 一对kv几个q用, GQA 一个组共享一对kv 降低kvcache的大小
const int kv_head_idx = head_idx / num_queries_per_kv; // 第几组kv
const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
// 规定一个线程组 每次都要拿到16Byte(字节)的数据来计算
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); // 4 线程组中的每个线程 要处理几个scalar_t类型的数据个数 如果是bf16 16bit 2byte ver_size=4
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; // 模板结构体 Q_vec = scalar_t[vec_size] 比如 bf16[4] 表示一组标量数据的向量类型
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; // 128 head_size代表单个head对一个token的embedding数 也就是字节 这些字节按照ver_size的大小得到一下行
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; // 32 需要32个线程组来处理 一个线程需要处理的vector数量, 上面算出一个vector 处理4个elem
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; // 处于第几个线程组 0-63
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; // q_stride = num_head*head_size,本例子中num_head=16,head_size=256 表示要越过前面所有seq所有head的q
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; // 两个线程 每个线程处理的vector [2][32]
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; // i = idx, i < 32, i += 64
// 处理的问题是 将head_size个即256个BFloat16数据,使用128个线程load进来
// 因为i<32 所以只有0-31 线程组去执行取数据操作 两个warp
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; // 0-63
q_vecs[thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); // 将q_ptr + vec_idx * VEC_SIZE指向地址的内容转变为Q_vec类型并赋值给q_vecs
// 一共处理256个数据, q_vecs =2*32=64个vector 一个vector放四个
} // 当前head下的seq的q向量,就被完全load到了共享内存q_vecs中
// 也是向量化的加载和处理,2个warp 32个线程组 64个线程 处理了64个vector 256个scalar_t 减少全局内存的访问次数,提高内存带宽利用率 在很多地方都有这种应用, sgemm ,warpsum
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t);
float qk_max = -FLT_MAX;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// block table的寻址过程,所有seq的block_table存放在以block_tables为首的连续地址空间中,当前seq的block table的地址需要越过之前所有seq所占用的block条目才能找到才能找到
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
// 当前的循环起始于start_block_idx + warp_idx,也就是0-31个线程并行执行当前循环的时候,block_idx初始值为0,第32-63个线程的block_idx初始值为1
// 剩下的2个线程束中的线程,block_idx初始值分别为2和3
// 所以在执行physical_blcok_number寻找的时候 第i个warp 分别访问的是当前head当前seq的第i个block
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; //
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
logits[token_idx - start_token_idx] = -FLT_MAX;
}
}
continue;
}
}
const int64_t physical_block_number = // 根据block_table 找到 physical_block_number
static_cast<int64_t>(block_table[block_idx]); // block_idx 存储为int32 但是转为int64 因为int32在与一些较大的值相乘时可能会溢出 比如kv_block_stride
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { // NUM_TOKENS_PER_THREAD_GROUP = 1 本例中1个线程组处理1个token
const int physical_block_offset = // 当前block的第几个token
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; // thread_group_idx 0-63 一组处理一个token
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; // 处理当前seq中token的绝对idx
K_vec k_vecs[NUM_VECS_PER_THREAD]; // NUM_VECS_PER_THREAD 32
// q_vecs = [2][32]
#pragma unroll
// 找到当前token所对应的k向量
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { // j < 32
const cache_t* k_ptr =
// key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x]
// kv_block_stride = key_cache.stride(0) = num_kv_heads * kv_head_stride = num_kv_heads *block_size *head_size = 16*16*256 = 65536
// 每个token16个头,一个block的 tokens即16个token所占bfloat16的总数量
// kv_head_stride = key_cache.stride(1) = (head_size/x) *(block_size) *(x) = 16*256 = 4096
// 本质上是一个block中所有的token 乘以 head_size,只算1个头,所占的bfloat16的数量,
// physical_block_number * kv_block_stride 是前面所有的blcok中的bfloat16的数量,进入了当前block存储的token的kv_cache范围内
// 进入了当前block存储的token的kv_cache范围内; kv_head_idx * kv_head_stride 当前处理的是seq的第kv_head_idx个head,需要越过之前的head;
k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x;
// x = 16 / sizeof(cache_t) 16表示每个线程组希望在每次内存访问k_cache的时候至少加载16个字节 x也就表示一次加载几个 cache_t的数据类型 如果是bf16 x=8
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; // 表示线程在线程组内的相对位置 vec_idx表示当前token的embedding分成了NUM_VECS_PER_THREAD个vec
const int offset1 = (vec_idx * VEC_SIZE) / x; // 0-31
const int offset2 = (vec_idx * VEC_SIZE) % x; // 0-7
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale);
}
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
// 其中q_vecs的维度也是[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD],是根据当前的线程组的线程数量、以及每个线程负责的vec的数量来分组的;
// 当前的k_vecs是存放的就是当前线程的[NUM_VECS_PER_THREAD]个vec。
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
// If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum;
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// context, we should explicitly zero out the values since they may
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
accs[i] += dot(logits_vec, v_vec);
}
}
}
// Perform reduction within each warp.
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]);
}
}
}
}
vLLM - PagedAttention
最新推荐文章于 2025-04-30 19:58:04 发布