Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[LLM INFER] top_p_sampling_reject support top_p=0 and custom seed #9202

Merged
merged 7 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 30 additions & 37 deletions csrc/gpu/sample_kernels/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ namespace sampling {

using namespace cub;

#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \
if (compute_capacity.first >= 8) { \
constexpr uint32_t BLOCK_THREADS = 1024; \
__VA_ARGS__ \
} else { \
constexpr uint32_t BLOCK_THREADS = 512; \
__VA_ARGS__ \
}

constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS;
constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;

Expand Down Expand Up @@ -277,17 +286,12 @@ template <uint32_t BLOCK_THREADS,
__global__ void TopPSamplingFromProbKernel(DType* probs,
DType* uniform_samples,
IdType* output,
bool* success,
IdType* row_indices,
float* top_p_arr,
float* top_p_val,
uint32_t d,
uint32_t max_top_p_rounds) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
float top_p = (top_p_arr == nullptr) ? top_p_val[bx] : top_p_arr[bx];

const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
float top_p = top_p_val[bx];

extern __shared__ __align__(alignof(SamplingTempStorage<DType,
BLOCK_THREADS,
Expand All @@ -313,7 +317,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs,
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + row_idx * d +
probs_vec.load(probs + bx * d +
(i * BLOCK_THREADS + tx) * VEC_SIZE);
}

Expand All @@ -330,58 +334,51 @@ __global__ void TopPSamplingFromProbKernel(DType* probs,
}
__syncthreads();
sampled_id = temp_storage.data.sampled_id;
pivot = max(pivot, probs[row_idx * d + sampled_id]);
pivot = max(pivot, probs[bx * d + sampled_id]);

DType aggregate_gt_pivot = DType(0);
Pair<DType> aggregate_gt_pivot{DType(0), 0};
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + row_idx * d +
(i * BLOCK_THREADS + tx) * VEC_SIZE);
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

DType probs_gt_pivot[VEC_SIZE];
Pair<DType> probs_gt_pivot[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
(probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
}

aggregate_gt_pivot +=
BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot);
aggregate_gt_pivot += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_pair)
.Sum<VEC_SIZE>(probs_gt_pivot);
if (tx == 0) {
temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
}
__syncthreads();
}
q = temp_storage.data.block_aggregate.value;
if (float(q) < top_p) {
q = temp_storage.data.block_aggregate.pair.value;
if (float(q) > 0 && float(q) < top_p) {
// top_p is not 0
break;
} else {
// top_p is 0, use top_k, k=1
if (temp_storage.data.block_aggregate.pair.count < 1) {
break;
}
}
}
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
if (float(q) >= top_p) {
// failed to sample within MAX_TOP_P_ROUNDS
if (success != nullptr) {
success[bx] = false;
}
} else {
if (success != nullptr) {
success[bx] = true;
}
}
}
}


template <typename T, typename IdType>
cudaError_t TopPSamplingFromProb(T* probs,
T* uniform_samples,
IdType* output,
bool* success,
T* top_p_arr,
uint32_t batch_size,
const T* top_p_val,
uint32_t d,
Expand All @@ -395,13 +392,9 @@ cudaError_t TopPSamplingFromProb(T* probs,
sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
IdType* row_indices_placeholder = nullptr;
void* args[] = {&probs,
&uniform_samples,
&output,
&success,
&row_indices_placeholder,
&top_p_arr,
&top_p_val,
&d,
&max_top_p_rounds};
Expand All @@ -425,4 +418,4 @@ cudaError_t TopPSamplingFromProb(T* probs,
return cudaSuccess;
}

} // namespace sampling
} // namespace sampling
55 changes: 27 additions & 28 deletions csrc/gpu/sample_kernels/top_p_sampling_reject.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,48 +16,46 @@
#include "sample_kernels/sampling.cuh"

std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor& probs,
const paddle::Tensor& top_p) {
const paddle::Tensor& top_p,
int seed) {
std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0];
unsigned int vocab_size = probs_shape[1];

