Skip to content

Commit d26ff69

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Make quantize_pt2 return an ExportedProgram instead of a GraphModule
Summary: This will help differentiating the fp32 models from the quantized models, and prevent people from using the wrong APIs. For fp32 cases, we have a `torch.nn.Module`, which we trace and then lower. For quantized cases, we trace, quantize, and lower. After this diff, `export_to_<edge, executorch>` will ONLY handle non-quantized cases, and importantly, the sequence of `quantize_pt2` and then `export_to_<edge, executorch>` will not work anymore. Those cases should use the (existing) `lower_ep_to_<edge, executorch>` instead. Differential Revision: D73722640
1 parent 8ffdea1 commit d26ff69

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

backends/cadence/aot/compiler.py

+33-8
Original file line numberDiff line numberDiff line change
@@ -151,16 +151,14 @@ def quantize_pt2(
151151
quantizer: Optional[CadenceQuantizer] = None,
152152
calibration_data: Optional[list[tuple[object, ...]]] = None,
153153
dump_graphs: bool = False,
154-
) -> torch.fx.GraphModule:
154+
) -> ExportedProgram:
155155
"""
156156
Trace, prepare, convert and fuse the model using the given quantizer.
157157
If calibration data is provided, it will be used to calibrate the model. If
158158
not, the inputs will be used for calibration instead, which is useful for
159159
unit tests but should not be used for end-to-end use cases.
160160
Returns a GraphModule with the quantized model.
161161
"""
162-
# Make the model inference mode by calling model.eval()
163-
model.eval()
164162

165163
# Instantiate the quantizer to CadenceQuantizer if not supplied
166164
if not quantizer:
@@ -178,7 +176,9 @@ def quantize_pt2(
178176
logging.info("Graph after quantization and fusion:")
179177
logging.info(fused_gm.graph.print_tabular())
180178

181-
return fused_gm
179+
program = torch.export.export(fused_gm, inputs, strict=True)
180+
181+
return program
182182

183183

184184
# Export the model and lower it to an ExportedProgram (in aten IR)
@@ -260,21 +260,43 @@ def quantize_and_export_to_edge(
260260
dump_graphs: bool = False,
261261
constant_methods: Optional[dict[str, object]] = None,
262262
) -> EdgeProgramManager:
263+
"""
264+
Trace, quantize and lower a model/inputs pair to edge IR.
265+
"""
263266
quantized_model = quantize_pt2(
264267
model,
265268
inputs,
266269
quantizer=quantizer,
267270
dump_graphs=dump_graphs,
268271
)
269272

270-
return export_to_edge(
273+
return lower_ep_to_edge(
271274
quantized_model,
272-
inputs,
273275
dump_graphs=dump_graphs,
274276
constant_methods=constant_methods,
275277
)
276278

277279

280+
def lower_ep_to_cadence(
281+
program: ExportedProgram,
282+
dump_graphs: bool = False,
283+
opt_level: int = 1,
284+
) -> EdgeProgramManager:
285+
"""
286+
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
287+
"""
288+
edge_prog_manager = lower_ep_to_edge(program, dump_graphs=dump_graphs)
289+
cadence_passes = get_cadence_passes(opt_level)
290+
291+
# Run a couple required passes for quant/dequant ops
292+
cadence_prog_manager = edge_prog_manager.transform(
293+
cast(
294+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
295+
)
296+
)
297+
return cadence_prog_manager
298+
299+
278300
def export_to_cadence(
279301
model: torch.nn.Module,
280302
inputs: tuple[object, ...],
@@ -299,11 +321,14 @@ def quantize_and_export_to_cadence(
299321
dump_graphs: bool = False,
300322
opt_level: int = 1,
301323
) -> EdgeProgramManager:
324+
"""
325+
Trace, quantize, lower a model/inputs pair to edge IR and apply frontend
326+
optimization passes.
327+
"""
302328
quantized_model = quantize_pt2(model, inputs)
303329

304-
return export_to_cadence(
330+
return lower_ep_to_cadence(
305331
quantized_model,
306-
inputs,
307332
opt_level=opt_level,
308333
dump_graphs=dump_graphs,
309334
)

backends/cadence/aot/tests/test_replace_ops_passes.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,8 @@ def forward(self, x, y):
113113
Y = torch.randn(y_shape)
114114
p = ReplaceMatmulWithTransposedMatmulPass()
115115
inputs = (X, Y)
116-
quantized_model = quantize_pt2(model, inputs)
117116
graph_module = (
118-
export_to_edge(quantized_model, inputs).exported_program().graph_module
117+
quantize_and_export_to_edge(model, inputs).exported_program().graph_module
119118
)
120119
# pyre-fixme[16]: Optional type has no attribute `graph_module`
121120
graph_after_passes = p(graph_module).graph_module

0 commit comments

Comments
 (0)