Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Root/Push constants #22

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/FlyCube/BindingSetLayout/DXBindingSetLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ DXBindingSetLayout::DXBindingSetLayout(DXDevice& device, const std::vector<BindK
return root_param_index;
};

auto add_root_constant = [&](uint32_t shader_register, uint32_t register_space, uint32_t num_constants,
ShaderType shader_type) {
D3D12_ROOT_PARAMETER root_constant_param = {};
root_constant_param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS;
root_constant_param.Constants.Num32BitValues = num_constants;
root_constant_param.Constants.ShaderRegister = shader_register;
root_constant_param.Constants.RegisterSpace = register_space;
root_constant_param.ShaderVisibility = GetVisibility(shader_type);
root_parameters.push_back(root_constant_param);
};

auto add_bindless_range = [&](ShaderType shader_type, ViewType view_type, uint32_t base_slot, uint32_t space) {
auto& descriptor_table_range = bindless_ranges.emplace_back();
descriptor_table_range.RangeType = GetRangeType(view_type);
Expand All @@ -118,6 +129,11 @@ DXBindingSetLayout::DXBindingSetLayout(DXDevice& device, const std::vector<BindK
continue;
}

if (bind_key.is_root_constant) {
add_root_constant(bind_key.slot, bind_key.space, 1, bind_key.shader_type);
continue;
}

D3D12_DESCRIPTOR_HEAP_TYPE heap_type = GetHeapType(bind_key.view_type);
decltype(auto) layout = m_layout[bind_key];
layout.heap_type = heap_type;
Expand Down
2 changes: 2 additions & 0 deletions src/FlyCube/CommandList/CommandList.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,7 @@ class CommandList : public QueryInterface {
uint32_t query_count,
const std::shared_ptr<Resource>& dst_buffer,
uint64_t dst_offset) = 0;
virtual void SetGraphicsConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset) = 0;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to use PushConstants, to match it with [[vk::push_constant]] in hlsl.
void PushConstants(ShaderType shader_type, uint32_t dst_offset, const void* data, uint32_t size)

In Vulkan we can have at most 1 push_constant block in each shader, I suggest to add the same limitation in FlyCube. In this case ShaderType + active Pipeline will probably be enough to find root_parameter_index in DirectX12 implementation.

In my opinion PushConstants is enough to for all cases, but I don't mind having PushConstant as well.
void PushConstant(ShaderType shader_type, uint32_t dst_offset, uint32_t value)

virtual void SetComputeConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset) = 0;
virtual void SetName(const std::string& name) = 0;
};
10 changes: 10 additions & 0 deletions src/FlyCube/CommandList/DXCommandList.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,16 @@ void DXCommandList::ResolveQueryData(const std::shared_ptr<QueryHeap>& query_hea
D3D12_RESOURCE_STATE_UNORDERED_ACCESS, 0));
}

void DXCommandList::SetGraphicsConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current implementation, how do you know which root_parameter_index value you need to pass?

Copy link
Contributor Author

@alelievr alelievr Mar 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The root parameter index is the index of the binding set/layout set in the list given to CreateBindingSetLayout or CreateBindingSet.

{
m_command_list->SetGraphicsRoot32BitConstant(root_parameter_index, value, byte_offset);
}

void DXCommandList::SetComputeConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset)
{
m_command_list->SetComputeRoot32BitConstant(root_parameter_index, value, byte_offset);
}

ComPtr<ID3D12GraphicsCommandList> DXCommandList::GetCommandList()
{
return m_command_list;
Expand Down
2 changes: 2 additions & 0 deletions src/FlyCube/CommandList/DXCommandList.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class DXCommandList : public CommandList {
const std::shared_ptr<Resource>& dst_buffer,
uint64_t dst_offset) override;

void SetGraphicsConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset) override;
void SetComputeConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset) override;
void SetName(const std::string& name) override;

ComPtr<ID3D12GraphicsCommandList> GetCommandList();
Expand Down
2 changes: 2 additions & 0 deletions src/FlyCube/CommandList/MTCommandList.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class MTCommandList : public CommandList {
uint32_t query_count,
const std::shared_ptr<Resource>& dst_buffer,
uint64_t dst_offset) override;
void SetGraphicsConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset) override;
void SetComputeConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset) override;
void SetName(const std::string& name) override;

id<MTLCommandBuffer> GetCommandBuffer();
Expand Down
35 changes: 35 additions & 0 deletions src/FlyCube/CommandList/MTCommandList.mm
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,41 @@ MTLCullMode ConvertCullMode(CullMode cull_mode)
assert(false);
}

