Skip to content

Commit 872d9a3

Browse files
committed
feat(supportedops): Application to dump a list of supported operators
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 1c9dfe2 commit 872d9a3

File tree

6 files changed

+90
-0
lines changed

6 files changed

+90
-0
lines changed

core/conversion/converters/NodeConverterRegistry.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class NodeConverterRegistry {
4747
public:
4848
bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) {
4949
LOG_DEBUG("Registering converter for " << canonical_schema_string(*signature));
50+
registered_converter_schemas_.insert(c10::toString(*signature));
5051
auto name = signature->operator_name();
5152
auto iter = converter_lut_.find(name);
5253
if (iter != converter_lut_.end()) {
@@ -83,8 +84,15 @@ class NodeConverterRegistry {
8384
}
8485
}
8586

87+
std::vector<std::string> GetRegisteredConverterList() {
88+
std::vector<std::string> converter_list;
89+
std::copy(registered_converter_schemas_.begin(), registered_converter_schemas_.end(), std::back_inserter(converter_list));
90+
return converter_list;
91+
}
92+
8693
private:
8794
ConverterLUT converter_lut_;
95+
std::set<std::string> registered_converter_schemas_;
8896
};
8997

9098
NodeConverterRegistry& get_converter_registry() {
@@ -115,6 +123,10 @@ bool node_is_convertable(const torch::jit::Node* n) {
115123
return get_converter_registry().Convertable(n);
116124
}
117125

126+
std::vector<std::string> get_converter_list() {
127+
return get_converter_registry().GetRegisteredConverterList();
128+
}
129+
118130
RegisterNodeConversionPatterns&& RegisterNodeConversionPatterns::pattern(ConversionPattern p) && {
119131
register_node_converter(std::move(p));
120132
return std::move(*this);

core/conversion/converters/converters.h

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class RegisterNodeConversionPatterns {
3939

4040
bool node_is_convertable(const torch::jit::Node* n);
4141
OpConverter get_node_converter_for(const torch::jit::FunctionSchema* signature);
42+
std::vector<std::string> get_converter_list();
4243

4344
} // namespace converters
4445
} // namespace conversion

core/conversion/evaluators/NodeEvaluatorRegistry.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class NodeEvaluatorRegistry {
3636
"Attempting to override already registered evaluator " << node_kind.toQualString()
3737
<< ", merge implementations instead");
3838
}
39+
for (auto const& e : eval_reg.options.supported_variants) {
40+
registered_evaluator_schemas_.insert(e);
41+
}
3942
evaluator_lut_[node_kind] = std::move(eval_reg);
4043
}
4144

@@ -76,6 +79,12 @@ class NodeEvaluatorRegistry {
7679
return evaluator;
7780
}
7881

82+
std::vector<std::string> GetRegisteredEvaluatorList() {
83+
std::vector<std::string> evaluator_list;
84+
std::copy(registered_evaluator_schemas_.begin(), registered_evaluator_schemas_.end(), std::back_inserter(evaluator_list));
85+
return evaluator_list;
86+
}
87+
7988
bool EvalAtConversionTime(const torch::jit::Node* n) {
8089
auto evaluator = FindEvaluator(n);
8190
if (evaluator == nullptr) {
@@ -87,6 +96,7 @@ class NodeEvaluatorRegistry {
8796

8897
private:
8998
EvaluatorLUT evaluator_lut_;
99+
std::set<std::string> registered_evaluator_schemas_;
90100
};
91101

92102
NodeEvaluatorRegistry& get_evaluator_registry() {
@@ -99,6 +109,10 @@ bool shouldEvalAtConversionTime(const torch::jit::Node* n) {
99109
return get_evaluator_registry().EvalAtConversionTime(n);
100110
}
101111

112+
std::vector<std::string> getEvaluatorList() {
113+
return get_evaluator_registry().GetRegisteredEvaluatorList();
114+
}
115+
102116
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
103117
auto evaluator = get_evaluator_registry().GetEvaluator(n);
104118
return evaluator(n, args);

core/conversion/evaluators/evaluators.h

+3
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*,
3838
struct EvalOptions {
3939
std::set<c10::TypePtr> blacklisted_output_types;
4040
std::vector<c10::OperatorName> valid_schemas;
41+
std::vector<std::string> supported_variants;
4142
EvalOptions() = default;
4243
EvalOptions& blacklistOutputTypes(std::set<c10::TypePtr> types) {
4344
use_options = true;
4445
blacklisted_output_types = types;
4546
return *this;
4647
}
4748
EvalOptions& validSchemas(std::set<std::string> schemas) {
49+
std::copy(schemas.begin(), schemas.end(), std::back_inserter(supported_variants));
4850
use_options = true;
4951
for (auto s : schemas) {
5052
valid_schemas.push_back(torch::jit::parseSchema(s).operator_name());
@@ -72,6 +74,7 @@ struct EvalRegistration {
7274

7375
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args);
7476
bool shouldEvalAtConversionTime(const torch::jit::Node* n);
77+
std::vector<std::string> getEvaluatorList();
7578
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator);
7679
void register_node_evaluator(EvalRegistration r);
7780

cpp/supportedops/BUILD

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
cc_binary(
4+
name = "supportedops",
5+
srcs = [
6+
"main.cpp"
7+
],
8+
deps = [
9+
"//cpp/api:trtorch",
10+
"//core/conversion/converters"
11+
],
12+
)

cpp/supportedops/main.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include "core/conversion/converters/converters.h"
2+
#include "core/conversion/evaluators/evaluators.h"
3+
4+
#include <string>
5+
#include <sstream>
6+
#include <vector>
7+
#include <iostream>
8+
9+
int main(int argc, const char* argv[]) {
10+
std::vector<std::string> converters = trtorch::core::conversion::converters::get_converter_list();
11+
std::vector<std::string> evaluators = trtorch::core::conversion::evaluators::getEvaluatorList();
12+
13+
std::stringstream ss;
14+
15+
ss << R"TITLE(
16+
=================================
17+
Operators Supported
18+
=================================
19+
20+
)TITLE";
21+
22+
ss << R"SEC(
23+
Operators Currently Supported Through Converters
24+
-------------------------------------------------
25+
26+
)SEC";
27+
28+
for (auto c : converters) {
29+
ss << "- " << c << std::endl;
30+
}
31+
32+
ss << R"SEC(
33+
Operators Currently Supported Through Evaluators
34+
-------------------------------------------------
35+
36+
)SEC";
37+
38+
for (auto e : evaluators) {
39+
ss << "- " << e << std::endl;
40+
}
41+
42+
std::ofstream ofs;
43+
ofs.open(argv[1]);
44+
45+
ofs << ss.rdbuf();
46+
47+
return 0;
48+
}

0 commit comments

Comments
 (0)