Skip to content

Commit

Permalink
fix local storage if there are no counters
Browse files Browse the repository at this point in the history
  • Loading branch information
rjodinchr committed Jan 13, 2024
1 parent 3edc6ee commit 8d55dda
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions source/opt/vksp_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "spirv/unified1/spirv.hpp11"
#include "vulkan/vulkan.h"

#define UNDEFINED_ID (UINT32_MAX)

namespace spvtools {
namespace opt {

Expand Down Expand Up @@ -408,12 +410,16 @@ void ExtractVkspReflectInfoPass::CreateVariables(
{SPV_OPERAND_TYPE_DECORATION, {(uint32_t)spv::Decoration::Block}}});
module->AddAnnotationInst(std::unique_ptr<Instruction>(decorate_arr_st_inst));

local_counters_id = context()->TakeNextId();
auto local_counters_inst = new Instruction(
context(), spv::Op::OpVariable, local_counters_ty_id, local_counters_id,
{{SPV_OPERAND_TYPE_LITERAL_INTEGER,
{(uint32_t)spv::StorageClass::Private}}});
module->AddGlobalValue(std::unique_ptr<Instruction>(local_counters_inst));
if (local_counters_ty_id != UNDEFINED_ID) {
local_counters_id = context()->TakeNextId();
auto local_counters_inst = new Instruction(
context(), spv::Op::OpVariable, local_counters_ty_id, local_counters_id,
{{SPV_OPERAND_TYPE_LITERAL_INTEGER,
{(uint32_t)spv::StorageClass::Private}}});
module->AddGlobalValue(std::unique_ptr<Instruction>(local_counters_inst));
} else {
local_counters_id = UNDEFINED_ID;
}

global_counters_id = context()->TakeNextId();
auto global_counters_inst = new Instruction(
Expand Down Expand Up @@ -447,7 +453,9 @@ void ExtractVkspReflectInfoPass::CreatePrologue(
Function*& function, uint32_t& read_clock_id) {
auto* cst_mgr = context()->get_constant_mgr();
entry_point_inst->AddOperand({SPV_OPERAND_TYPE_ID, {global_counters_id}});
entry_point_inst->AddOperand({SPV_OPERAND_TYPE_ID, {local_counters_id}});
if (local_counters_id != UNDEFINED_ID) {
entry_point_inst->AddOperand({SPV_OPERAND_TYPE_ID, {local_counters_id}});
}

auto function_id = entry_point_inst->GetOperand(1).AsId();
function = context()->GetFunction(function_id);
Expand Down Expand Up @@ -696,23 +704,29 @@ Pass::Status ExtractVkspReflectInfoPass::Process() {
analysis::Pointer u64_ty_ptr(u64_ty, spv::StorageClass::StorageBuffer);
auto u64_ptr_ty = type_mgr->GetRegisteredType(&u64_ty_ptr);

analysis::Array arr(
u64_ty, analysis::Array::LengthInfo{
cst_mgr->GetUIntConstId((uint32_t)start_counters.size()),
{0, (uint32_t)start_counters.size()}});
auto u64_arr_ty = type_mgr->GetRegisteredType(&arr);
analysis::Pointer u64_arr_ty_ptr(u64_arr_ty, spv::StorageClass::Private);
auto u64_arr_ptr_ty = type_mgr->GetRegisteredType(&u64_arr_ty_ptr);
analysis::Pointer u64_ty_ptr_private(u64_ty, spv::StorageClass::Private);
auto u64_private_ptr_ty = type_mgr->GetRegisteredType(&u64_ty_ptr_private);

auto local_counters_ty_id = type_mgr->GetId(u64_arr_ptr_ty);
auto counters_ty_id = type_mgr->GetId(u64_run_arr_st_ptr_ty);
auto u64_ty_id = type_mgr->GetId(u64_ty);
auto u64_ptr_ty_id = type_mgr->GetId(u64_ptr_ty);
auto u64_arr_ty_id = type_mgr->GetId(u64_run_arr_ty);
auto u64_arr_st_ty_id = type_mgr->GetId(u64_run_arr_st_ty);
auto u64_private_ptr_ty_id = type_mgr->GetId(u64_private_ptr_ty);

uint32_t local_counters_ty_id = UNDEFINED_ID;
uint32_t u64_private_ptr_ty_id = UNDEFINED_ID;

if (start_counters.size() > 0) {
analysis::Array arr(
u64_ty, analysis::Array::LengthInfo{
cst_mgr->GetUIntConstId((uint32_t)start_counters.size()),
{0, (uint32_t)start_counters.size()}});
auto u64_arr_ty = type_mgr->GetRegisteredType(&arr);
analysis::Pointer u64_arr_ty_ptr(u64_arr_ty, spv::StorageClass::Private);
auto u64_arr_ptr_ty = type_mgr->GetRegisteredType(&u64_arr_ty_ptr);
analysis::Pointer u64_ty_ptr_private(u64_ty, spv::StorageClass::Private);
auto u64_private_ptr_ty = type_mgr->GetRegisteredType(&u64_ty_ptr_private);

local_counters_ty_id = type_mgr->GetId(u64_arr_ptr_ty);
u64_private_ptr_ty_id = type_mgr->GetId(u64_private_ptr_ty);
}

auto subgroup_scope_id =
cst_mgr->GetUIntConstId((uint32_t)spv::Scope::Subgroup);
Expand Down

0 comments on commit 8d55dda

Please # to comment.