Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Support exporting to onnx for torch>=1.10 #1231

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import os
import copy
from collections import defaultdict
import logging
import torch
import torch.nn as nn
import torch.onnx.symbolic_caffe2
Expand Down Expand Up @@ -104,6 +105,24 @@
}


def export_to_onnx(*args, **kwargs):
"""
A wrapper function to export torch module to onnx

`enable_checker` is ignored for pytorch >= 1.10
"""
enable_checker = kwargs.get('enable_onnx_checker', None)
if version.parse(torch.__version__) >= version.parse("1.10") and not enable_checker:
logging.warning('Export torch module to onnx with `enable_onnx_checker` deprecated')
kwargs.pop('enable_onnx_checker')
try:
torch.onnx.export(*args, **kwargs)
except torch.onnx.utils.ONNXCheckerError as e:
logging.error('Error when exporting to onnx: {}, could be ignored'.format(e))
else:
torch.onnx.export(*args, **kwargs)


if version.parse(torch.__version__) >= version.parse("1.9"):
onnx_subgraph_op_to_pytorch_module_param_name = {
torch.nn.GroupNorm:
Expand Down Expand Up @@ -656,10 +675,18 @@ def _create_onnx_model_with_markers(cls, dummy_input, pt_model, working_dir, onn
if is_conditional:
dummy_output = model(*dummy_input)
scripted_model = torch.jit.script(model)
torch.onnx.export(scripted_model, dummy_input, temp_file, example_outputs=dummy_output,
enable_onnx_checker=False, **onnx_export_args.kwargs)
export_to_onnx(scripted_model,
dummy_input,
temp_file,
example_outputs=dummy_output,
enable_onnx_checker=False,
**onnx_export_args.kwargs)
else:
torch.onnx.export(model, dummy_input, temp_file, enable_onnx_checker=False, **onnx_export_args.kwargs)
export_to_onnx(model,
dummy_input,
temp_file,
enable_onnx_checker=False,
**onnx_export_args.kwargs)
onnx_model = onnx.load(temp_file)
return onnx_model

Expand Down