Skip to content

Commit 736f4af

Browse files
committed
[ET-VK] Moving device capabilities check to DispatchNode and PrepackNode ctor.
The changes in this diff move the device capabilities check from the encode method to the constructor of DispatchNode and PrepackNode classes. Differential Revision: [D74481839](https://our.internmc.facebook.com/intern/diff/D74481839/) ghstack-source-id: 283053996 Pull Request resolved: #10785
1 parent 7221f42 commit 736f4af

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ DispatchNode::DispatchNode(
3333
spec_vars_(spec_vars),
3434
push_constants_(push_constants) {
3535
graph.update_descriptor_counts(shader, /*execute = */ true);
36+
graph.context()->check_device_capabilities(shader_);
3637
}
3738

3839
void DispatchNode::encode(ComputeGraph* graph) {
@@ -42,8 +43,6 @@ void DispatchNode::encode(ComputeGraph* graph) {
4243
api::Context* const context = graph->context();
4344
vkapi::PipelineBarrier pipeline_barrier{};
4445

45-
context->check_device_capabilities(shader_);
46-
4746
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
4847

4948
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ PrepackNode::PrepackNode(
4545
push_constants_(push_constants) {
4646
graph.update_descriptor_counts(shader, /*execute = */ false);
4747
graph.update_descriptor_counts(noop_shader_, /*execute = */ false);
48+
graph.context()->check_device_capabilities(shader_);
4849
}
4950

5051
api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
@@ -70,8 +71,6 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
7071
void PrepackNode::encode(ComputeGraph* graph) {
7172
api::Context* const context = graph->context();
7273

73-
context->check_device_capabilities(shader_);
74-
7574
vTensorPtr packed = graph->get_tensor(packed_);
7675
api::StagingBuffer staging = create_staging_buffer(graph);
7776

0 commit comments

Comments
 (0)