Skip to content

Commit 630f9c4

Browse files
committed
fix: support dict type for input in shape analysis
Signed-off-by: Bo Wang <bowa@nvidia.com>
1 parent 1c294fa commit 630f9c4

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

core/partitioning/shape_analysis.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ void getSegmentsOutputByRunning(
8484
jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple());
8585
} else if (input->type()->kind() == torch::jit::TypeKind::NumberType) {
8686
jit_inputs_ivalues.push_back(ivalues_maps[input].toScalar());
87+
} else if (input->type()->kind() == torch::jit::TypeKind::DictType) {
88+
jit_inputs_ivalues.push_back(ivalues_maps[input].toGenericDict());
8789
} else {
8890
TORCHTRT_THROW_ERROR(
8991
"Expected to find type " << input->type()->str() << " for value " << input->debugName()

0 commit comments

Comments
 (0)