From 8d55dda9f81d78549ea99895b755ed2d87c44c03 Mon Sep 17 00:00:00 2001 From: Romaric Jodin Date: Thu, 11 Jan 2024 15:32:15 +0100 Subject: [PATCH] fix local storage if there are no counters --- source/opt/vksp_passes.cpp | 52 ++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/source/opt/vksp_passes.cpp b/source/opt/vksp_passes.cpp index baf01992c1..8233860411 100644 --- a/source/opt/vksp_passes.cpp +++ b/source/opt/vksp_passes.cpp @@ -22,6 +22,8 @@ #include "spirv/unified1/spirv.hpp11" #include "vulkan/vulkan.h" +#define UNDEFINED_ID (UINT32_MAX) + namespace spvtools { namespace opt { @@ -408,12 +410,16 @@ void ExtractVkspReflectInfoPass::CreateVariables( {SPV_OPERAND_TYPE_DECORATION, {(uint32_t)spv::Decoration::Block}}}); module->AddAnnotationInst(std::unique_ptr(decorate_arr_st_inst)); - local_counters_id = context()->TakeNextId(); - auto local_counters_inst = new Instruction( - context(), spv::Op::OpVariable, local_counters_ty_id, local_counters_id, - {{SPV_OPERAND_TYPE_LITERAL_INTEGER, - {(uint32_t)spv::StorageClass::Private}}}); - module->AddGlobalValue(std::unique_ptr(local_counters_inst)); + if (local_counters_ty_id != UNDEFINED_ID) { + local_counters_id = context()->TakeNextId(); + auto local_counters_inst = new Instruction( + context(), spv::Op::OpVariable, local_counters_ty_id, local_counters_id, + {{SPV_OPERAND_TYPE_LITERAL_INTEGER, + {(uint32_t)spv::StorageClass::Private}}}); + module->AddGlobalValue(std::unique_ptr(local_counters_inst)); + } else { + local_counters_id = UNDEFINED_ID; + } global_counters_id = context()->TakeNextId(); auto global_counters_inst = new Instruction( @@ -447,7 +453,9 @@ void ExtractVkspReflectInfoPass::CreatePrologue( Function*& function, uint32_t& read_clock_id) { auto* cst_mgr = context()->get_constant_mgr(); entry_point_inst->AddOperand({SPV_OPERAND_TYPE_ID, {global_counters_id}}); - entry_point_inst->AddOperand({SPV_OPERAND_TYPE_ID, {local_counters_id}}); + if (local_counters_id != UNDEFINED_ID) { + entry_point_inst->AddOperand({SPV_OPERAND_TYPE_ID, {local_counters_id}}); + } auto function_id = entry_point_inst->GetOperand(1).AsId(); function = context()->GetFunction(function_id); @@ -696,23 +704,29 @@ Pass::Status ExtractVkspReflectInfoPass::Process() { analysis::Pointer u64_ty_ptr(u64_ty, spv::StorageClass::StorageBuffer); auto u64_ptr_ty = type_mgr->GetRegisteredType(&u64_ty_ptr); - analysis::Array arr( - u64_ty, analysis::Array::LengthInfo{ - cst_mgr->GetUIntConstId((uint32_t)start_counters.size()), - {0, (uint32_t)start_counters.size()}}); - auto u64_arr_ty = type_mgr->GetRegisteredType(&arr); - analysis::Pointer u64_arr_ty_ptr(u64_arr_ty, spv::StorageClass::Private); - auto u64_arr_ptr_ty = type_mgr->GetRegisteredType(&u64_arr_ty_ptr); - analysis::Pointer u64_ty_ptr_private(u64_ty, spv::StorageClass::Private); - auto u64_private_ptr_ty = type_mgr->GetRegisteredType(&u64_ty_ptr_private); - - auto local_counters_ty_id = type_mgr->GetId(u64_arr_ptr_ty); auto counters_ty_id = type_mgr->GetId(u64_run_arr_st_ptr_ty); auto u64_ty_id = type_mgr->GetId(u64_ty); auto u64_ptr_ty_id = type_mgr->GetId(u64_ptr_ty); auto u64_arr_ty_id = type_mgr->GetId(u64_run_arr_ty); auto u64_arr_st_ty_id = type_mgr->GetId(u64_run_arr_st_ty); - auto u64_private_ptr_ty_id = type_mgr->GetId(u64_private_ptr_ty); + + uint32_t local_counters_ty_id = UNDEFINED_ID; + uint32_t u64_private_ptr_ty_id = UNDEFINED_ID; + + if (start_counters.size() > 0) { + analysis::Array arr( + u64_ty, analysis::Array::LengthInfo{ + cst_mgr->GetUIntConstId((uint32_t)start_counters.size()), + {0, (uint32_t)start_counters.size()}}); + auto u64_arr_ty = type_mgr->GetRegisteredType(&arr); + analysis::Pointer u64_arr_ty_ptr(u64_arr_ty, spv::StorageClass::Private); + auto u64_arr_ptr_ty = type_mgr->GetRegisteredType(&u64_arr_ty_ptr); + analysis::Pointer u64_ty_ptr_private(u64_ty, spv::StorageClass::Private); + auto u64_private_ptr_ty = type_mgr->GetRegisteredType(&u64_ty_ptr_private); + + local_counters_ty_id = type_mgr->GetId(u64_arr_ptr_ty); + u64_private_ptr_ty_id = type_mgr->GetId(u64_private_ptr_ty); + } auto subgroup_scope_id = cst_mgr->GetUIntConstId((uint32_t)spv::Scope::Subgroup);