Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix rebuilding types with circular references (#5623). #5637

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions source/opt/type_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,13 +517,24 @@ void TypeManager::CreateDecoration(uint32_t target,
context()->get_def_use_mgr()->AnalyzeInstUse(inst);
}

Type* TypeManager::RebuildType(const Type& type) {
Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
assert(type_id != 0);

// The comparison and hash on the type pool will avoid inserting the rebuilt
// type if an equivalent type already exists. The rebuilt type will be deleted
// when it goes out of scope at the end of the function in that case. Repeated
// insertions of the same Type will, at most, keep one corresponding object in
// the type pool.
std::unique_ptr<Type> rebuilt_ty;

// If |type_id| is already present in the type pool, return the existing type.
// This saves extra work in the type builder and prevents running into
// circular issues (https://github.com/KhronosGroup/SPIRV-Tools/issues/5623).
Type* pool_ty = GetType(type_id);
if (pool_ty != nullptr) {
return pool_ty;
}

switch (type.kind()) {
#define DefineNoSubtypeCase(kind) \
case Type::k##kind: \
Expand All @@ -550,51 +561,54 @@ Type* TypeManager::RebuildType(const Type& type) {
case Type::kVector: {
const Vector* vec_ty = type.AsVector();
const Type* ele_ty = vec_ty->element_type();
rebuilt_ty =
MakeUnique<Vector>(RebuildType(*ele_ty), vec_ty->element_count());
rebuilt_ty = MakeUnique<Vector>(RebuildType(GetId(ele_ty), *ele_ty),
vec_ty->element_count());
break;
}
case Type::kMatrix: {
const Matrix* mat_ty = type.AsMatrix();
const Type* ele_ty = mat_ty->element_type();
rebuilt_ty =
MakeUnique<Matrix>(RebuildType(*ele_ty), mat_ty->element_count());
rebuilt_ty = MakeUnique<Matrix>(RebuildType(GetId(ele_ty), *ele_ty),
mat_ty->element_count());
break;
}
case Type::kImage: {
const Image* image_ty = type.AsImage();
const Type* ele_ty = image_ty->sampled_type();
rebuilt_ty =
MakeUnique<Image>(RebuildType(*ele_ty), image_ty->dim(),
image_ty->depth(), image_ty->is_arrayed(),
image_ty->is_multisampled(), image_ty->sampled(),
image_ty->format(), image_ty->access_qualifier());
rebuilt_ty = MakeUnique<Image>(
RebuildType(GetId(ele_ty), *ele_ty), image_ty->dim(),
image_ty->depth(), image_ty->is_arrayed(),
image_ty->is_multisampled(), image_ty->sampled(), image_ty->format(),
image_ty->access_qualifier());
break;
}
case Type::kSampledImage: {
const SampledImage* image_ty = type.AsSampledImage();
const Type* ele_ty = image_ty->image_type();
rebuilt_ty = MakeUnique<SampledImage>(RebuildType(*ele_ty));
rebuilt_ty =
MakeUnique<SampledImage>(RebuildType(GetId(ele_ty), *ele_ty));
break;
}
case Type::kArray: {
const Array* array_ty = type.AsArray();
rebuilt_ty =
MakeUnique<Array>(array_ty->element_type(), array_ty->length_info());
const Type* ele_ty = array_ty->element_type();
rebuilt_ty = MakeUnique<Array>(RebuildType(GetId(ele_ty), *ele_ty),
array_ty->length_info());
break;
}
case Type::kRuntimeArray: {
const RuntimeArray* array_ty = type.AsRuntimeArray();
const Type* ele_ty = array_ty->element_type();
rebuilt_ty = MakeUnique<RuntimeArray>(RebuildType(*ele_ty));
rebuilt_ty =
MakeUnique<RuntimeArray>(RebuildType(GetId(ele_ty), *ele_ty));
break;
}
case Type::kStruct: {
const Struct* struct_ty = type.AsStruct();
std::vector<const Type*> subtypes;
subtypes.reserve(struct_ty->element_types().size());
for (const auto* ele_ty : struct_ty->element_types()) {
subtypes.push_back(RebuildType(*ele_ty));
subtypes.push_back(RebuildType(GetId(ele_ty), *ele_ty));
}
rebuilt_ty = MakeUnique<Struct>(subtypes);
Struct* rebuilt_struct = rebuilt_ty->AsStruct();
Expand All @@ -611,7 +625,7 @@ Type* TypeManager::RebuildType(const Type& type) {
case Type::kPointer: {
const Pointer* pointer_ty = type.AsPointer();
const Type* ele_ty = pointer_ty->pointee_type();
rebuilt_ty = MakeUnique<Pointer>(RebuildType(*ele_ty),
rebuilt_ty = MakeUnique<Pointer>(RebuildType(GetId(ele_ty), *ele_ty),
pointer_ty->storage_class());
break;
}
Expand All @@ -621,9 +635,10 @@ Type* TypeManager::RebuildType(const Type& type) {
std::vector<const Type*> param_types;
param_types.reserve(function_ty->param_types().size());
for (const auto* param_ty : function_ty->param_types()) {
param_types.push_back(RebuildType(*param_ty));
param_types.push_back(RebuildType(GetId(param_ty), *param_ty));
}
rebuilt_ty = MakeUnique<Function>(RebuildType(*ret_ty), param_types);
rebuilt_ty = MakeUnique<Function>(RebuildType(GetId(ret_ty), *ret_ty),
param_types);
break;
}
case Type::kForwardPointer: {
Expand All @@ -633,24 +648,25 @@ Type* TypeManager::RebuildType(const Type& type) {
const Pointer* target_ptr = forward_ptr_ty->target_pointer();
if (target_ptr) {
rebuilt_ty->AsForwardPointer()->SetTargetPointer(
RebuildType(*target_ptr)->AsPointer());
RebuildType(GetId(target_ptr), *target_ptr)->AsPointer());
}
break;
}
case Type::kCooperativeMatrixNV: {
const CooperativeMatrixNV* cm_type = type.AsCooperativeMatrixNV();
const Type* component_type = cm_type->component_type();
rebuilt_ty = MakeUnique<CooperativeMatrixNV>(
RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
cm_type->columns_id());
RebuildType(GetId(component_type), *component_type),
cm_type->scope_id(), cm_type->rows_id(), cm_type->columns_id());
break;
}
case Type::kCooperativeMatrixKHR: {
const CooperativeMatrixKHR* cm_type = type.AsCooperativeMatrixKHR();
const Type* component_type = cm_type->component_type();
rebuilt_ty = MakeUnique<CooperativeMatrixKHR>(
RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
cm_type->columns_id(), cm_type->use_id());
RebuildType(GetId(component_type), *component_type),
cm_type->scope_id(), cm_type->rows_id(), cm_type->columns_id(),
cm_type->use_id());
break;
}
default:
Expand All @@ -669,7 +685,7 @@ Type* TypeManager::RebuildType(const Type& type) {
void TypeManager::RegisterType(uint32_t id, const Type& type) {
// Rebuild |type| so it and all its constituent types are owned by the type
// pool.
Type* rebuilt = RebuildType(type);
Type* rebuilt = RebuildType(id, type);
assert(rebuilt->IsSame(&type));
id_to_type_[id] = rebuilt;
if (GetId(rebuilt) == 0) {
Expand Down
4 changes: 3 additions & 1 deletion source/opt/type_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,9 @@ class TypeManager {
// Returns an equivalent pointer to |type| built in terms of pointers owned by
// |type_pool_|. For example, if |type| is a vec3 of bool, it will be rebuilt
// replacing the bool subtype with one owned by |type_pool_|.
Type* RebuildType(const Type& type);
//
// The re-built type will have ID |type_id|.
Type* RebuildType(uint32_t type_id, const Type& type);

// Completes the incomplete type |type|, by replaces all references to
// ForwardPointer by the defining Pointer.
Expand Down
38 changes: 36 additions & 2 deletions test/opt/type_manager_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,10 +942,11 @@ OpMemoryModel Logical GLSL450
EXPECT_NE(context, nullptr);

std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
uint32_t id = 1u;
uint32_t id = 0u;
for (auto& t : types) {
context->get_type_mgr()->RegisterType(id, *t);
context->get_type_mgr()->RegisterType(++id, *t);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow. Nice catch!

EXPECT_EQ(*t, *context->get_type_mgr()->GetType(id));
EXPECT_EQ(id, context->get_type_mgr()->GetId(t.get()));
}
types.clear();

Expand Down Expand Up @@ -1199,6 +1200,39 @@ OpMemoryModel Logical GLSL450
Match(text, context.get());
}

// Structures containing circular type references
// (from https://github.com/KhronosGroup/SPIRV-Tools/issues/5623).
TEST(TypeManager, CircularPointerToStruct) {
const std::string text = R"(
OpCapability VariablePointers
OpCapability PhysicalStorageBufferAddresses
OpCapability Int64
OpCapability Shader
OpExtension "SPV_KHR_variable_pointers"
OpExtension "SPV_KHR_physical_storage_buffer"
OpMemoryModel PhysicalStorageBuffer64 GLSL450
OpEntryPoint Fragment %1 "main"
OpExecutionMode %1 OriginUpperLeft
OpExecutionMode %1 DepthReplacing
OpDecorate %1200 ArrayStride 24
OpMemberDecorate %600 0 Offset 0
OpMemberDecorate %800 0 Offset 0
OpMemberDecorate %120 0 Offset 16
OpTypeForwardPointer %1200 PhysicalStorageBuffer
%600 = OpTypeStruct %1200
%800 = OpTypeStruct %1200
%120 = OpTypeStruct %800
%1200 = OpTypePointer PhysicalStorageBuffer %120
)";

std::unique_ptr<IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
TypeManager manager(nullptr, context.get());
uint32_t id = manager.FindPointerToType(600, spv::StorageClass::Function);
EXPECT_EQ(id, 1201);
}

} // namespace
} // namespace analysis
} // namespace opt
Expand Down