From fa38b5e34b9591bd5ab07186bea229ea95307755 Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Fri, 16 Aug 2024 22:25:56 -0700 Subject: [PATCH] feat: add accept num, emit num metric for ChainSpeculativeSampling (#450) --- include/flashinfer/sampling.cuh | 43 +++++++++++++++++++++++++-------- python/csrc/flashinfer_ops.h | 7 +++--- python/csrc/sampling.cu | 27 ++++++++++++++++----- python/flashinfer/sampling.py | 19 ++++++++++++++- python/tests/test_sampling.py | 36 +++++++++++++++++++++++---- 5 files changed, 107 insertions(+), 25 deletions(-) diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 4f4272f6..f5b35efa 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -1154,8 +1154,10 @@ template __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids, DType* uniform_samples, DType* target_probs, - IdType* output_token_ids, uint32_t num_speculative_tokens, - uint32_t d) { + IdType* output_token_ids, + IdType* output_accepted_token_num, + IdType* output_emitted_token_num, + uint32_t num_speculative_tokens, uint32_t d) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; @@ -1165,20 +1167,38 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token auto& temp_storage = reinterpret_cast< SamplingTempStorage&>(smem_sampling); - uint32_t pos = 0; - for (pos = 0; pos < num_speculative_tokens; ++pos) { - IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + pos]; - float q = target_probs[(row_idx * (num_speculative_tokens + 1) + pos) * d + draft_id], - p = draft_probs[(row_idx * num_speculative_tokens + pos) * d + draft_id]; - DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + pos]; + uint32_t pos = num_speculative_tokens; + for (uint32_t i = 0; i < num_speculative_tokens; ++i) { + IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i]; + float q = target_probs[(row_idx * (num_speculative_tokens + 1) + i) * d + draft_id], + p = draft_probs[(row_idx * num_speculative_tokens + i) * d + draft_id]; + DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + i]; if (u * p < q) { // accept the draft models output - output_token_ids[row_idx * (num_speculative_tokens + 1) + pos] = draft_id; + output_token_ids[row_idx * (num_speculative_tokens + 1) + i] = draft_id; } else { + pos = i; break; } } + uint32_t emitted_token_num = pos; + uint32_t accepted_token_num = pos; + for (uint32_t i = pos; i < num_speculative_tokens; ++i) { + IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i]; + float q = target_probs[(row_idx * (num_speculative_tokens + 1) + i) * d + draft_id], + p = draft_probs[(row_idx * num_speculative_tokens + i) * d + draft_id]; + DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + i]; + if (u * p < q) { + ++accepted_token_num; + } + } + + if (tx == 0) { + output_accepted_token_num[row_idx] += accepted_token_num; + output_emitted_token_num[row_idx] += emitted_token_num; + } + // sample from relu(target_probs - draft_probs) DType sum_relu_q_minus_p(0); vec_t q_vec, p_vec; @@ -1284,7 +1304,8 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o template cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids, DType* uniform_samples, DType* target_probs, - IdType* output_token_ids, uint32_t batch_size, + IdType* output_token_ids, IdType* output_accepted_token_num, + IdType* output_emitted_token_num, uint32_t batch_size, uint32_t num_speculative_tokens, uint32_t d, bool deterministic, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; @@ -1299,6 +1320,8 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids &uniform_samples, &target_probs, &output_token_ids, + &output_accepted_token_num, + &output_emitted_token_num, &num_speculative_tokens, &d}; DISPATCH_ALIGNED_VEC_SIZE( diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 073b7219..3d0b678d 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -67,9 +67,10 @@ torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional maybe_top_k_arr, unsigned int top_k_val, double eps); -torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, - torch::Tensor uniform_samples, torch::Tensor target_probs, - bool deterministic); +std::vector chain_speculative_sampling( + torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, + torch::Tensor target_probs, std::optional maybe_output_accepted_token_num, + std::optional maybe_output_emitted_token_num, bool deterministic); torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps); diff --git a/python/csrc/sampling.cu b/python/csrc/sampling.cu index 623a0c45..dea531e0 100644 --- a/python/csrc/sampling.cu +++ b/python/csrc/sampling.cu @@ -315,9 +315,10 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional chain_speculative_sampling( + torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, + torch::Tensor target_probs, std::optional maybe_output_accepted_token_num, + std::optional maybe_output_emitted_token_num, bool deterministic) { CHECK_INPUT(draft_probs); CHECK_INPUT(draft_token_ids); CHECK_INPUT(uniform_samples); @@ -349,14 +350,28 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso auto output_token_ids = torch::empty({batch_size, num_speculate_tokens + 1}, torch::dtype(torch::kInt32).device(device)); + bool has_output_accepted_token_num = maybe_output_accepted_token_num.has_value(); + bool has_output_emitted_token_num = maybe_output_emitted_token_num.has_value(); + auto output_accepted_token_num = maybe_output_accepted_token_num.value_or( + torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device))); + auto output_emitted_token_num = maybe_output_emitted_token_num.value_or( + torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device))); + if (has_output_accepted_token_num) { + CHECK_EQ(has_output_emitted_token_num, true); + CHECK_EQ(batch_size, output_accepted_token_num.size(0)); + CHECK_EQ(batch_size, output_emitted_token_num.size(0)); + } + cudaError_t status = sampling::ChainSpeculativeSampling( static_cast(draft_probs.data_ptr()), static_cast(draft_token_ids.data_ptr()), static_cast(uniform_samples.data_ptr()), static_cast(target_probs.data_ptr()), - static_cast(output_token_ids.data_ptr()), batch_size, num_speculate_tokens, vocab_size, - deterministic, torch_current_stream); + static_cast(output_token_ids.data_ptr()), + static_cast(output_accepted_token_num.data_ptr()), + static_cast(output_emitted_token_num.data_ptr()), batch_size, num_speculate_tokens, + vocab_size, deterministic, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " + std::string(cudaGetErrorString(status))); - return output_token_ids; + return {output_token_ids, output_accepted_token_num, output_emitted_token_num}; } diff --git a/python/flashinfer/sampling.py b/python/flashinfer/sampling.py index b6c23d73..290d4461 100644 --- a/python/flashinfer/sampling.py +++ b/python/flashinfer/sampling.py @@ -592,6 +592,8 @@ def chain_speculative_sampling( draft_token_ids, uniform_samples, target_probs, + maybe_output_accepted_token_num: torch.Tensor = None, + maybe_output_emitted_token_num: torch.Tensor = None, deterministic: bool = True, ) -> torch.Tensor: r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in @@ -614,6 +616,15 @@ def chain_speculative_sampling( Compared to input :attr:`draft_probs`, the target model's probability has an additional slot at the end because the target model will generate one more token than the draft model. Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)`` + maybe_output_accepted_token_num: torch.Tensor + The number of tokens that can be accepted if each token is considered independently for each request. + This metric does not consider the fact that rejection sampling will stop at the first token that does not + satisfy the probablity requirement r < p/q. + It only evaluates the alignment of draft model and target model. + Shape: ``(batch_size)`` + maybe_output_emitted_token_num: torch.Tensor + The number of tokens that are finally emitted/generated for each request. + Shape: ``(batch_size)`` deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. @@ -628,5 +639,11 @@ def chain_speculative_sampling( Shape: (batch_size, num_specutate_tokens + 1) """ return _kernels.chain_speculative_sampling( - draft_probs, draft_token_ids, uniform_samples, target_probs, deterministic + draft_probs, + draft_token_ids, + uniform_samples, + target_probs, + maybe_output_accepted_token_num, + maybe_output_emitted_token_num, + deterministic, ) diff --git a/python/tests/test_sampling.py b/python/tests/test_sampling.py index 43a49bfe..64344029 100644 --- a/python/tests/test_sampling.py +++ b/python/tests/test_sampling.py @@ -339,11 +339,17 @@ def test_chain_speculative_sampling( # NOTE(Zihao): this is a very simple test that only checks whether output is valid or not. for trials in range(10): uniform_samples.uniform_() - output_token_ids = flashinfer.sampling.chain_speculative_sampling( - normalized_draft_prob, - draft_token_ids, - uniform_samples, - target_onehot_prob, + accepted_num = torch.zeros(batch_size, dtype=torch.int32).to(0) + emitted_num = torch.zeros(batch_size, dtype=torch.int32).to(0) + output_token_ids, accepted_num, emitted_num = ( + flashinfer.sampling.chain_speculative_sampling( + normalized_draft_prob, + draft_token_ids, + uniform_samples, + target_onehot_prob, + accepted_num, + emitted_num, + ) ) if onehot_target: assert torch.all(output_token_ids == target_token_ids) @@ -359,6 +365,26 @@ def test_chain_speculative_sampling( # from the second mismatched token on, the output tokens should be -1 assert torch.all(output_token_ids[row, mismatch_idx[0] + 1 :] == -1) + assert torch.all(emitted_num + 1 == (output_token_ids != -1).sum(dim=1)) + batch_indices = torch.arange(batch_size, device=normalized_draft_prob.device)[ + :, None + ] + probs_indicies = torch.arange( + num_speculate_tokens, device=normalized_draft_prob.device + ) + selected_draft_probs = normalized_draft_prob[ + batch_indices, probs_indicies, draft_token_ids + ] + selected_target_probs = target_onehot_prob[ + batch_indices, probs_indicies, draft_token_ids + ] + capped_ratio = torch.minimum( + selected_target_probs / selected_draft_probs, + torch.full((1,), 1, device=normalized_draft_prob.device), + ) + ref_accepted = (uniform_samples[:, :-1] < capped_ratio).sum(dim=1) + assert torch.all(accepted_num == ref_accepted) + if __name__ == "__main__": test_sampling(1, 111)