From c8b71c669a93685ac435fed7c7ed81c0a89a2c00 Mon Sep 17 00:00:00 2001 From: Diego Novillo Date: Mon, 8 Apr 2024 16:24:26 -0400 Subject: [PATCH] Fix rebuilding types with circular references (#5623). This fixes the problem reported in #5623 using the observation that if we are re-building a type that already exists in the type pool, we should just return that type. This makes type rebuilding more efficient, and it also prevents the type builder from getting itself into infinite recursion (as reported in this issue). In fixing this, I found a couple of other bugs in the type builder: - When rebuilding an Array type, we were not re-building the element type. This caused stale type references in the rebuilt type. - This bug had not been caught by the test, because the test itself had a bug in it: the test was rebuilding types on top of the same ID (the ID counter was never incremented). Initially, the bug in the test caused a failure with the new logic in the builder because we now return types from the pool directly, which causes a failure when two incompatible types are registered under the same ID. Fixing that issue in the test exposed another bug in the rebuilder: we were not re-building the element type for Array types. This was causing a stale type reference inside Array types which was later caught by the type removal logic in the test. --- source/opt/type_manager.cpp | 64 +++++++++++++++++++++------------- source/opt/type_manager.h | 4 ++- test/opt/type_manager_test.cpp | 38 ++++++++++++++++++-- 3 files changed, 79 insertions(+), 27 deletions(-) diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp index ae320772df..7b609bc776 100644 --- a/source/opt/type_manager.cpp +++ b/source/opt/type_manager.cpp @@ -517,13 +517,24 @@ void TypeManager::CreateDecoration(uint32_t target, context()->get_def_use_mgr()->AnalyzeInstUse(inst); } -Type* TypeManager::RebuildType(const Type& type) { +Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) { + assert(type_id != 0); + // The comparison and hash on the type pool will avoid inserting the rebuilt // type if an equivalent type already exists. The rebuilt type will be deleted // when it goes out of scope at the end of the function in that case. Repeated // insertions of the same Type will, at most, keep one corresponding object in // the type pool. std::unique_ptr rebuilt_ty; + + // If |type_id| is already present in the type pool, return the existing type. + // This saves extra work in the type builder and prevents running into + // circular issues (https://github.com/KhronosGroup/SPIRV-Tools/issues/5623). + Type* pool_ty = GetType(type_id); + if (pool_ty != nullptr) { + return pool_ty; + } + switch (type.kind()) { #define DefineNoSubtypeCase(kind) \ case Type::k##kind: \ @@ -550,43 +561,46 @@ Type* TypeManager::RebuildType(const Type& type) { case Type::kVector: { const Vector* vec_ty = type.AsVector(); const Type* ele_ty = vec_ty->element_type(); - rebuilt_ty = - MakeUnique(RebuildType(*ele_ty), vec_ty->element_count()); + rebuilt_ty = MakeUnique(RebuildType(GetId(ele_ty), *ele_ty), + vec_ty->element_count()); break; } case Type::kMatrix: { const Matrix* mat_ty = type.AsMatrix(); const Type* ele_ty = mat_ty->element_type(); - rebuilt_ty = - MakeUnique(RebuildType(*ele_ty), mat_ty->element_count()); + rebuilt_ty = MakeUnique(RebuildType(GetId(ele_ty), *ele_ty), + mat_ty->element_count()); break; } case Type::kImage: { const Image* image_ty = type.AsImage(); const Type* ele_ty = image_ty->sampled_type(); - rebuilt_ty = - MakeUnique(RebuildType(*ele_ty), image_ty->dim(), - image_ty->depth(), image_ty->is_arrayed(), - image_ty->is_multisampled(), image_ty->sampled(), - image_ty->format(), image_ty->access_qualifier()); + rebuilt_ty = MakeUnique( + RebuildType(GetId(ele_ty), *ele_ty), image_ty->dim(), + image_ty->depth(), image_ty->is_arrayed(), + image_ty->is_multisampled(), image_ty->sampled(), image_ty->format(), + image_ty->access_qualifier()); break; } case Type::kSampledImage: { const SampledImage* image_ty = type.AsSampledImage(); const Type* ele_ty = image_ty->image_type(); - rebuilt_ty = MakeUnique(RebuildType(*ele_ty)); + rebuilt_ty = + MakeUnique(RebuildType(GetId(ele_ty), *ele_ty)); break; } case Type::kArray: { const Array* array_ty = type.AsArray(); - rebuilt_ty = - MakeUnique(array_ty->element_type(), array_ty->length_info()); + const Type* ele_ty = array_ty->element_type(); + rebuilt_ty = MakeUnique(RebuildType(GetId(ele_ty), *ele_ty), + array_ty->length_info()); break; } case Type::kRuntimeArray: { const RuntimeArray* array_ty = type.AsRuntimeArray(); const Type* ele_ty = array_ty->element_type(); - rebuilt_ty = MakeUnique(RebuildType(*ele_ty)); + rebuilt_ty = + MakeUnique(RebuildType(GetId(ele_ty), *ele_ty)); break; } case Type::kStruct: { @@ -594,7 +608,7 @@ Type* TypeManager::RebuildType(const Type& type) { std::vector subtypes; subtypes.reserve(struct_ty->element_types().size()); for (const auto* ele_ty : struct_ty->element_types()) { - subtypes.push_back(RebuildType(*ele_ty)); + subtypes.push_back(RebuildType(GetId(ele_ty), *ele_ty)); } rebuilt_ty = MakeUnique(subtypes); Struct* rebuilt_struct = rebuilt_ty->AsStruct(); @@ -611,7 +625,7 @@ Type* TypeManager::RebuildType(const Type& type) { case Type::kPointer: { const Pointer* pointer_ty = type.AsPointer(); const Type* ele_ty = pointer_ty->pointee_type(); - rebuilt_ty = MakeUnique(RebuildType(*ele_ty), + rebuilt_ty = MakeUnique(RebuildType(GetId(ele_ty), *ele_ty), pointer_ty->storage_class()); break; } @@ -621,9 +635,10 @@ Type* TypeManager::RebuildType(const Type& type) { std::vector param_types; param_types.reserve(function_ty->param_types().size()); for (const auto* param_ty : function_ty->param_types()) { - param_types.push_back(RebuildType(*param_ty)); + param_types.push_back(RebuildType(GetId(param_ty), *param_ty)); } - rebuilt_ty = MakeUnique(RebuildType(*ret_ty), param_types); + rebuilt_ty = MakeUnique(RebuildType(GetId(ret_ty), *ret_ty), + param_types); break; } case Type::kForwardPointer: { @@ -633,7 +648,7 @@ Type* TypeManager::RebuildType(const Type& type) { const Pointer* target_ptr = forward_ptr_ty->target_pointer(); if (target_ptr) { rebuilt_ty->AsForwardPointer()->SetTargetPointer( - RebuildType(*target_ptr)->AsPointer()); + RebuildType(GetId(target_ptr), *target_ptr)->AsPointer()); } break; } @@ -641,16 +656,17 @@ Type* TypeManager::RebuildType(const Type& type) { const CooperativeMatrixNV* cm_type = type.AsCooperativeMatrixNV(); const Type* component_type = cm_type->component_type(); rebuilt_ty = MakeUnique( - RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(), - cm_type->columns_id()); + RebuildType(GetId(component_type), *component_type), + cm_type->scope_id(), cm_type->rows_id(), cm_type->columns_id()); break; } case Type::kCooperativeMatrixKHR: { const CooperativeMatrixKHR* cm_type = type.AsCooperativeMatrixKHR(); const Type* component_type = cm_type->component_type(); rebuilt_ty = MakeUnique( - RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(), - cm_type->columns_id(), cm_type->use_id()); + RebuildType(GetId(component_type), *component_type), + cm_type->scope_id(), cm_type->rows_id(), cm_type->columns_id(), + cm_type->use_id()); break; } default: @@ -669,7 +685,7 @@ Type* TypeManager::RebuildType(const Type& type) { void TypeManager::RegisterType(uint32_t id, const Type& type) { // Rebuild |type| so it and all its constituent types are owned by the type // pool. - Type* rebuilt = RebuildType(type); + Type* rebuilt = RebuildType(id, type); assert(rebuilt->IsSame(&type)); id_to_type_[id] = rebuilt; if (GetId(rebuilt) == 0) { diff --git a/source/opt/type_manager.h b/source/opt/type_manager.h index a70c371db0..948b691bac 100644 --- a/source/opt/type_manager.h +++ b/source/opt/type_manager.h @@ -260,7 +260,9 @@ class TypeManager { // Returns an equivalent pointer to |type| built in terms of pointers owned by // |type_pool_|. For example, if |type| is a vec3 of bool, it will be rebuilt // replacing the bool subtype with one owned by |type_pool_|. - Type* RebuildType(const Type& type); + // + // The re-built type will have ID |type_id|. + Type* RebuildType(uint32_t type_id, const Type& type); // Completes the incomplete type |type|, by replaces all references to // ForwardPointer by the defining Pointer. diff --git a/test/opt/type_manager_test.cpp b/test/opt/type_manager_test.cpp index 946f06cc08..d4d0fef524 100644 --- a/test/opt/type_manager_test.cpp +++ b/test/opt/type_manager_test.cpp @@ -942,10 +942,11 @@ OpMemoryModel Logical GLSL450 EXPECT_NE(context, nullptr); std::vector> types = GenerateAllTypes(); - uint32_t id = 1u; + uint32_t id = 0u; for (auto& t : types) { - context->get_type_mgr()->RegisterType(id, *t); + context->get_type_mgr()->RegisterType(++id, *t); EXPECT_EQ(*t, *context->get_type_mgr()->GetType(id)); + EXPECT_EQ(id, context->get_type_mgr()->GetId(t.get())); } types.clear(); @@ -1199,6 +1200,39 @@ OpMemoryModel Logical GLSL450 Match(text, context.get()); } +// Structures containing circular type references +// (from https://github.com/KhronosGroup/SPIRV-Tools/issues/5623). +TEST(TypeManager, CircularPointerToStruct) { + const std::string text = R"( + OpCapability VariablePointers + OpCapability PhysicalStorageBufferAddresses + OpCapability Int64 + OpCapability Shader + OpExtension "SPV_KHR_variable_pointers" + OpExtension "SPV_KHR_physical_storage_buffer" + OpMemoryModel PhysicalStorageBuffer64 GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + OpExecutionMode %1 DepthReplacing + OpDecorate %1200 ArrayStride 24 + OpMemberDecorate %600 0 Offset 0 + OpMemberDecorate %800 0 Offset 0 + OpMemberDecorate %120 0 Offset 16 + OpTypeForwardPointer %1200 PhysicalStorageBuffer + %600 = OpTypeStruct %1200 + %800 = OpTypeStruct %1200 + %120 = OpTypeStruct %800 + %1200 = OpTypePointer PhysicalStorageBuffer %120 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + TypeManager manager(nullptr, context.get()); + uint32_t id = manager.FindPointerToType(600, spv::StorageClass::Function); + EXPECT_EQ(id, 1201); +} + } // namespace } // namespace analysis } // namespace opt