Skip to content

Commit

Permalink
feat: Fused GPU sampling kernel for joint top-k & top-p sampling (#374)
Browse files Browse the repository at this point in the history
Currently our sampling kernels only support either top-k or top-p
sampling. However, these two sampling algorithms can be used together,
this PR implements the sampling kernel that performs top-k and top-p
sampling jointly.
  • Loading branch information
yzh119 authored Jul 13, 2024
1 parent e14fa81 commit 6e028eb
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 16 deletions.
10 changes: 4 additions & 6 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,11 @@ __global__ void MergeStateKernel(DTypeIn* __restrict__ v_a, float* __restrict__
template <uint32_t vec_size, typename DType>
__global__ void MergeStateInPlaceKernel(DType* __restrict__ v, float* __restrict__ s,
DType* __restrict__ v_other, float* __restrict__ s_other,
uint8_t* __restrict__ mask,
uint32_t num_heads, uint32_t head_dim) {
uint8_t* __restrict__ mask, uint32_t num_heads,
uint32_t head_dim) {
uint32_t pos = blockIdx.x;

if (mask != nullptr && mask[pos] == 0)
return;
if (mask != nullptr && mask[pos] == 0) return;

uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t head_idx = ty;
Expand Down Expand Up @@ -396,8 +395,7 @@ cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DType
*/
template <typename DType>
cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other, uint32_t seq_len,
uint32_t num_heads, uint32_t head_dim,
uint8_t* mask = nullptr,
uint32_t num_heads, uint32_t head_dim, uint8_t* mask = nullptr,
cudaStream_t stream = nullptr) {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16U / sizeof(DType), HEAD_DIM / 32U);
Expand Down
114 changes: 114 additions & 0 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#ifndef FLASHINFER_SAMPLING_CUH_
#define FLASHINFER_SAMPLING_CUH_

#include <driver_types.h>

#include <cub/block/block_adjacent_difference.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_scan.cuh>
Expand Down Expand Up @@ -342,6 +344,96 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}
}

template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, typename DType, typename IdType>
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* top_k,
DType* top_p, IdType* output, bool* success,
uint32_t d, uint32_t max_rounds) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
IdType k = top_k[bx];
DType p = top_p[bx];

extern __shared__ __align__(
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

vec_t<DType, VEC_SIZE> probs_vec;
DType aggregate;
DType q = DType(0);
DType pivot = DType(0);
IdType sampled_id;
for (uint32_t round = 0; round < max_rounds; ++round) {
temp_storage.data.sampled_id = d - 1;
__syncthreads();
DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q);
aggregate = DType(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DType>(
i, d, pivot, u, probs_vec, aggregate, &temp_storage);
if (aggregate > u) {
break;
}
}
__syncthreads();
sampled_id = temp_storage.data.sampled_id;
pivot = probs[bx * d + sampled_id];

Pair<DType> aggregate_leq_pivot{DType(0), 0};
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

Pair<DType> probs_leq_pivot[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_leq_pivot[j] = {
(probs_vec[j] <= pivot) ? probs_vec[j] : DType(0),
(probs_vec[j] <= pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
}

aggregate_leq_pivot += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_pair)
.Sum<VEC_SIZE>(probs_leq_pivot);
if (tx == 0) {
temp_storage.data.block_aggregate.pair = aggregate_leq_pivot;
}
__syncthreads();
if (temp_storage.data.block_aggregate.pair.count + k > d &&
float(temp_storage.data.block_aggregate.pair.value) + p > 1 + eps) {
break;
}
}
q = temp_storage.data.block_aggregate.pair.value;
if (temp_storage.data.block_aggregate.pair.count + k > d && float(q) + p > 1 + eps) {
break;
}
}
__syncthreads();
if (tx == 0) {
if (temp_storage.data.block_aggregate.pair.count + k <= d || float(q) + p <= 1 + eps) {
// failed to sample within MAX_TOP_P_ROUNDS
if (success != nullptr) {
success[bx] = false;
}
} else {
output[bx] = sampled_id;
if (success != nullptr) {
success[bx] = true;
}
}
}
}

template <typename T, typename IdType>
cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint32_t batch_size,
uint32_t d, cudaStream_t stream = 0) {
Expand Down Expand Up @@ -434,6 +526,28 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
return cudaSuccess;
}

template <typename T, typename IdType>
cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k, T* top_p,
IdType* output, bool* success, uint32_t batch_size, uint32_t d,
uint32_t max_rounds, cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t smem_size = sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &uniform_samples, &top_k, &top_p, &output, &success, &d, &max_rounds};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel =
TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE, T, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
}

