From 0cd49949e6c05a0c8f63d050ff96c8f6168cf914 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 20 Jul 2024 17:24:57 -0700 Subject: [PATCH] bugfix: fix sampling API's behavior on cu118 (#386) As observed in #384 , we should use different variables for input and output for `FlagHeads` API in cu118. --- include/flashinfer/sampling.cuh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 2df38d24..1a2aebd1 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -112,18 +112,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb( greater_than_u[j] = inclusive_cdf[j] + aggregate > u; } + bool greater_than_u_diff[VEC_SIZE]; #ifdef FLASHINFER_CUB_SUBTRACTLEFT_DEFINED BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .SubtractLeft(greater_than_u, greater_than_u, BoolDiffOp()); + .SubtractLeft(greater_than_u_diff, greater_than_u, BoolDiffOp()); #else BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .FlagHeads(greater_than_u, greater_than_u, BoolDiffOp()); + .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp()); #endif __syncthreads(); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - if (greater_than_u[j] && valid[j]) { + if (greater_than_u_diff[j] && valid[j]) { atomicMin(&(temp_storage->data.sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j); } }