Skip to content

Commit 94f7b10

Browse files
authored
Make quantize_pt2 return an ExportedProgram instead of a GraphModule
Differential Revision: D73722640 Pull Request resolved: #10644
1 parent f203c94 commit 94f7b10

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

backends/cadence/aot/compiler.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def quantize_pt2(
151151
quantizer: Optional[CadenceQuantizer] = None,
152152
calibration_data: Optional[list[tuple[object, ...]]] = None,
153153
dump_graphs: bool = False,
154-
) -> torch.fx.GraphModule:
154+
) -> ExportedProgram:
155155
"""
156156
Trace, prepare, convert and fuse the model using the given quantizer.
157157
If calibration data is provided, it will be used to calibrate the model. If
@@ -178,7 +178,9 @@ def quantize_pt2(
178178
logging.info("Graph after quantization and fusion:")
179179
logging.info(fused_gm.graph.print_tabular())
180180

181-
return fused_gm
181+
program = torch.export.export(fused_gm, inputs, strict=True)
182+
183+
return program
182184

183185

184186
# Export the model and lower it to an ExportedProgram (in aten IR)
@@ -260,21 +262,43 @@ def quantize_and_export_to_edge(
260262
dump_graphs: bool = False,
261263
constant_methods: Optional[dict[str, object]] = None,
262264
) -> EdgeProgramManager:
265+
"""
266+
Trace, quantize and lower a model/inputs pair to edge IR.
267+
"""
263268
quantized_model = quantize_pt2(
264269
model,
265270
inputs,
266271
quantizer=quantizer,
267272
dump_graphs=dump_graphs,
268273
)
269274

270-
return export_to_edge(
275+
return lower_ep_to_edge(
271276
quantized_model,
272-
inputs,
273277
dump_graphs=dump_graphs,
274278
constant_methods=constant_methods,
275279
)
276280

277281

282+
def lower_ep_to_cadence(
283+
program: ExportedProgram,
284+
dump_graphs: bool = False,
285+
opt_level: int = 1,
286+
) -> EdgeProgramManager:
287+
"""
288+
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
289+
"""
290+
edge_prog_manager = lower_ep_to_edge(program, dump_graphs=dump_graphs)
291+
cadence_passes = get_cadence_passes(opt_level)
292+
293+
# Run a couple required passes for quant/dequant ops
294+
cadence_prog_manager = edge_prog_manager.transform(
295+
cast(
296+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
297+
)
298+
)
299+
return cadence_prog_manager
300+
301+
278302
def export_to_cadence(
279303
model: torch.nn.Module,
280304
inputs: tuple[object, ...],
@@ -299,11 +323,14 @@ def quantize_and_export_to_cadence(
299323
dump_graphs: bool = False,
300324
opt_level: int = 1,
301325
) -> EdgeProgramManager:
326+
"""
327+
Trace, quantize, lower a model/inputs pair to edge IR and apply frontend
328+
optimization passes.
329+
"""
302330
quantized_model = quantize_pt2(model, inputs)
303331

304-
return export_to_cadence(
332+
return lower_ep_to_cadence(
305333
quantized_model,
306-
inputs,
307334
opt_level=opt_level,
308335
dump_graphs=dump_graphs,
309336
)

backends/cadence/aot/tests/test_replace_ops_passes.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from executorch.backends.cadence.aot.compiler import (
1717
export_to_edge,
1818
quantize_and_export_to_edge,
19-
quantize_pt2,
2019
)
2120
from executorch.backends.cadence.aot.graph_builder import (
2221
GraphBuilder,
@@ -113,9 +112,8 @@ def forward(self, x, y):
113112
Y = torch.randn(y_shape)
114113
p = ReplaceMatmulWithTransposedMatmulPass()
115114
inputs = (X, Y)
116-
quantized_model = quantize_pt2(model, inputs)
117115
graph_module = (
118-
export_to_edge(quantized_model, inputs).exported_program().graph_module
116+
quantize_and_export_to_edge(model, inputs).exported_program().graph_module
119117
)
120118
# pyre-fixme[16]: Optional type has no attribute `graph_module`
121119
graph_after_passes = p(graph_module).graph_module

0 commit comments

Comments
 (0)