diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index e2e2bfe150..aa6c0055de 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -81,8 +81,13 @@ void getSegmentsOutputByRunning( jit_inputs_ivalues.push_back(ivalues_maps[input].toList()); } else if (input->type()->kind() == torch::jit::TypeKind::TupleType) { jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple()); + } else if (input->type()->kind() == torch::jit::TypeKind::NumberType) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toScalar()); } else { - TORCHTRT_THROW_ERROR("Unable to find type for value: " << input->debugName() << " to get the ivalues.\n"); + TORCHTRT_THROW_ERROR( + "Unable to find type for value: " << input->debugName() + << " to get the ivalues. The type for this value should be " + << input->type()->str() << " \n"); } } @@ -110,28 +115,31 @@ void getSegmentsOutputByRunning( for (auto& i : seg_block.raw_inputs()) { if (ivalues_maps[i].isTensor()) { // set the input_shape and data_type - at::ScalarType t = ivalues_maps[i].toTensor().scalar_type(); + // we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for + // shape inference + auto cur_ivalue = ivalues_maps[i]; + at::ScalarType t = cur_ivalue.toTensor().scalar_type(); if (!partition_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) { TORCHTRT_THROW_ERROR( "Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled"); } else if (partition_info.truncate_long_and_double && t == at::kLong) { - ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kInt); + cur_ivalue = cur_ivalue.toTensor().to(at::kInt); LOG_WARNING("Truncating graph input type from at::kLong to at::kInt"); } else if (partition_info.truncate_long_and_double && t == at::kDouble) { - ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kFloat); + cur_ivalue = cur_ivalue.toTensor().to(at::kFloat); LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat"); } - c10::optional dtype = util::optTypeMetaToTRTDataType(ivalues_maps[i].toTensor().dtype()); + c10::optional dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype()); if (dtype == c10::nullopt) { - TORCHTRT_THROW_ERROR("Unsupported input data type " << ivalues_maps[i].toTensor().dtype()); + TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype()); } - if (ivalues_maps[i].toTensor().sizes().size() == 0) { + if (cur_ivalue.toTensor().sizes().size() == 0) { // handle Scalar types, which has sizes of [] input_shapes.push_back(util::toVec(util::toDims(c10::List({1})))); } else { - input_shapes.push_back(util::toVec(util::toDims(ivalues_maps[i].toTensor().sizes()))); + input_shapes.push_back(util::toVec(util::toDims(cur_ivalue.toTensor().sizes()))); } - input_types.push_back(ivalues_maps[i].toTensor().scalar_type()); + input_types.push_back(cur_ivalue.toTensor().scalar_type()); } }