From c54441a1d662669b0076878eae4b5a74adbac5be Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 20 Aug 2024 16:29:02 -0700 Subject: [PATCH 1/6] Create minifier --- src/torch_onnx/_verification.py | 225 +++++++++++++++++++++++++++++++- 1 file changed, 223 insertions(+), 2 deletions(-) diff --git a/src/torch_onnx/_verification.py b/src/torch_onnx/_verification.py index c7e55af..3fdf8ff 100644 --- a/src/torch_onnx/_verification.py +++ b/src/torch_onnx/_verification.py @@ -1,13 +1,19 @@ # mypy: allow-untyped-defs from __future__ import annotations +import copy import dataclasses -from typing import Any +import operator +from typing import TYPE_CHECKING, Any, Sequence import torch +from torch._functorch import fx_minifier from torch.utils import _pytree -from torch_onnx import _onnx_program +from torch_onnx import _core, _onnx_program, _testing + +if TYPE_CHECKING: + import torch.fx @dataclasses.dataclass @@ -21,6 +27,20 @@ class VerificationInfo: # and checked by the runtime +@dataclasses.dataclass +class SearchResult: + graph_module: torch.fx.GraphModule + inputs: Sequence[Any] + + @property + def graph(self) -> torch.fx.Graph: + return self.graph_module.graph + + @graph.setter + def graph(self, fx_g: torch.fx.Graph): + self.graph_module.graph = fx_g + + def _compare_tensors( expected: torch.Tensor, actual: torch.Tensor, @@ -75,3 +95,204 @@ def verify_onnx_program( ) ) return results + + +def _exported_program_to_fx_graph_module_and_inputs( + exported_program: torch.export.ExportedProgram, +): + fx_gm = exported_program.graph_module + fx_inputs = _pytree.tree_map( + torch.tensor, + exported_program._graph_module_flat_inputs(*exported_program.example_inputs), + ) + return fx_gm, fx_inputs + + +def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule): + """This function turns all operator getitem nodes in ExportedProgram FX graph to + + new nodes composed of "computation + getitem". The normalization duplicates + some computations in the graph but would make the graph more friendly for + partitioning in FX minifier. + """ + # Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py#L191 + # Original code Copyright 2024 The AI Edge Torch Authors. + # Apache License, Version 2.0 + + fx_gm = copy.deepcopy(fx_gm) + graph = fx_gm.graph + for n in graph.nodes: + if n.target != operator.getitem: + continue + + src_n, key = n.args + assert n.op == "call_function" + with graph.inserting_after(n): + new_n = graph.call_function( + lambda src_target, key, args, kwargs: operator.getitem( + src_target(*args, **kwargs), key + ), + (src_n.target, key, src_n.args, src_n.kwargs), + ) + n.replace_all_uses_with(new_n) + + graph.eliminate_dead_code() + fx_gm.graph = graph + return fx_gm + + +def _erase_unused_inputs(fx_gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]): + fx_gm = copy.deepcopy(fx_gm) + inputs = tuple(inputs) + args = fx_gm.graph.process_inputs(*inputs) + args_iter = iter(args) + + graph = fx_gm.graph + new_inputs = [] + for n in graph.nodes: + if n.op == "placeholder": + if n.target.startswith("*"): + new_inputs += list(args_iter) + elif len(n.users) > 0: + new_inputs.append(next(args_iter)) + else: + graph.erase_node(n) + next(args_iter) + new_inputs = tuple(new_inputs) + fx_gm.graph = graph + return fx_gm, new_inputs + + +def _lift_dead_ops_to_outputs(fx_gm: torch.fx.GraphModule): + fx_gm = copy.deepcopy(fx_gm) + + new_outputs = [] + graph = fx_gm.graph + nodes = list(graph.nodes) + assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1 + for node in nodes: + if node.op not in ("placeholder", "output") and len(node.users) == 0: + new_outputs.append(node) + + output_node = nodes[-1] + # FX output node returns the first arg as is. + # ref: https://github.com/pytorch/pytorch/blob/1a578df57cc0f417f671634e564c62ef5d9a97e2/torch/fx/interpreter.py#L337 + new_outputs, _ = _pytree.tree_flatten([new_outputs, output_node.args[0]]) + output_node.update_arg(0, tuple(new_outputs)) + + fx_gm.graph = graph + return fx_gm + + +def _normalize_minified_fx_gm( + fx_gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor] +): + fx_gm, inputs = _erase_unused_inputs(fx_gm, inputs) + fx_gm = _lift_dead_ops_to_outputs(fx_gm) + return fx_gm, inputs + + +def _erase_trivial_outputs(fx_gm: torch.fx.GraphModule): + """Remove output nodes directly connected to an input node.""" + fx_gm = copy.deepcopy(fx_gm) + + graph = fx_gm.graph + nodes = list(graph.nodes) + assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1 + output_node = nodes[-1] + + outputs, _ = _pytree.tree_flatten(output_node.args[0]) + new_outputs = [output for output in outputs if output.op != "placeholder"] + output_node.update_arg(0, tuple(new_outputs)) + + fx_gm.recompile() + return fx_gm + + +def _erase_sub_gm_from_gm( + fx_gm: torch.fx.GraphModule, + inputs: Sequence[torch.Tensor], + sub_gm: torch.fx.GraphModule, + sub_inputs: Sequence[torch.Tensor], +): + fx_gm = copy.deepcopy(fx_gm) + fx_inputs = list(inputs) + + class EraseNodeInterpreter(torch.fx.Interpreter): + def run_node(self, node): + nonlocal fx_gm, fx_inputs + res = super().run_node(node) + if node.op not in ("placeholder", "output"): + to_erase = next(m for m in fx_gm.graph.nodes if m.name == node.name) + # Raise the output (tensor) of the erased node to be an input of + # the new model graph. Some raised inputs may become unused later + # when all the users are within the erased subgraph, those inputs + # will be removed by the followed `_erase_unused_inputs` pass. + with fx_gm.graph.inserting_before(to_erase): + new_input = fx_gm.graph.placeholder(node.name + "__value") + to_erase.replace_all_uses_with(new_input) + + fx_gm.graph.erase_node(to_erase) + fx_inputs.append(res) + return res + + interpreter = EraseNodeInterpreter(sub_gm) + interpreter.run(*sub_inputs) + + fx_gm.graph.lint() + fx_gm.recompile() + + # Ops prior to the erased subgraph may be dangling. Lift them as outputs. + fx_gm = _lift_dead_ops_to_outputs(fx_gm) + fx_gm = _erase_trivial_outputs(fx_gm) + fx_gm, fx_inputs = _erase_unused_inputs(fx_gm, fx_inputs) + + fx_gm.graph.lint() + fx_gm.recompile() + return fx_gm, fx_inputs + + +def minimize_inaccurate_subgraph( + exported_program: torch.export.ExportedProgram, + rtol: float | None = None, + atol: float | None = None, +): + """Find the subgraph with error and minimize it.""" + + def _export_and_verify( + torch_module: torch.fx.GraphModule, + inputs: Any, + ) -> bool: + try: + onnx_program = _core.export(torch_module, args=inputs) + _testing.assert_onnx_program(onnx_program, rtol=rtol, atol=atol) + except Exception: + return True + return False + + # Get the subgraph with error + fx_gm, fx_inputs = _exported_program_to_fx_graph_module_and_inputs(exported_program) + found_culprits_num = 0 + while True: + try: + graph_module = _normalize_getitem_nodes(fx_gm) + raw_min_fx_gm, raw_min_inputs = fx_minifier.minifier( + graph_module, + fx_inputs, + _export_and_verify, + ) + min_fx_gm, min_inputs = _normalize_minified_fx_gm( + raw_min_fx_gm, raw_min_inputs + ) + found_culprits_num += 1 + yield SearchResult(min_fx_gm, min_inputs) + fx_gm, fx_inputs = _erase_sub_gm_from_gm( + fx_gm, fx_inputs, raw_min_fx_gm, raw_min_inputs + ) + except RuntimeError as e: # noqa: PERF203 + if ( + str(e) == "Input graph did not fail the tester" + and found_culprits_num > 0 + ): + break + raise From ccf7ad01925bc148e21bb99ce1ef74da29647ff6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 20 Aug 2024 16:35:50 -0700 Subject: [PATCH 2/6] try this --- src/torch_onnx/_verification.py | 4 +++- tests/models/longformer_export_acc.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 tests/models/longformer_export_acc.py diff --git a/src/torch_onnx/_verification.py b/src/torch_onnx/_verification.py index 3fdf8ff..0a4c121 100644 --- a/src/torch_onnx/_verification.py +++ b/src/torch_onnx/_verification.py @@ -264,7 +264,9 @@ def _export_and_verify( inputs: Any, ) -> bool: try: - onnx_program = _core.export(torch_module, args=inputs) + exported_program = torch.export.export(torch_module, inputs) + onnx_model = _core.exported_program_to_ir(exported_program) + onnx_program = _onnx_program.ONNXProgram(onnx_model, exported_program) _testing.assert_onnx_program(onnx_program, rtol=rtol, atol=atol) except Exception: return True diff --git a/tests/models/longformer_export_acc.py b/tests/models/longformer_export_acc.py new file mode 100644 index 0000000..870f564 --- /dev/null +++ b/tests/models/longformer_export_acc.py @@ -0,0 +1,20 @@ +import torch +import torch_onnx +from torch_onnx import _verification +from transformers import LongformerModel, LongformerTokenizer + +torch_onnx.patch_torch() + +tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096") +model = LongformerModel.from_pretrained("allenai/longformer-base-4096") +text = "Replace me by any text you'd like." +encoded_input = tokenizer(text, return_tensors="pt") +print("Exporting model...") + +ep = torch.export.export( + model, + (encoded_input["input_ids"], encoded_input["attention_mask"]), +) + +for result in _verification.minimize_inaccurate_subgraph(ep, rtol=10.0): + print(result) From 5e010f610a57172f7c58b0842751ffa4c70e2e0d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 20 Aug 2024 16:41:36 -0700 Subject: [PATCH 3/6] fix --- src/torch_onnx/_core.py | 6 +++--- src/torch_onnx/_verification.py | 9 ++++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/torch_onnx/_core.py b/src/torch_onnx/_core.py index 9626817..043eade 100644 --- a/src/torch_onnx/_core.py +++ b/src/torch_onnx/_core.py @@ -903,6 +903,9 @@ def _exported_program_to_onnx_program( # TODO: Decide if we should keep mutated buffers as inputs/outputs + # TODO(justinchuby): Remove the hack + _ir_passes.add_torchlib_common_imports(model) + return _onnx_program.ONNXProgram(model, exported_program) @@ -1129,9 +1132,6 @@ def export( if output_names: _ir_passes.rename_outputs(onnx_program.model, output_names) - # TODO(justinchuby): Remove the hack - _ir_passes.add_torchlib_common_imports(onnx_program.model) - export_status.onnx_translation = True verbose_print("Translate the graph into ONNX... ✅") except Exception as e: diff --git a/src/torch_onnx/_verification.py b/src/torch_onnx/_verification.py index 0a4c121..a43eca7 100644 --- a/src/torch_onnx/_verification.py +++ b/src/torch_onnx/_verification.py @@ -3,6 +3,7 @@ import copy import dataclasses +import logging import operator from typing import TYPE_CHECKING, Any, Sequence @@ -16,6 +17,9 @@ import torch.fx +logger = logging.getLogger(__name__) + + @dataclasses.dataclass class VerificationInfo: name: str @@ -268,8 +272,11 @@ def _export_and_verify( onnx_model = _core.exported_program_to_ir(exported_program) onnx_program = _onnx_program.ONNXProgram(onnx_model, exported_program) _testing.assert_onnx_program(onnx_program, rtol=rtol, atol=atol) - except Exception: + except AssertionError: return True + except Exception: + logger.exception("Error during verification") + return False return False # Get the subgraph with error From c565b2fb878953aab0e47849c584dd3d64ba48b5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 20 Aug 2024 17:13:05 -0700 Subject: [PATCH 4/6] Cerate the search --- src/torch_onnx/_verification.py | 72 ++++++++++++++++++++------- tests/models/longformer_export_acc.py | 2 +- 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/src/torch_onnx/_verification.py b/src/torch_onnx/_verification.py index a43eca7..42ddcab 100644 --- a/src/torch_onnx/_verification.py +++ b/src/torch_onnx/_verification.py @@ -1,17 +1,24 @@ # mypy: allow-untyped-defs from __future__ import annotations +__all__ = [ + "VerificationInfo", + "SearchResult", + "verify_onnx_program", + "minimize_inaccurate_subgraph", +] + import copy import dataclasses import logging import operator -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, Iterator, Sequence import torch from torch._functorch import fx_minifier from torch.utils import _pytree -from torch_onnx import _core, _onnx_program, _testing +from torch_onnx import _core, _onnx_program if TYPE_CHECKING: import torch.fx @@ -52,6 +59,9 @@ def _compare_tensors( # Move tensors to the same device expected = expected.detach().cpu() actual = actual.detach().cpu() + if expected.dtype == torch.bool: + expected = expected.to(torch.int) + actual = actual.to(torch.int) absolute_difference = torch.abs(expected - actual).max().item() eps = 1e-7 relative_difference = ( @@ -66,6 +76,11 @@ def verify_onnx_program( kwargs: dict[str, Any] | None = None, ) -> list[VerificationInfo]: exported_program = onnx_program.exported_program + if exported_program is None: + raise ValueError( + "The ONNX program does not contain an exported_program. " + "Please provide an exported_program to verify the ONNX program." + ) if args is None and kwargs is None: # User did not provide example inputs, use the default example inputs if exported_program.example_inputs is None: @@ -103,7 +118,10 @@ def verify_onnx_program( def _exported_program_to_fx_graph_module_and_inputs( exported_program: torch.export.ExportedProgram, -): +) -> tuple[torch.fx.GraphModule, Sequence[torch.Tensor]]: + # Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py + # Original code Copyright 2024 The AI Edge Torch Authors. + # Apache License, Version 2.0 fx_gm = exported_program.graph_module fx_inputs = _pytree.tree_map( torch.tensor, @@ -119,7 +137,7 @@ def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule): some computations in the graph but would make the graph more friendly for partitioning in FX minifier. """ - # Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py#L191 + # Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py # Original code Copyright 2024 The AI Edge Torch Authors. # Apache License, Version 2.0 @@ -146,8 +164,10 @@ def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule): def _erase_unused_inputs(fx_gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]): + # Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py + # Original code Copyright 2024 The AI Edge Torch Authors. + # Apache License, Version 2.0 fx_gm = copy.deepcopy(fx_gm) - inputs = tuple(inputs) args = fx_gm.graph.process_inputs(*inputs) args_iter = iter(args) @@ -168,6 +188,9 @@ def _erase_unused_inputs(fx_gm: torch.fx.GraphModule, inputs: Sequence[torch.Ten def _lift_dead_ops_to_outputs(fx_gm: torch.fx.GraphModule): + # Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py + # Original code Copyright 2024 The AI Edge Torch Authors. + # Apache License, Version 2.0 fx_gm = copy.deepcopy(fx_gm) new_outputs = [] @@ -176,7 +199,7 @@ def _lift_dead_ops_to_outputs(fx_gm: torch.fx.GraphModule): assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1 for node in nodes: if node.op not in ("placeholder", "output") and len(node.users) == 0: - new_outputs.append(node) + new_outputs.append(node) # noqa: PERF401 output_node = nodes[-1] # FX output node returns the first arg as is. @@ -191,6 +214,9 @@ def _lift_dead_ops_to_outputs(fx_gm: torch.fx.GraphModule): def _normalize_minified_fx_gm( fx_gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor] ): + # Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py + # Original code Copyright 2024 The AI Edge Torch Authors. + # Apache License, Version 2.0 fx_gm, inputs = _erase_unused_inputs(fx_gm, inputs) fx_gm = _lift_dead_ops_to_outputs(fx_gm) return fx_gm, inputs @@ -198,6 +224,9 @@ def _normalize_minified_fx_gm( def _erase_trivial_outputs(fx_gm: torch.fx.GraphModule): """Remove output nodes directly connected to an input node.""" + # Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py + # Original code Copyright 2024 The AI Edge Torch Authors. + # Apache License, Version 2.0 fx_gm = copy.deepcopy(fx_gm) graph = fx_gm.graph @@ -219,6 +248,9 @@ def _erase_sub_gm_from_gm( sub_gm: torch.fx.GraphModule, sub_inputs: Sequence[torch.Tensor], ): + # Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py + # Original code Copyright 2024 The AI Edge Torch Authors. + # Apache License, Version 2.0 fx_gm = copy.deepcopy(fx_gm) fx_inputs = list(inputs) @@ -258,25 +290,26 @@ def run_node(self, node): def minimize_inaccurate_subgraph( exported_program: torch.export.ExportedProgram, - rtol: float | None = None, - atol: float | None = None, -): + rtol: float = 1e-4, + atol: float = 1e-5, +) -> Iterator[SearchResult]: """Find the subgraph with error and minimize it.""" + # Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py + # Original code Copyright 2024 The AI Edge Torch Authors. + # Apache License, Version 2.0 def _export_and_verify( torch_module: torch.fx.GraphModule, inputs: Any, ) -> bool: - try: - exported_program = torch.export.export(torch_module, inputs) - onnx_model = _core.exported_program_to_ir(exported_program) - onnx_program = _onnx_program.ONNXProgram(onnx_model, exported_program) - _testing.assert_onnx_program(onnx_program, rtol=rtol, atol=atol) - except AssertionError: - return True - except Exception: - logger.exception("Error during verification") - return False + exported_program = torch.export.export(torch_module, tuple(inputs)) + onnx_model = _core.exported_program_to_ir(exported_program) + onnx_program = _onnx_program.ONNXProgram(onnx_model, exported_program) + verification_info = verify_onnx_program(onnx_program) + for info in verification_info: + if info.absolute_difference > atol or info.relative_difference > rtol: + print(f"Found culprit: {info}") + return True return False # Get the subgraph with error @@ -290,6 +323,7 @@ def _export_and_verify( fx_inputs, _export_and_verify, ) + raw_min_inputs = tuple(raw_min_inputs) min_fx_gm, min_inputs = _normalize_minified_fx_gm( raw_min_fx_gm, raw_min_inputs ) diff --git a/tests/models/longformer_export_acc.py b/tests/models/longformer_export_acc.py index 870f564..930138e 100644 --- a/tests/models/longformer_export_acc.py +++ b/tests/models/longformer_export_acc.py @@ -16,5 +16,5 @@ (encoded_input["input_ids"], encoded_input["attention_mask"]), ) -for result in _verification.minimize_inaccurate_subgraph(ep, rtol=10.0): +for result in _verification.minimize_inaccurate_subgraph(ep, atol=1e-2, rtol=10.0): print(result) From a74a62f48d25a5058d8b40f004da64fb58a17fce Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 20 Aug 2024 17:47:10 -0700 Subject: [PATCH 5/6] Import --- src/torch_onnx/__init__.py | 5 +++-- src/torch_onnx/_verification.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/torch_onnx/__init__.py b/src/torch_onnx/__init__.py index 8a3cf64..cc9c1b7 100644 --- a/src/torch_onnx/__init__.py +++ b/src/torch_onnx/__init__.py @@ -7,15 +7,16 @@ "exported_program_to_ir", "patch_torch", "unpatch_torch", - "verify_onnx_program", + # Modules "testing", + "verification", ] from . import _testing as testing +from . import _verification as verification from ._analysis import analyze from ._core import export, exported_program_to_ir from ._onnx_program import ONNXProgram from ._patch import _torch_onnx_export as export_compat from ._patch import patch_torch, unpatch_torch from ._registration import ONNXRegistry -from ._verification import verify_onnx_program diff --git a/src/torch_onnx/_verification.py b/src/torch_onnx/_verification.py index 42ddcab..fa18221 100644 --- a/src/torch_onnx/_verification.py +++ b/src/torch_onnx/_verification.py @@ -308,13 +308,13 @@ def _export_and_verify( verification_info = verify_onnx_program(onnx_program) for info in verification_info: if info.absolute_difference > atol or info.relative_difference > rtol: - print(f"Found culprit: {info}") + logger.warning("Found inaccuracy: %s", info) return True return False # Get the subgraph with error fx_gm, fx_inputs = _exported_program_to_fx_graph_module_and_inputs(exported_program) - found_culprits_num = 0 + found_inaccuracies_num = 0 while True: try: graph_module = _normalize_getitem_nodes(fx_gm) @@ -327,7 +327,7 @@ def _export_and_verify( min_fx_gm, min_inputs = _normalize_minified_fx_gm( raw_min_fx_gm, raw_min_inputs ) - found_culprits_num += 1 + found_inaccuracies_num += 1 yield SearchResult(min_fx_gm, min_inputs) fx_gm, fx_inputs = _erase_sub_gm_from_gm( fx_gm, fx_inputs, raw_min_fx_gm, raw_min_inputs @@ -335,7 +335,7 @@ def _export_and_verify( except RuntimeError as e: # noqa: PERF203 if ( str(e) == "Input graph did not fail the tester" - and found_culprits_num > 0 + and found_inaccuracies_num > 0 ): break raise From 9c5e3fe37eef98176a37eca7d6b615c700e7b144 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 20 Aug 2024 17:56:13 -0700 Subject: [PATCH 6/6] import --- tests/models/longformer_export_acc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/longformer_export_acc.py b/tests/models/longformer_export_acc.py index 930138e..b3d84ca 100644 --- a/tests/models/longformer_export_acc.py +++ b/tests/models/longformer_export_acc.py @@ -1,6 +1,5 @@ import torch import torch_onnx -from torch_onnx import _verification from transformers import LongformerModel, LongformerTokenizer torch_onnx.patch_torch() @@ -16,5 +15,7 @@ (encoded_input["input_ids"], encoded_input["attention_mask"]), ) -for result in _verification.minimize_inaccurate_subgraph(ep, atol=1e-2, rtol=10.0): +for result in torch_onnx.verification.minimize_inaccurate_subgraph( + ep, atol=1e-2, rtol=10.0 +): print(result)