From 12c9f940a01d8f7f72a582b401f80d28b9ec27c9 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Tue, 28 Mar 2023 11:01:19 -0700 Subject: [PATCH] fix undefined attr issue --- core/conversion/evaluators/eval_util.cpp | 51 ++++++++++++------------ core/ir/ir.cpp | 3 +- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index c14f9a6714..88d06a9952 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -27,72 +27,73 @@ c10::optional toIValue(const torch::jit::Value* v) { } const torch::jit::Node* node = v->node(); const c10::TypePtr& type = v->type(); + + c10::Symbol attr_value = c10::Symbol::fromDomainAndUnqualString(c10::attr::value.domainString(), "value"); + if (type->isSubtypeOf(c10::TensorType::get())) { - return node->t(c10::attr::value); + return node->t(attr_value); } else if (type->isSubtypeOf(c10::BoolType::get())) { - return (bool)node->i(c10::attr::value); - } else if ( - type->isSubtypeOf(c10::NumberType::get()) && node->kindOf(c10::attr::value) == torch::jit::AttributeKind::i) { - return node->i(c10::attr::value); - } else if ( - type->isSubtypeOf(c10::NumberType::get()) && node->kindOf(c10::attr::value) == torch::jit::AttributeKind::f) { - return node->f(c10::attr::value); + return (bool)node->i(attr_value); + } else if (type->isSubtypeOf(c10::NumberType::get()) && node->kindOf(attr_value) == torch::jit::AttributeKind::i) { + return node->i(attr_value); + } else if (type->isSubtypeOf(c10::NumberType::get()) && node->kindOf(attr_value) == torch::jit::AttributeKind::f) { + return node->f(attr_value); } else if (type->isSubtypeOf(c10::ListType::ofInts())) { try { - const auto& is = node->is(c10::attr::value); + const auto& is = node->is(attr_value); return is; } catch (const std::exception& ex) { - const auto& ival = node->ival(c10::attr::value); + const auto& ival = node->ival(attr_value); return ival; } } else if (type->isSubtypeOf(c10::ListType::ofFloats())) { try { - const auto& fs = node->fs(c10::attr::value); + const auto& fs = node->fs(attr_value); return fs; } catch (const std::exception& ex) { - const auto& ival = node->ival(c10::attr::value); + const auto& ival = node->ival(attr_value); return ival; } } else if (type->isSubtypeOf(c10::ListType::ofBools())) { - const auto bs = c10::fmap(node->is(c10::attr::value)); + const auto bs = c10::fmap(node->is(attr_value)); return bs; } else if (type->isSubtypeOf(c10::ListType::ofTensors())) { try { - const auto& ts = node->ts(c10::attr::value); + const auto& ts = node->ts(attr_value); return ts; } catch (const std::exception& ex) { - const auto& ival = node->ival(c10::attr::value); + const auto& ival = node->ival(attr_value); return ival; } } else if (type->isSubtypeOf(c10::ListType::ofStrings())) { try { - const auto& ss = node->ss(c10::attr::value); + const auto& ss = node->ss(attr_value); auto vals = c10::impl::GenericList(c10::StringType::get()); for (const auto& str : ss) { vals.push_back(str); } return vals; } catch (const std::exception& ex) { - const auto& ival = node->ival(c10::attr::value); + const auto& ival = node->ival(attr_value); return ival; } - } else if (type->cast() && node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) { - const auto& list = node->ival(c10::attr::value); + } else if (type->cast() && node->kindOf(attr_value) == torch::jit::AttributeKind::ival) { + const auto& list = node->ival(attr_value); TORCHTRT_ASSERT(list.isList(), "Is not a list"); return list; - } else if (type->cast() && node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) { - const auto& dict = node->ival(c10::attr::value); + } else if (type->cast() && node->kindOf(attr_value) == torch::jit::AttributeKind::ival) { + const auto& dict = node->ival(attr_value); TORCHTRT_ASSERT(dict.isGenericDict(), "Is not a dict"); return dict; - } else if (type->cast() && node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) { - const auto& tup = node->ival(c10::attr::value); + } else if (type->cast() && node->kindOf(attr_value) == torch::jit::AttributeKind::ival) { + const auto& tup = node->ival(attr_value); TORCHTRT_ASSERT(tup.isTuple(), "Is not a tuple"); return tup; } else if (type == c10::StringType::get()) { - const auto& s = node->s(c10::attr::value); + const auto& s = node->s(attr_value); return s; } else if (type == c10::DeviceObjType::get()) { - auto d = c10::Device(node->s(c10::attr::value)); + auto d = c10::Device(node->s(attr_value)); return d; } else if (node->mustBeNone()) { return torch::jit::IValue(); diff --git a/core/ir/ir.cpp b/core/ir/ir.cpp index 99bf4f97b1..c98d17c5ef 100644 --- a/core/ir/ir.cpp +++ b/core/ir/ir.cpp @@ -160,7 +160,8 @@ c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* LOG_GRAPH("Input outputs a Tensor"); if (in->node()->kind() == torch::jit::prim::Constant) { LOG_GRAPH("Input is a constant"); - auto const_val = in->node()->t(c10::attr::value); + auto const_val = + in->node()->t(c10::Symbol::fromDomainAndUnqualString(c10::attr::value.domainString(), "value")); LOG_GRAPH("Found that constant tensor has type: " << const_val.scalar_type()); dtype = {const_val.scalar_type()}; goto exit_first_calc_dtype;