@@ -151,7 +151,7 @@ def quantize_pt2(
151
151
quantizer : Optional [CadenceQuantizer ] = None ,
152
152
calibration_data : Optional [list [tuple [object , ...]]] = None ,
153
153
dump_graphs : bool = False ,
154
- ) -> torch . fx . GraphModule :
154
+ ) -> ExportedProgram :
155
155
"""
156
156
Trace, prepare, convert and fuse the model using the given quantizer.
157
157
If calibration data is provided, it will be used to calibrate the model. If
@@ -178,7 +178,9 @@ def quantize_pt2(
178
178
logging .info ("Graph after quantization and fusion:" )
179
179
logging .info (fused_gm .graph .print_tabular ())
180
180
181
- return fused_gm
181
+ program = torch .export .export (fused_gm , inputs , strict = True )
182
+
183
+ return program
182
184
183
185
184
186
# Export the model and lower it to an ExportedProgram (in aten IR)
@@ -260,21 +262,43 @@ def quantize_and_export_to_edge(
260
262
dump_graphs : bool = False ,
261
263
constant_methods : Optional [dict [str , object ]] = None ,
262
264
) -> EdgeProgramManager :
265
+ """
266
+ Trace, quantize and lower a model/inputs pair to edge IR.
267
+ """
263
268
quantized_model = quantize_pt2 (
264
269
model ,
265
270
inputs ,
266
271
quantizer = quantizer ,
267
272
dump_graphs = dump_graphs ,
268
273
)
269
274
270
- return export_to_edge (
275
+ return lower_ep_to_edge (
271
276
quantized_model ,
272
- inputs ,
273
277
dump_graphs = dump_graphs ,
274
278
constant_methods = constant_methods ,
275
279
)
276
280
277
281
282
+ def lower_ep_to_cadence (
283
+ program : ExportedProgram ,
284
+ dump_graphs : bool = False ,
285
+ opt_level : int = 1 ,
286
+ ) -> EdgeProgramManager :
287
+ """
288
+ Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
289
+ """
290
+ edge_prog_manager = lower_ep_to_edge (program , dump_graphs = dump_graphs )
291
+ cadence_passes = get_cadence_passes (opt_level )
292
+
293
+ # Run a couple required passes for quant/dequant ops
294
+ cadence_prog_manager = edge_prog_manager .transform (
295
+ cast (
296
+ list [Callable [[torch .fx .GraphModule ], Optional [PassResult ]]], cadence_passes
297
+ )
298
+ )
299
+ return cadence_prog_manager
300
+
301
+
278
302
def export_to_cadence (
279
303
model : torch .nn .Module ,
280
304
inputs : tuple [object , ...],
@@ -299,11 +323,14 @@ def quantize_and_export_to_cadence(
299
323
dump_graphs : bool = False ,
300
324
opt_level : int = 1 ,
301
325
) -> EdgeProgramManager :
326
+ """
327
+ Trace, quantize, lower a model/inputs pair to edge IR and apply frontend
328
+ optimization passes.
329
+ """
302
330
quantized_model = quantize_pt2 (model , inputs )
303
331
304
- return export_to_cadence (
332
+ return lower_ep_to_cadence (
305
333
quantized_model ,
306
- inputs ,
307
334
opt_level = opt_level ,
308
335
dump_graphs = dump_graphs ,
309
336
)
0 commit comments