diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index e773b8f321..a899b3313e 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -84,6 +84,8 @@ void getSegmentsOutputByRunning( jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple()); } else if (input->type()->kind() == torch::jit::TypeKind::NumberType) { jit_inputs_ivalues.push_back(ivalues_maps[input].toScalar()); + } else if (input->type()->kind() == torch::jit::TypeKind::DictType) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toGenericDict()); } else { TORCHTRT_THROW_ERROR( "Expected to find type " << input->type()->str() << " for value " << input->debugName()