void MTCommandList::SetGraphicsConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset)
{
ApplyAndRecord([&command_buffer = m_command_buffer, value, byte_offset] {
// Use MTLRenderCommandEncoder for graphics (vertex/fragment) constants
id<MTLRenderCommandEncoder> render_encoder = [command_buffer renderCommandEncoder];

// Set the constant value for vertex and fragment shaders (example with vertex shader)
[render_encoder setVertexBytes:&value
length:sizeof(value)
atIndex:root_parameter_index]; // root_parameter_index used as the buffer index

// Optionally set for fragment shader if required
[render_encoder setFragmentBytes:&value
length:sizeof(value)
atIndex:root_parameter_index]; // root_parameter_index for fragment

[render_encoder endEncoding];
});
}

void MTCommandList::SetComputeConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset)
{
ApplyAndRecord([&command_buffer = m_command_buffer, value, byte_offset] {
// Use MTLComputeCommandEncoder for compute shader constants
id<MTLComputeCommandEncoder> compute_encoder = [command_buffer computeCommandEncoder];

// Set the constant value for compute shaders
[compute_encoder setBytes:&value
length:sizeof(value)
atIndex:root_parameter_index]; // root_parameter_index used as buffer index

[compute_encoder endEncoding];
});
}

void MTCommandList::SetName(const std::string& name)
{
ApplyAndRecord([&command_buffer = m_command_buffer, name] {
Expand Down
16 changes: 16 additions & 0 deletions src/FlyCube/CommandList/VKCommandList.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,22 @@ void VKCommandList::ResolveQueryData(const std::shared_ptr<QueryHeap>& query_hea
vk::QueryResultFlagBits::eWait);
}

void VKCommandList::SetGraphicsConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset)
{
decltype(auto) pipeline_layout = m_state->GetPipelineLayout();
m_command_list->pushConstants(pipeline_layout,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need to specify pipeline_layout_info.pPushConstantRanges in VKBindingSetLayout. But it probably won't work without information about push constant size.

vk::ShaderStageFlagBits::eVertex | vk::ShaderStageFlagBits::eFragment,
byte_offset,
sizeof(value),
&value);
}

void VKCommandList::SetComputeConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset)
{
decltype(auto) pipeline_layout = m_state->GetPipelineLayout();
m_command_list->pushConstants(pipeline_layout, vk::ShaderStageFlagBits::eCompute, byte_offset, sizeof(value), &value);
}

void VKCommandList::SetName(const std::string& name)
{
vk::DebugUtilsObjectNameInfoEXT info = {};
Expand Down
2 changes: 2 additions & 0 deletions src/FlyCube/CommandList/VKCommandList.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class VKCommandList : public CommandList {
uint32_t query_count,
const std::shared_ptr<Resource>& dst_buffer,
uint64_t dst_offset) override;
void SetGraphicsConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset) override;
void SetComputeConstant(uint32_t root_parameter_index, uint32_t value, uint32_t byte_offset) override;
void SetName(const std::string& name) override;

vk::CommandBuffer GetCommandList();
Expand Down
3 changes: 2 additions & 1 deletion src/FlyCube/Instance/BaseTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ struct BindKey {
uint32_t space = 0;
uint32_t count = 1;
uint32_t remapped_slot = ~0;
bool is_root_constant = false;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to use ViewType::kPushConstant instead.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually no, it's better to use ViewType::kConstantBuffer. Because ShaderReflection must be able to distinguish ViewType::kPushConstant from ViewType::kConstantBuffer, but there is no way to do so in dxil, since dxc ignore [[vk::push_constant]] when compile to dxil.

But since we still need to pass the size for push_constant, it is better to pass it separately from the BindKey.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I use ViewType::kConstantBuffer instead of this boolean, how can I know if the root BindKey should go to a root constant?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, we can separate them.

struct PushConstant {
    ShaderType shader_type = ShaderType::kUnknown;
    uint32_t size = 0;
};

std::shared_ptr<BindingSetLayout> CreateBindingSetLayout(const std::vector<BindKey>& descs, const std::vector<PushConstant>& push_constants = {});

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what is best to use, just ShaderType or BindKey.

  1. ShaderType
    At most 1 push constant per shader stage to match Vulkan limitation.
    Slot and space are 0 because spirv do not save bindings when [[vk::push_constant]] is used.
  2. BindKey
    Allow multiple push constant in DirectX12 and Metal, but Vulkan still supports at most 1 push constant per shader stage at slot/space 0 with using [[vk::push_constant]].
    More flexible, but harder to use in Vulkan.


uint32_t GetRemappedSlot() const
{
Expand All @@ -419,7 +420,7 @@ struct BindKey {

auto MakeTie() const
{
return std::tie(shader_type, view_type, slot, space, count);
return std::tie(shader_type, view_type, slot, space, count, is_root_constant);
}
};

Expand Down
Loading