Skip to content

Commit

Permalink
feat: deterministic sampling (#417)
Browse files Browse the repository at this point in the history
Our previous sampling kernels relies on cub's BlockScan, which is not
deterministic as reported in NVIDIA/cub#454.

This PR implements the deterministic BlockScan using Belloch scan
algorithm, which is slower than cub but guarantees determinism.
  • Loading branch information
yzh119 authored Aug 2, 2024
1 parent 146c31e commit 0dd801d
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 115 deletions.
274 changes: 203 additions & 71 deletions include/flashinfer/sampling.cuh

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,29 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe

std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s);

torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples);
torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples,
bool deterministic);

std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples, double top_p);
torch::Tensor uniform_samples, double top_p,
bool deterministic);

std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
unsigned int top_k);
unsigned int top_k, bool deterministic);

std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
torch::Tensor top_k,
torch::Tensor top_p);
torch::Tensor top_k, torch::Tensor top_p,
bool deterministic);

torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps);

torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps);

torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
torch::Tensor uniform_samples, torch::Tensor target_probs);
torch::Tensor uniform_samples, torch::Tensor target_probs,
bool deterministic);

torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);

Expand Down
31 changes: 17 additions & 14 deletions python/csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

using namespace flashinfer;

torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples) {
torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples,
bool deterministic) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
auto device = probs.device();
Expand All @@ -36,16 +37,18 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));

cudaError_t status = sampling::SamplingFromProb(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(samples.data_ptr()), batch_size, vocab_size, torch_current_stream);
cudaError_t status = sampling::SamplingFromProb(static_cast<float*>(probs.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(samples.data_ptr()), batch_size,
vocab_size, deterministic, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
return samples;
}

std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples, double top_p) {
torch::Tensor uniform_samples, double top_p,
bool deterministic) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
auto device = probs.device();
Expand All @@ -66,7 +69,7 @@ std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
cudaError_t status = sampling::TopPSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), top_p,
batch_size, vocab_size, max_top_p_rounds, torch_current_stream);
batch_size, vocab_size, max_top_p_rounds, deterministic, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "TopPSamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

Expand All @@ -75,7 +78,7 @@ std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,

std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
unsigned int top_k) {
unsigned int top_k, bool deterministic) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
auto device = probs.device();
Expand All @@ -96,7 +99,7 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
cudaError_t status = sampling::TopKSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), top_k,
batch_size, vocab_size, max_top_k_rounds, torch_current_stream);
batch_size, vocab_size, max_top_k_rounds, deterministic, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "TopKSamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

Expand All @@ -105,8 +108,8 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,

std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
torch::Tensor top_k,
torch::Tensor top_p) {
torch::Tensor top_k, torch::Tensor top_p,
bool deterministic) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
CHECK_INPUT(top_k);
Expand Down Expand Up @@ -138,7 +141,7 @@ std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(top_k.data_ptr()), static_cast<float*>(top_p.data_ptr()),
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), batch_size,
vocab_size, max_rounds, torch_current_stream);
vocab_size, max_rounds, deterministic, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "TopKTopPSamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

Expand Down Expand Up @@ -187,8 +190,8 @@ torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double
}

torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
torch::Tensor uniform_samples,
torch::Tensor target_probs) {
torch::Tensor uniform_samples, torch::Tensor target_probs,
bool deterministic) {
CHECK_INPUT(draft_probs);
CHECK_INPUT(draft_token_ids);
CHECK_INPUT(uniform_samples);
Expand Down Expand Up @@ -224,7 +227,7 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()), static_cast<float*>(target_probs.data_ptr()),
static_cast<int*>(output_token_ids.data_ptr()), batch_size, num_speculate_tokens, vocab_size,
torch_current_stream);
deterministic, torch_current_stream);

TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
std::string(cudaGetErrorString(status)));
Expand Down
38 changes: 30 additions & 8 deletions python/flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def sampling_from_probs(
probs: torch.Tensor, uniform_samples: torch.Tensor
probs: torch.Tensor, uniform_samples: torch.Tensor, deterministic: bool = True
) -> torch.Tensor:
r"""Fused GPU kernel for category sampling from probabilities.
Expand All @@ -43,6 +43,8 @@ def sampling_from_probs(
uniform_samples: torch.Tensor
The uniform samples used as needle for sampling, shape ``(batch_size,)``.
Expected to be uniformly distributed in ``[0, 1)``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
Returns
-------
Expand Down Expand Up @@ -73,11 +75,14 @@ def sampling_from_probs(
-----
This function expects float32 inputs, and the output is int32.
"""
return _kernels.sampling_from_probs(probs, uniform_samples)
return _kernels.sampling_from_probs(probs, uniform_samples, deterministic)


def top_p_sampling_from_probs(
probs: torch.Tensor, uniform_samples: torch.Tensor, top_p: float
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_p: float,
deterministic: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Expand All @@ -95,6 +100,8 @@ def top_p_sampling_from_probs(
Expected to be uniformly distributed in ``[0, 1)``.
top_p: float
The threshold for top-p sampling.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
Returns
-------
Expand Down Expand Up @@ -134,11 +141,16 @@ def top_p_sampling_from_probs(
We encourage users to set ``max_top_p_rounds`` to a reasonable value, e.g., 32. The actual
implementation usually use much fewer rounds for rejection sampling because of early stopping.
"""
return _kernels.top_p_sampling_from_probs(probs, uniform_samples, top_p)
return _kernels.top_p_sampling_from_probs(
probs, uniform_samples, top_p, deterministic
)


def top_k_sampling_from_probs(
probs: torch.Tensor, uniform_samples: torch.Tensor, top_k: int
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: int,
deterministic: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Fused GPU kernel for top-k sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Expand All @@ -156,6 +168,8 @@ def top_k_sampling_from_probs(
Expected to be uniformly distributed in ``[0, 1)``.
top_k: int
The k in "top-k".
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
Returns
-------
Expand Down Expand Up @@ -195,14 +209,17 @@ def top_k_sampling_from_probs(
We encourage users to set ``max_top_k_rounds`` to a reasonable value, e.g., 32. The actual
implementation usually use much fewer rounds for rejection sampling because of early stopping.
"""
return _kernels.top_k_sampling_from_probs(probs, uniform_samples, top_k)
return _kernels.top_k_sampling_from_probs(
probs, uniform_samples, top_k, deterministic
)


def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: torch.Tensor,
top_p: torch.Tensor,
deterministic: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Fused GPU kernel for joint top-k and top-p sampling from probabilities,
Expand All @@ -223,6 +240,8 @@ def top_k_top_p_sampling_from_probs(
The k in "top-k" for each request, shape ``(batch_size,)``.
top_p: torch.Tensor
The threshold for top-p sampling for each request, shape ``(batch_size,)``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
Returns
-------
Expand Down Expand Up @@ -264,7 +283,7 @@ def top_k_top_p_sampling_from_probs(
implementation usually use much fewer rounds for rejection sampling because of early stopping.
"""
return _kernels.top_k_top_p_sampling_from_probs(
probs, uniform_samples, top_k, top_p
probs, uniform_samples, top_k, top_p, deterministic
)


Expand Down Expand Up @@ -328,6 +347,7 @@ def chain_speculative_sampling(
draft_token_ids,
uniform_samples,
target_probs,
deterministic: bool = True,
) -> torch.Tensor:
r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in
paper `Accelerating Large Language Model Decoding with Speculative Sampling <https://arxiv.org/pdf/2302.01318>`_),
Expand All @@ -349,6 +369,8 @@ def chain_speculative_sampling(
Compared to input :attr:`draft_probs`, the target model's probability has an additional
slot at the end because the target model will generate one more token than the draft model.
Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)``
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
Returns
-------
Expand All @@ -361,5 +383,5 @@ def chain_speculative_sampling(
Shape: (batch_size, num_specutate_tokens + 1)
"""
return _kernels.chain_speculative_sampling(
draft_probs, draft_token_ids, uniform_samples, target_probs
draft_probs, draft_token_ids, uniform_samples, target_probs, deterministic
)
18 changes: 12 additions & 6 deletions src/bench_sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ template <typename T>
void bench_sampling_with_probability(nvbench::state& state) {
size_t batch_size = state.get_int64("batch_size");
size_t vocab_size = state.get_int64("vocab_size");
bool deterministic = state.get_int64("determinisic");

std::vector<T> probs_h(batch_size * vocab_size);
std::vector<T> uniform_samples_h(batch_size);
Expand Down Expand Up @@ -55,7 +56,7 @@ void bench_sampling_with_probability(nvbench::state& state) {
cudaError_t status = sampling::SamplingFromProb<T>(
thrust::raw_pointer_cast(probs_d.data()),
thrust::raw_pointer_cast(uniform_samples_d.data()),
thrust::raw_pointer_cast(output_d.data()), batch_size, vocab_size);
thrust::raw_pointer_cast(output_d.data()), batch_size, vocab_size, deterministic);
timer.stop();
if (status != cudaSuccess) {
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
Expand All @@ -67,6 +68,7 @@ template <typename T>
void bench_top_p_sampling_with_probability(nvbench::state& state) {
size_t batch_size = state.get_int64("batch_size");
size_t vocab_size = state.get_int64("vocab_size");
bool deterministic = state.get_int64("determinisic");
double p = state.get_float64("p");
constexpr uint32_t max_top_p_rounds = 32;

Expand Down Expand Up @@ -100,7 +102,7 @@ void bench_top_p_sampling_with_probability(nvbench::state& state) {
thrust::raw_pointer_cast(probs_d.data()),
thrust::raw_pointer_cast(uniform_samples_d.data()),
thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), p,
batch_size, vocab_size, max_top_p_rounds);
batch_size, vocab_size, max_top_p_rounds, deterministic);
timer.stop();
if (status != cudaSuccess) {
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
Expand All @@ -113,6 +115,7 @@ void bench_top_k_sampling_with_probability(nvbench::state& state) {
size_t batch_size = state.get_int64("batch_size");
size_t vocab_size = state.get_int64("vocab_size");
size_t k = state.get_int64("k");
bool deterministic = state.get_int64("determinisic");
constexpr uint32_t max_top_k_rounds = 32;

std::vector<T> probs_h(batch_size * vocab_size);
Expand Down Expand Up @@ -145,7 +148,7 @@ void bench_top_k_sampling_with_probability(nvbench::state& state) {
thrust::raw_pointer_cast(probs_d.data()),
thrust::raw_pointer_cast(uniform_samples_d.data()),
thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), k,
batch_size, vocab_size, max_top_k_rounds);
batch_size, vocab_size, max_top_k_rounds, deterministic);
timer.stop();
if (status != cudaSuccess) {
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
Expand All @@ -157,18 +160,21 @@ auto bench_sampling_with_probability_f32 = bench_sampling_with_probability<float
NVBENCH_BENCH(bench_sampling_with_probability_f32)
.set_name("bench_sampling_with_probability_f32")
.add_int64_axis("batch_size", {16, 32, 128, 512, 2048})
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000});
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000})
.add_int64_axis("determinisic", {0, 1});

auto bench_top_p_sampling_with_probability_f32 = bench_top_p_sampling_with_probability<float>;
NVBENCH_BENCH(bench_top_p_sampling_with_probability_f32)
.set_name("bench_top_p_sampling_with_probability_f32")
.add_int64_axis("batch_size", {16, 32, 128, 512, 2048})
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000})
.add_float64_axis("p", {0.1, 0.5, 0.9, 1.0});
.add_float64_axis("p", {0.1, 0.5, 0.9, 1.0})
.add_int64_axis("determinisic", {0, 1});

auto bench_top_k_sampling_with_probability_f32 = bench_top_k_sampling_with_probability<float>;
NVBENCH_BENCH(bench_top_k_sampling_with_probability_f32)
.set_name("bench_top_k_sampling_with_probability_f32")
.add_int64_axis("batch_size", {16, 32, 128, 512, 2048})
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000})
.add_int64_axis("k", {16, 32, 128, 1024});
.add_int64_axis("k", {16, 32, 128, 1024})
.add_int64_axis("determinisic", {0, 1});
Loading

0 comments on commit 0dd801d

Please # to comment.