@@ -22,7 +22,7 @@ TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) {
22
22
torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
23
23
auto sg = std::make_shared<torch::jit::Graph>();
24
24
torch::jit::parseIR (source_graph, sg.get ());
25
- torch_tensorrt::core::lowering::passes::RemoveContiguous (sg);
25
+ torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts (sg);
26
26
27
27
auto tg = std::make_shared<torch::jit::Graph>();
28
28
torch::jit::parseIR (target_graph, tg.get ());
@@ -46,7 +46,7 @@ TEST(LoweringPasses, RemoveUnnecessaryCastFloatCorrectly) {
46
46
torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
47
47
auto sg = std::make_shared<torch::jit::Graph>();
48
48
torch::jit::parseIR (source_graph, sg.get ());
49
- torch_tensorrt::core::lowering::passes::RemoveContiguous (sg);
49
+ torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts (sg);
50
50
51
51
auto tg = std::make_shared<torch::jit::Graph>();
52
52
torch::jit::parseIR (target_graph, tg.get ());
@@ -70,7 +70,85 @@ TEST(LoweringPasses, RemoveUnnecessaryCastBoolCorrectly) {
70
70
torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
71
71
auto sg = std::make_shared<torch::jit::Graph>();
72
72
torch::jit::parseIR (source_graph, sg.get ());
73
- torch_tensorrt::core::lowering::passes::RemoveContiguous (sg);
73
+ torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts (sg);
74
+
75
+ auto tg = std::make_shared<torch::jit::Graph>();
76
+ torch::jit::parseIR (target_graph, tg.get ());
77
+
78
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
79
+ }
80
+
81
+ TEST (LoweringPasses, RemoveSingleUse0DTensorsIntCorrectly) {
82
+ std::string source_graph = R"IR(
83
+ graph(%0: int):
84
+ %1: Tensor = prim::Constant[value=[8]]()
85
+ %2: int = prim::Constant[value=1]()
86
+ %3: Tensor = prim::NumToTensor(%0)
87
+ %4: Tensor = aten::add(%1, %3, %2)
88
+ %5: int = aten::Int(%4)
89
+ %6: int = aten::add(%5, %5)
90
+ return (%6))IR" ;
91
+ std::string target_graph = R"IR(
92
+ graph(%0: int):
93
+ %1: int = prim::Constant[value=8]()
94
+ %4: int = aten::add(%1, %0)
95
+ %6: int = aten::add(%4, %4)
96
+ return (%6))IR" ;
97
+
98
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
99
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
100
+ auto sg = std::make_shared<torch::jit::Graph>();
101
+ torch::jit::parseIR (source_graph, sg.get ());
102
+
103
+ auto first_op = *(sg->block ()->nodes ().begin ());
104
+ torch::jit::WithInsertPoint guard (first_op);
105
+ torch::jit::Value* r = sg->insertConstant (
106
+ c10::scalar_to_tensor (8 ), c10::nullopt, first_op->scope ());
107
+ r->copyMetadata (first_op->output ());
108
+ r->setType (c10::TensorType::get ());
109
+ first_op->output ()->replaceAllUsesWith (r);
110
+ first_op->destroy ();
111
+
112
+ torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors (sg);
113
+
114
+ auto tg = std::make_shared<torch::jit::Graph>();
115
+ torch::jit::parseIR (target_graph, tg.get ());
116
+
117
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
118
+ }
119
+
120
+ TEST (LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) {
121
+ std::string source_graph = R"IR(
122
+ graph(%0: float):
123
+ %1: Tensor = prim::Constant[value=[8.]]()
124
+ %2: float = prim::Constant[value=1.]()
125
+ %3: Tensor = prim::NumToTensor(%0)
126
+ %4: Tensor = aten::add(%1, %3, %2)
127
+ %5: float = aten::Float(%4)
128
+ %6: float = aten::add(%5, %5)
129
+ return (%6))IR" ;
130
+ std::string target_graph = R"IR(
131
+ graph(%0: float):
132
+ %1: float = prim::Constant[value=8.]()
133
+ %4: float = aten::add(%1, %0)
134
+ %6: float = aten::add(%4, %4)
135
+ return (%6))IR" ;
136
+
137
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
138
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
139
+ auto sg = std::make_shared<torch::jit::Graph>();
140
+ torch::jit::parseIR (source_graph, sg.get ());
141
+
142
+ auto first_op = *(sg->block ()->nodes ().begin ());
143
+ torch::jit::WithInsertPoint guard (first_op);
144
+ torch::jit::Value* r = sg->insertConstant (
145
+ c10::scalar_to_tensor (8.0 ), c10::nullopt, first_op->scope ());
146
+ r->copyMetadata (first_op->output ());
147
+ r->setType (c10::TensorType::get ());
148
+ first_op->output ()->replaceAllUsesWith (r);
149
+ first_op->destroy ();
150
+
151
+ torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors (sg);
74
152
75
153
auto tg = std::make_shared<torch::jit::Graph>();
76
154
torch::jit::parseIR (target_graph, tg.get ());
0 commit comments