template <typename T, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
struct RenormTempStorage {
union {
Expand Down
2 changes: 2 additions & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Top-k sampling from probabilities");
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs,
"Top-p sampling from probabilities");
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs,
"Top-k and top-p sampling from probabilities");
m.def("top_k_renorm_prob", &top_k_renorm_prob, "Renormalize probabilities by top-k mask");
m.def("top_p_renorm_prob", &top_p_renorm_prob, "Renormalize probabilities by top-p mask");
m.def("chain_speculative_sampling", &chain_speculative_sampling,
Expand Down
5 changes: 5 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
unsigned int top_k);

std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
torch::Tensor top_k,
torch::Tensor top_p);

torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps);

torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps);
Expand Down
42 changes: 42 additions & 0 deletions python/csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,48 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
return {samples, success};
}

std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
torch::Tensor top_k,
torch::Tensor top_p) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
CHECK_INPUT(top_k);
CHECK_INPUT(top_p);
auto device = probs.device();
CHECK_EQ(uniform_samples.device(), device);
CHECK_EQ(top_k.device(), device);
CHECK_EQ(top_p.device(), device);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size)
CHECK_DIM(1, top_k); // top_k: (batch_size,)
CHECK_DIM(1, top_p); // top_p: (batch_size,)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
unsigned int max_rounds = uniform_samples.size(0);
CHECK_EQ(uniform_samples.size(1), batch_size);
CHECK_EQ(top_k.size(0), batch_size);
CHECK_EQ(top_p.size(0), batch_size);
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);
top_k = top_k.to(torch::kInt32);
top_p = top_p.to(torch::kFloat32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));

cudaError_t status = sampling::TopKTopPSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(top_k.data_ptr()), static_cast<float*>(top_p.data_ptr()),
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), batch_size,
vocab_size, max_rounds, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "TopKTopPSamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

return {samples, success};
}

torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps) {
CHECK_INPUT(probs);
auto device = probs.device();
Expand Down
71 changes: 71 additions & 0 deletions python/flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,77 @@ def top_k_sampling_from_probs(
return _kernels.top_k_sampling_from_probs(probs, uniform_samples, top_k)


def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: torch.Tensor,
top_p: torch.Tensor,
):
r"""Fused GPU kernel for joint top-k and top-p sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
uniform_samples: torch.Tensor
The uniform samples used as needle for sampling, shape ``(max_top_k_rounds, batch_size,)``,
where the first dimension is the maximum number of rounds for rejection sampling.
Expected to be uniformly distributed in ``[0, 1)``.
top_k: torch.Tensor
The k in "top-k" for each request, shape ``(batch_size,)``.
top_p: torch.Tensor
The threshold for top-p sampling for each request, shape ``(batch_size,)``.
Returns
-------
(samples, success): Tuple[torch.Tensor, torch.Tensor]
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
success: torch.Tensor
Whether the sampling is successful within ``max_top_k_rounds`` rounds,
shape ``(batch_size,)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> max_rounds = 3
>>> top_p = torch.full((batch_size,), 0.2).to(0)
>>> top_k = torch.full((batch_size,), 2).to(0)
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> norm_prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> uniform_samples = torch.rand(max_rounds, batch_size).to(0)
>>> samples, success = flashinfer.sampling.top_k_top_p_sampling_from_probs(norm_prob, uniform_samples, top_k, top_p)
>>> samples
tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32)
>>> success
tensor([True, True, True, True], device='cuda:0')
Notes
-----
This function expects float32 inputs, and the output is int32.
We encourage users to set ``max_rounds`` to a reasonable value, e.g., 32. The actual
implementation usually use much fewer rounds for rejection sampling because of early stopping.
"""
return _kernels.top_k_top_p_sampling_from_probs(
probs, uniform_samples, top_k, top_p
)


def top_p_renorm_prob(probs: torch.Tensor, top_p: float, eps: float = 1e-5):
r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding.
Expand Down
44 changes: 44 additions & 0 deletions python/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,50 @@ def test_top_k_sampling(batch_size, vocab_size, k):
]


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_sampling(batch_size, vocab_size, p):
if p == 0.1:
k = int(vocab_size * 0.5)
elif p == 0.5:
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
max_top_k_trails = 32
eps = 1e-4
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
# top-p mask
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
# top-k mask
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
pivot = sorted_prob[:, k - 1]
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
# overall mask
mask = torch.minimum(mask_top_p, mask_top_k)
uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to(
0
)
top_p_tensor = torch.full((batch_size,), p).to(0)
top_k_tensor = torch.full((batch_size,), k).to(0)

num_trails = 1000
for _ in range(num_trails):
uniform_samples.uniform_()
samples, success = flashinfer.sampling.top_k_top_p_sampling_from_probs(
normalized_prob, uniform_samples, top_k_tensor, top_p_tensor
)
assert torch.all(success)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
torch.arange(batch_size), samples
]


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
Expand Down
Loading

0 comments on commit 6e028eb

Please # to comment.