Skip to content

Commit

Permalink
refactor: Break up _kernels into multiple modules (#428)
Browse files Browse the repository at this point in the history
Breaks up the `_kernels` module into multiple modules to avoid issues
caused by the file growing too large.
  • Loading branch information
Yard1 authored Aug 8, 2024
1 parent 898d8ea commit 8e482d9
Show file tree
Hide file tree
Showing 13 changed files with 311 additions and 202 deletions.
2 changes: 1 addition & 1 deletion python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
#include <flashinfer/decode_attention_decl.cuh>

#include "flashinfer_ops.h"
#include "flashinfer_ops_decode.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;
Expand Down
2 changes: 1 addition & 1 deletion python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
#include <flashinfer/prefill_attention_decl.cuh>

#include "flashinfer_ops.h"
#include "flashinfer_ops_prefill.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;
Expand Down
37 changes: 0 additions & 37 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,6 @@
#include "flashinfer_ops.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache,
"Single-request decode with KV-Cache operator");
m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache,
"Single-request prefill with KV-Cache operator, return logsumexp");
m.def(
"single_prefill_with_kv_cache_custom_mask", &single_prefill_with_kv_cache_custom_mask,
"Single-request prefill with KV-Cache operator, user defined custom mask, return logsumexp");
m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator");
m.def("merge_state", &merge_state, "Merge two self-attention states");
m.def("merge_state_in_place", &merge_state_in_place,
Expand All @@ -50,36 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
m.def("packbits", &packbits, "GPU packbits operator");
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
"BatchDecodeWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool, unsigned int>())
.def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>(
m, "BatchPrefillWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool>())
.def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward)
.def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask);
py::class_<BatchPrefillWithRaggedKVCachePyTorchWrapper>(
m, "BatchPrefillWithRaggedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool>())
.def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward)
.def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask);
py::class_<CutlassSegmentGEMMPyTorchWrapper>(m, "CutlassSegmentGEMMPyTorchWrapper")
.def(py::init<torch::Tensor>())
.def("register_workspace", &CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer)
Expand Down
113 changes: 0 additions & 113 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,10 @@
#pragma once
#include <torch/extension.h>

#include <flashinfer/attention/handler.cuh>
#include <flashinfer/group_gemm/handler.cuh>
#include <flashinfer/layout.cuh>
#include <memory>

torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v,
torch::Tensor tmp, unsigned int pos_encoding_mode,
unsigned int layout, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta);

std::vector<torch::Tensor> single_prefill_with_kv_cache(
torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal,
unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
bool return_lse);

std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(
torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor packed_custom_mask,
torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode,
bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);

void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
torch::Tensor append_indptr, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
Expand Down Expand Up @@ -106,100 +87,6 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);
torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
torch::Tensor output_indptr, const std::string& bitorder);

class BatchDecodeWithPagedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr,
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
unsigned int pos_encoding_mode, float logits_soft_cap,
torch::Tensor empty_q_data, torch::Tensor empty_kv_data);
void EndForward();
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
std::vector<torch::Tensor> Forward(torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len,
unsigned int pos_encoding_mode, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta, bool return_lse);
BatchDecodeWithPagedKVCachePyTorchWrapper(
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_ptr, flashinfer::QKVLayout kv_layout)
: handler_(handler_ptr), kv_layout_(kv_layout) {}
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph,
unsigned int fixed_batch_size)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(enable_cuda_graph,
fixed_batch_size)) {}

protected:
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_;
flashinfer::QKVLayout kv_layout_;
};

class BatchPrefillWithPagedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
torch::Tensor page_kv_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
unsigned page_size, torch::Tensor empty_q_data);
void EndForward();
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, bool causal,
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
std::vector<torch::Tensor> ForwardCustomMask(
torch::Tensor q, torch::Tensor qo_indptr, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, torch::Tensor packed_custom_mask,
torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
bool return_lse);
BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(enable_cuda_graph)) {}

private:
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
flashinfer::QKVLayout kv_layout_;
};

class BatchPrefillWithRaggedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, torch::Tensor empty_q_data);
void EndForward();
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k,
torch::Tensor v, torch::Tensor kv_indptr, bool causal,
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
std::vector<torch::Tensor> ForwardCustomMask(
torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v,
torch::Tensor kv_indptr, torch::Tensor packed_custom_mask, torch::Tensor qk_indptr,
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse);
BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(enable_cuda_graph)) {}

private:
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
flashinfer::QKVLayout kv_layout_;
};

class CutlassSegmentGEMMPyTorchWrapper {
public:
void RegisterWorkspaceBuffer(torch::Tensor workspace_buffer);
Expand Down
32 changes: 32 additions & 0 deletions python/csrc/flashinfer_ops_decode.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>

#include "flashinfer_ops_decode.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache,
"Single-request decode with KV-Cache operator");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
"BatchDecodeWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool, unsigned int>())
.def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
}
59 changes: 59 additions & 0 deletions python/csrc/flashinfer_ops_decode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>

#include <flashinfer/attention/handler.cuh>
#include <flashinfer/layout.cuh>
#include <memory>

torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v,
torch::Tensor tmp, unsigned int pos_encoding_mode,
unsigned int layout, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta);

class BatchDecodeWithPagedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr,
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
unsigned int pos_encoding_mode, float logits_soft_cap,
torch::Tensor empty_q_data, torch::Tensor empty_kv_data);
void EndForward();
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
std::vector<torch::Tensor> Forward(torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len,
unsigned int pos_encoding_mode, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta, bool return_lse);
BatchDecodeWithPagedKVCachePyTorchWrapper(
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_ptr, flashinfer::QKVLayout kv_layout)
: handler_(handler_ptr), kv_layout_(kv_layout) {}
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph,
unsigned int fixed_batch_size)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(enable_cuda_graph,
fixed_batch_size)) {}

protected:
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_;
flashinfer::QKVLayout kv_layout_;
};
47 changes: 47 additions & 0 deletions python/csrc/flashinfer_ops_prefill.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>

#include "flashinfer_ops_prefill.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache,
"Single-request prefill with KV-Cache operator, return logsumexp");
m.def(
"single_prefill_with_kv_cache_custom_mask", &single_prefill_with_kv_cache_custom_mask,
"Single-request prefill with KV-Cache operator, user defined custom mask, return logsumexp");
py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>(
m, "BatchPrefillWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool>())
.def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward)
.def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask);
py::class_<BatchPrefillWithRaggedKVCachePyTorchWrapper>(
m, "BatchPrefillWithRaggedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool>())
.def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward)
.def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask);
}
Loading

0 comments on commit 8e482d9

Please # to comment.