diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 49329fbd..0e0da12b 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -118,7 +118,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb( .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); #else BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp()); + .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); #endif __syncthreads();