From 1116237ac1e5690cf404841327b58b1d268d9951 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 10 Jul 2024 16:10:02 -0700 Subject: [PATCH] perf: Optimize tensor conversions in C++ code to avoid unnecessary copies (#366) Small tweak to avoid unnecessary copying by combining `to` calls. Discovered during profiling. --- python/csrc/batch_decode.cu | 2 +- python/csrc/batch_prefill.cu | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 110e918d..549a906a 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -154,7 +154,7 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor o = torch::empty_like(q); torch::Tensor lse; if (return_lse) { - lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32); + lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32))); } TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 85b92e3d..03ab80be 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -31,8 +31,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); CHECK_DIM(1, qo_indptr); CHECK_DIM(1, workspace_buffer); - qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32); - paged_kv_indptr = paged_kv_indptr.to(torch::kCPU).to(torch::kInt32); + qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); + paged_kv_indptr = paged_kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); auto device = workspace_buffer.device(); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); @@ -111,7 +111,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor o = torch::empty_like(q, q.options()); torch::Tensor lse = torch::empty({0}); if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); + lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); } MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); @@ -226,7 +226,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu torch::Tensor o = torch::empty_like(q, q.options()); torch::Tensor lse = torch::empty({0}); if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); + lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); } constexpr MaskMode MASK_MODE = MaskMode::kCustom; TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); @@ -288,8 +288,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); CHECK_DIM(1, qo_indptr); CHECK_DIM(1, workspace_buffer); - qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32); - kv_indptr = kv_indptr.to(torch::kCPU).to(torch::kInt32); + qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); + kv_indptr = kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); auto device = workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); @@ -354,7 +354,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( torch::Tensor o = torch::empty_like(q, q.options()); torch::Tensor lse = torch::empty({0}); if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); + lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); } MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; @@ -452,7 +452,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC torch::Tensor o = torch::empty_like(q, q.options()); torch::Tensor lse = torch::empty({0}); if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); + lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype((torch::kFloat32))); } constexpr MaskMode MASK_MODE = MaskMode::kCustom;