Skip to content

Commit

Permalink
bugfix: fix sampling API's behavior on cu118 (#386)
Browse files Browse the repository at this point in the history
As observed in #384 , we should use different variables for input and
output for `FlagHeads` API in cu118.
  • Loading branch information
yzh119 authored Jul 21, 2024
1 parent b64d5c9 commit 0cd4994
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
greater_than_u[j] = inclusive_cdf[j] + aggregate > u;
}

bool greater_than_u_diff[VEC_SIZE];
#ifdef FLASHINFER_CUB_SUBTRACTLEFT_DEFINED
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u, BoolDiffOp());
.SubtractLeft<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp());
#else
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.FlagHeads<VEC_SIZE>(greater_than_u, greater_than_u, BoolDiffOp());
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp());
#endif
__syncthreads();

#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
if (greater_than_u[j] && valid[j]) {
if (greater_than_u_diff[j] && valid[j]) {
atomicMin(&(temp_storage->data.sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
}
}
Expand Down

0 comments on commit 0cd4994

Please # to comment.