Skip to content

Commit

Permalink
bugfix: Improve numerical stability of sampling kernels (#429)
Browse files Browse the repository at this point in the history
1. use `sum_of_probs_gt_pivot` rather than `sum_of_probs_leq_pivot`
2. make sure pivot will not decrease during iterations.
  • Loading branch information
yzh119 authored Aug 8, 2024
1 parent ddc1f09 commit 898d8ea
Showing 1 changed file with 37 additions and 51 deletions.
88 changes: 37 additions & 51 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,13 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,

vec_t<DType, VEC_SIZE> probs_vec;
DType aggregate;
DType q = DType(0);
DType q = DType(1);
DType pivot = DType(0);
IdType sampled_id;
for (uint32_t round = 0; round < max_top_k_rounds; ++round) {
temp_storage.data.sampled_id = d - 1;
__syncthreads();
DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q);
DType u = uniform_samples[round * batch_size + bx] * q;
aggregate = DType(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
Expand All @@ -314,42 +314,38 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}
__syncthreads();
sampled_id = temp_storage.data.sampled_id;
pivot = probs[bx * d + sampled_id];
pivot = max(pivot, probs[bx * d + sampled_id]);

Pair<DType> aggregate_leq_pivot{DType(0), 0};
Pair<DType> aggregate_gt_pivot{DType(0), 0};
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

Pair<DType> probs_leq_pivot[VEC_SIZE];
Pair<DType> probs_gt_pivot[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_leq_pivot[j] = {
(probs_vec[j] <= pivot) ? probs_vec[j] : DType(0),
(probs_vec[j] <= pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
(probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
}

aggregate_leq_pivot += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_pair)
.Sum<VEC_SIZE>(probs_leq_pivot);
aggregate_gt_pivot += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_pair)
.Sum<VEC_SIZE>(probs_gt_pivot);
if (tx == 0) {
temp_storage.data.block_aggregate.pair = aggregate_leq_pivot;
temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
}
__syncthreads();
if (temp_storage.data.block_aggregate.pair.count + k > d) {
break;
}
}
q = temp_storage.data.block_aggregate.pair.value;
if (temp_storage.data.block_aggregate.pair.count + k > d) {
if (temp_storage.data.block_aggregate.pair.count < k) {
break;
}
}
__syncthreads();
if (tx == 0) {
if (temp_storage.data.block_aggregate.pair.count + k <= d) {
if (temp_storage.data.block_aggregate.pair.count >= k) {
// failed to sample within MAX_TOP_P_ROUNDS
if (success != nullptr) {
success[bx] = false;
Expand All @@ -363,8 +359,6 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}
}

constexpr float eps = 1e-5;

template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType>
Expand All @@ -387,13 +381,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,

vec_t<DType, VEC_SIZE> probs_vec;
DType aggregate;
DType q = DType(0);
DType q = DType(1);
DType pivot = DType(0);
IdType sampled_id;
for (uint32_t round = 0; round < max_top_p_rounds; ++round) {
temp_storage.data.sampled_id = d - 1;
__syncthreads();
DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q);
DType u = uniform_samples[round * batch_size + bx] * q;
aggregate = DType(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
Expand All @@ -410,39 +404,36 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}
__syncthreads();
sampled_id = temp_storage.data.sampled_id;
pivot = probs[row_idx * d + sampled_id];
pivot = max(pivot, probs[row_idx * d + sampled_id]);

DType aggregate_leq_pivot = DType(0);
DType aggregate_gt_pivot = DType(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

DType probs_leq_pivot[VEC_SIZE];
DType probs_gt_pivot[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_leq_pivot[j] = (probs_vec[j] <= pivot) ? probs_vec[j] : DType(0);
probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
}

aggregate_leq_pivot += BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_leq_pivot);
aggregate_gt_pivot += BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot);
if (tx == 0) {
temp_storage.data.block_aggregate.value = aggregate_leq_pivot;
temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
}
__syncthreads();
if (float(temp_storage.data.block_aggregate.value) + top_p > 1 + eps) {
break;
}
}
q = temp_storage.data.block_aggregate.value;
if (float(q) + top_p > 1 + eps) {
if (float(q) < top_p) {
break;
}
}
__syncthreads();
if (tx == 0) {
if (float(q) + top_p <= 1 + eps) {
if (float(q) >= top_p) {
// failed to sample within MAX_TOP_P_ROUNDS
if (success != nullptr) {
success[bx] = false;
Expand Down Expand Up @@ -475,13 +466,13 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp

vec_t<DType, VEC_SIZE> probs_vec;
DType aggregate;
DType q = DType(0);
DType q = DType(1);
DType pivot = DType(0);
IdType sampled_id;
for (uint32_t round = 0; round < max_rounds; ++round) {
temp_storage.data.sampled_id = d - 1;
__syncthreads();
DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q);
DType u = uniform_samples[round * batch_size + bx] * q;
aggregate = DType(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
Expand All @@ -498,43 +489,38 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp
}
__syncthreads();
sampled_id = temp_storage.data.sampled_id;
pivot = probs[bx * d + sampled_id];
pivot = max(pivot, probs[bx * d + sampled_id]);

Pair<DType> aggregate_leq_pivot{DType(0), 0};
Pair<DType> aggregate_gt_pivot{DType(0), 0};
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

Pair<DType> probs_leq_pivot[VEC_SIZE];
Pair<DType> probs_gt_pivot[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_leq_pivot[j] = {
(probs_vec[j] <= pivot) ? probs_vec[j] : DType(0),
(probs_vec[j] <= pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
(probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
}

aggregate_leq_pivot += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_pair)
.Sum<VEC_SIZE>(probs_leq_pivot);
aggregate_gt_pivot += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_pair)
.Sum<VEC_SIZE>(probs_gt_pivot);
if (tx == 0) {
temp_storage.data.block_aggregate.pair = aggregate_leq_pivot;
temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
}
__syncthreads();
if (temp_storage.data.block_aggregate.pair.count + k > d &&
float(temp_storage.data.block_aggregate.pair.value) + p > 1 + eps) {
break;
}
}
q = temp_storage.data.block_aggregate.pair.value;
if (temp_storage.data.block_aggregate.pair.count + k > d && float(q) + p > 1 + eps) {
if (temp_storage.data.block_aggregate.pair.count < k && float(q) < p) {
break;
}
}
__syncthreads();
if (tx == 0) {
if (temp_storage.data.block_aggregate.pair.count + k <= d || float(q) + p <= 1 + eps) {
if (temp_storage.data.block_aggregate.pair.count >= k || float(q) >= p) {
// failed to sample within MAX_TOP_P_ROUNDS
if (success != nullptr) {
success[bx] = false;
Expand Down

0 comments on commit 898d8ea

Please # to comment.