@@ -36,6 +36,9 @@ class NodeEvaluatorRegistry {
36
36
" Attempting to override already registered evaluator " << node_kind.toQualString ()
37
37
<< " , merge implementations instead" );
38
38
}
39
+ for (auto const & e : eval_reg.options .supported_variants ) {
40
+ registered_evaluator_schemas_.insert (e);
41
+ }
39
42
evaluator_lut_[node_kind] = std::move (eval_reg);
40
43
}
41
44
@@ -76,6 +79,12 @@ class NodeEvaluatorRegistry {
76
79
return evaluator;
77
80
}
78
81
82
+ std::vector<std::string> GetRegisteredEvaluatorList () {
83
+ std::vector<std::string> evaluator_list;
84
+ std::copy (registered_evaluator_schemas_.begin (), registered_evaluator_schemas_.end (), std::back_inserter (evaluator_list));
85
+ return evaluator_list;
86
+ }
87
+
79
88
bool EvalAtConversionTime (const torch::jit::Node* n) {
80
89
auto evaluator = FindEvaluator (n);
81
90
if (evaluator == nullptr ) {
@@ -87,6 +96,7 @@ class NodeEvaluatorRegistry {
87
96
88
97
private:
89
98
EvaluatorLUT evaluator_lut_;
99
+ std::set<std::string> registered_evaluator_schemas_;
90
100
};
91
101
92
102
NodeEvaluatorRegistry& get_evaluator_registry () {
@@ -99,6 +109,10 @@ bool shouldEvalAtConversionTime(const torch::jit::Node* n) {
99
109
return get_evaluator_registry ().EvalAtConversionTime (n);
100
110
}
101
111
112
+ std::vector<std::string> getEvaluatorList () {
113
+ return get_evaluator_registry ().GetRegisteredEvaluatorList ();
114
+ }
115
+
102
116
c10::optional<torch::jit::IValue> EvalNode (const torch::jit::Node* n, kwargs& args) {
103
117
auto evaluator = get_evaluator_registry ().GetEvaluator (n);
104
118
return evaluator (n, args);
0 commit comments