// default is 32
unsigned int max_top_p_rounds = 32;
std::vector<int64_t> uniform_samples_shape = {batch_size, max_top_p_rounds};
paddle::Tensor uniform_samples = paddle::experimental::uniform(
uniform_samples_shape, paddle::DataType::FLOAT32, 0, 1, 0, probs.place());
paddle::Tensor uniform_samples =
paddle::experimental::uniform(uniform_samples_shape,
paddle::DataType::FLOAT32,
0,
1,
seed,
probs.place());

// todo: add parameter for deterministic, now default is true
bool deterministic = true;
paddle::Tensor probs_input;

probs_input = paddle::experimental::cast(probs, paddle::DataType::FLOAT32);
auto cu_stream = probs.stream();

auto samples =
paddle::full({batch_size}, 0, paddle::DataType::INT32, probs.place());
auto success =
paddle::full({batch_size}, 0, paddle::DataType::BOOL, probs.place());
paddle::empty({batch_size, 1}, paddle::DataType::INT64, probs.place());

cudaError_t status;

cudaError_t status =
sampling::TopPSamplingFromProb<float, int>(probs_input.data<float>(),
uniform_samples.data<float>(),
samples.data<int>(),
success.data<bool>(),
nullptr,
batch_size,
top_p.data<float>(),
vocab_size,
max_top_p_rounds,
deterministic,
cu_stream);
status = sampling::TopPSamplingFromProb<float, int64_t>(
const_cast<float*>(probs.data<float>()),
uniform_samples.data<float>(),
samples.data<int64_t>(),
batch_size,
top_p.data<float>(),
vocab_size,
max_top_p_rounds,
true,
cu_stream);

PD_CHECK(status == cudaSuccess,
"SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

paddle::Tensor samples_output;
samples_output = paddle::experimental::cast(samples, paddle::DataType::INT64);
return {samples_output};
return {samples};
}

std::vector<std::vector<int64_t>> TopPSamplingRejectInferShape(
Expand All @@ -69,12 +67,13 @@ std::vector<std::vector<int64_t>> TopPSamplingRejectInferShape(

std::vector<paddle::DataType> TopPSamplingRejectInferDtype(
const paddle::DataType& probs_dtype, const paddle::DataType& top_p_shape) {
return {probs_dtype};
return {paddle::DataType::INT64};
}

PD_BUILD_OP(top_p_sampling_reject)
.Inputs({"probs", "top_p"})
.Outputs({"samples"})
.Attrs({"seed: int"})
.SetKernelFn(PD_KERNEL(TopPSamplingReject))
.SetInferShapeFn(PD_INFER_SHAPE(TopPSamplingRejectInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype));
63 changes: 63 additions & 0 deletions csrc/gpu/test/python/test_top_p_sampling_reject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import paddle
from paddlenlp_ops import top_p_sampling_reject

paddle.seed(2023)

batch_size = 3
vocab_size = 40080
max_rounds = 32

class SetPreidsTokenPenaltyMultiScores(unittest.TestCase):
def test_top_p_sampling_reject_case1(self):
# top_p为1, 不同seed
pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32)

paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np)
paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True)
top_p_paddle = paddle.full((batch_size,), 1)
samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0)
print(samples)
samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 1024)
print(samples)
samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 2033)
print(samples)

def test_top_p_sampling_reject_case2(self):
# top_p为0
pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32)

paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np)
paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True)
top_p_paddle = paddle.full((batch_size,), 0)
samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0)
print(samples)

def test_top_p_sampling_reject_case3(self):
# 不同batch的top_p值不同
pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32)

paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np)
paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True)
top_p_paddle = paddle.uniform(shape=[batch_size,1], min=0, max=1)
samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0)
print(samples)

if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@
try:
from paddlenlp_ops import top_p_sampling_reject

next_tokens = top_p_sampling_reject(probs, top_p)
next_tokens = top_p_sampling_reject(probs, top_p, 0)

Check warning on line 336 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L336

Added line #L336 was not covered by tests
except:
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_p)

Expand Down Expand Up @@ -677,7 +677,7 @@
try:
from paddlenlp_ops import top_p_sampling_reject

next_tokens = top_p_sampling_reject(probs, top_p)
next_tokens = top_p_sampling_reject(probs, top_p, 0)

Check warning on line 680 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L680

Added line #L680 was not covered by tests
except:
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_p)

Expand Down
Loading