Skip to content

Commit

Permalink
bugfix: fix cu118 cub usage (#410)
Browse files Browse the repository at this point in the history
Related issue: sgl-project/sglang#771

This PR fixes the usage of `FlagHeads` cub API in sampling kernels.
As
[documented](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockDiscontinuity.html),
the default FlagHeads api will always flag the first element, which is
not expected when first element is not `true`.
> For thread0, item input[0] is always flagged.

This PR sets the `tile_predecessor_item` argument (to 0) which will be
compared against input[0].

CUDA 12+ don't have this issue because we are using the new
`SubtractLeft` API instead of `FlagHeads`.
  • Loading branch information
yzh119 authored Jul 30, 2024
1 parent aaa929a commit 58d3593
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
#else
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp());
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#endif
__syncthreads();

Expand Down

0 comments on commit 58d3593

Please # to comment.