Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Create minifier to identify errors #163

Merged
merged 6 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/torch_onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
"exported_program_to_ir",
"patch_torch",
"unpatch_torch",
"verify_onnx_program",
# Modules
"testing",
"verification",
]

from . import _testing as testing
from . import _verification as verification
from ._analysis import analyze
from ._core import export, exported_program_to_ir
from ._onnx_program import ONNXProgram
from ._patch import _torch_onnx_export as export_compat
from ._patch import patch_torch, unpatch_torch
from ._registration import ONNXRegistry
from ._verification import verify_onnx_program
6 changes: 3 additions & 3 deletions src/torch_onnx/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,9 @@ def _exported_program_to_onnx_program(

# TODO: Decide if we should keep mutated buffers as inputs/outputs

# TODO(justinchuby): Remove the hack
_ir_passes.add_torchlib_common_imports(model)

return _onnx_program.ONNXProgram(model, exported_program)


Expand Down Expand Up @@ -1129,9 +1132,6 @@ def export(
if output_names:
_ir_passes.rename_outputs(onnx_program.model, output_names)

# TODO(justinchuby): Remove the hack
_ir_passes.add_torchlib_common_imports(onnx_program.model)

export_status.onnx_translation = True
verbose_print("Translate the graph into ONNX... ✅")
except Exception as e:
Expand Down
268 changes: 266 additions & 2 deletions src/torch_onnx/_verification.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
# mypy: allow-untyped-defs
from __future__ import annotations

__all__ = [
"VerificationInfo",
"SearchResult",
"verify_onnx_program",
"minimize_inaccurate_subgraph",
]

import copy
import dataclasses
from typing import Any
import logging
import operator
from typing import TYPE_CHECKING, Any, Iterator, Sequence

import torch
from torch._functorch import fx_minifier
from torch.utils import _pytree

from torch_onnx import _onnx_program
from torch_onnx import _core, _onnx_program

if TYPE_CHECKING:
import torch.fx


logger = logging.getLogger(__name__)


@dataclasses.dataclass
Expand All @@ -21,13 +38,30 @@ class VerificationInfo:
# and checked by the runtime


@dataclasses.dataclass
class SearchResult:
graph_module: torch.fx.GraphModule
inputs: Sequence[Any]

@property
def graph(self) -> torch.fx.Graph:
return self.graph_module.graph

@graph.setter
def graph(self, fx_g: torch.fx.Graph):
self.graph_module.graph = fx_g


def _compare_tensors(
expected: torch.Tensor,
actual: torch.Tensor,
) -> tuple[float, float]:
# Move tensors to the same device
expected = expected.detach().cpu()
actual = actual.detach().cpu()
if expected.dtype == torch.bool:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to compare the dtype to make sure they are equal before we compare their values?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory no because ort would have complained already. But we could check

expected = expected.to(torch.int)
actual = actual.to(torch.int)
absolute_difference = torch.abs(expected - actual).max().item()
eps = 1e-7
relative_difference = (
Expand All @@ -42,6 +76,11 @@ def verify_onnx_program(
kwargs: dict[str, Any] | None = None,
) -> list[VerificationInfo]:
exported_program = onnx_program.exported_program
if exported_program is None:
raise ValueError(
"The ONNX program does not contain an exported_program. "
"Please provide an exported_program to verify the ONNX program."
)
if args is None and kwargs is None:
# User did not provide example inputs, use the default example inputs
if exported_program.example_inputs is None:
Expand Down Expand Up @@ -75,3 +114,228 @@ def verify_onnx_program(
)
)
return results


def _exported_program_to_fx_graph_module_and_inputs(
exported_program: torch.export.ExportedProgram,
) -> tuple[torch.fx.GraphModule, Sequence[torch.Tensor]]:
# Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py
# Original code Copyright 2024 The AI Edge Torch Authors.
# Apache License, Version 2.0
fx_gm = exported_program.graph_module
fx_inputs = _pytree.tree_map(
torch.tensor,
exported_program._graph_module_flat_inputs(*exported_program.example_inputs),
)
return fx_gm, fx_inputs


def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule):
"""This function turns all operator getitem nodes in ExportedProgram FX graph to

new nodes composed of "computation + getitem". The normalization duplicates
some computations in the graph but would make the graph more friendly for
partitioning in FX minifier.
"""
# Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py
# Original code Copyright 2024 The AI Edge Torch Authors.
# Apache License, Version 2.0

fx_gm = copy.deepcopy(fx_gm)
graph = fx_gm.graph
for n in graph.nodes:
if n.target != operator.getitem:
continue

src_n, key = n.args
assert n.op == "call_function"
with graph.inserting_after(n):
new_n = graph.call_function(
lambda src_target, key, args, kwargs: operator.getitem(
src_target(*args, **kwargs), key
),
(src_n.target, key, src_n.args, src_n.kwargs),
)
n.replace_all_uses_with(new_n)

graph.eliminate_dead_code()
fx_gm.graph = graph
return fx_gm


def _erase_unused_inputs(fx_gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]):
# Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py
# Original code Copyright 2024 The AI Edge Torch Authors.
# Apache License, Version 2.0
fx_gm = copy.deepcopy(fx_gm)
args = fx_gm.graph.process_inputs(*inputs)
args_iter = iter(args)

graph = fx_gm.graph
new_inputs = []
for n in graph.nodes:
if n.op == "placeholder":
if n.target.startswith("*"):
new_inputs += list(args_iter)
elif len(n.users) > 0:
new_inputs.append(next(args_iter))
else:
graph.erase_node(n)
next(args_iter)
new_inputs = tuple(new_inputs)
fx_gm.graph = graph
return fx_gm, new_inputs


def _lift_dead_ops_to_outputs(fx_gm: torch.fx.GraphModule):
# Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py
# Original code Copyright 2024 The AI Edge Torch Authors.
# Apache License, Version 2.0
fx_gm = copy.deepcopy(fx_gm)

new_outputs = []
graph = fx_gm.graph
nodes = list(graph.nodes)
assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
for node in nodes:
if node.op not in ("placeholder", "output") and len(node.users) == 0:
new_outputs.append(node) # noqa: PERF401

output_node = nodes[-1]
# FX output node returns the first arg as is.
# ref: https://github.com/pytorch/pytorch/blob/1a578df57cc0f417f671634e564c62ef5d9a97e2/torch/fx/interpreter.py#L337
new_outputs, _ = _pytree.tree_flatten([new_outputs, output_node.args[0]])
output_node.update_arg(0, tuple(new_outputs))

fx_gm.graph = graph
return fx_gm


def _normalize_minified_fx_gm(
fx_gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]
):
# Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py
# Original code Copyright 2024 The AI Edge Torch Authors.
# Apache License, Version 2.0
fx_gm, inputs = _erase_unused_inputs(fx_gm, inputs)
fx_gm = _lift_dead_ops_to_outputs(fx_gm)
return fx_gm, inputs


def _erase_trivial_outputs(fx_gm: torch.fx.GraphModule):
"""Remove output nodes directly connected to an input node."""
# Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py
# Original code Copyright 2024 The AI Edge Torch Authors.
# Apache License, Version 2.0
fx_gm = copy.deepcopy(fx_gm)

graph = fx_gm.graph
nodes = list(graph.nodes)
assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
output_node = nodes[-1]

outputs, _ = _pytree.tree_flatten(output_node.args[0])
new_outputs = [output for output in outputs if output.op != "placeholder"]
output_node.update_arg(0, tuple(new_outputs))

fx_gm.recompile()
return fx_gm


def _erase_sub_gm_from_gm(
fx_gm: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
sub_gm: torch.fx.GraphModule,
sub_inputs: Sequence[torch.Tensor],
):
# Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py
# Original code Copyright 2024 The AI Edge Torch Authors.
# Apache License, Version 2.0
fx_gm = copy.deepcopy(fx_gm)
fx_inputs = list(inputs)

class EraseNodeInterpreter(torch.fx.Interpreter):
def run_node(self, node):
nonlocal fx_gm, fx_inputs
res = super().run_node(node)
if node.op not in ("placeholder", "output"):
to_erase = next(m for m in fx_gm.graph.nodes if m.name == node.name)
# Raise the output (tensor) of the erased node to be an input of
# the new model graph. Some raised inputs may become unused later
# when all the users are within the erased subgraph, those inputs
# will be removed by the followed `_erase_unused_inputs` pass.
with fx_gm.graph.inserting_before(to_erase):
new_input = fx_gm.graph.placeholder(node.name + "__value")
to_erase.replace_all_uses_with(new_input)

fx_gm.graph.erase_node(to_erase)
fx_inputs.append(res)
return res

interpreter = EraseNodeInterpreter(sub_gm)
interpreter.run(*sub_inputs)

fx_gm.graph.lint()
fx_gm.recompile()

# Ops prior to the erased subgraph may be dangling. Lift them as outputs.
fx_gm = _lift_dead_ops_to_outputs(fx_gm)
fx_gm = _erase_trivial_outputs(fx_gm)
fx_gm, fx_inputs = _erase_unused_inputs(fx_gm, fx_inputs)

fx_gm.graph.lint()
fx_gm.recompile()
return fx_gm, fx_inputs


def minimize_inaccurate_subgraph(
exported_program: torch.export.ExportedProgram,
rtol: float = 1e-4,
atol: float = 1e-5,
) -> Iterator[SearchResult]:
"""Find the subgraph with error and minimize it."""
# Adapted from https://github.com/google-ai-edge/ai-edge-torch/blob/a54d10d4fcf53339d32b00dda71918e810064e22/ai_edge_torch/debug/culprit.py
# Original code Copyright 2024 The AI Edge Torch Authors.
# Apache License, Version 2.0

def _export_and_verify(
torch_module: torch.fx.GraphModule,
inputs: Any,
) -> bool:
exported_program = torch.export.export(torch_module, tuple(inputs))
onnx_model = _core.exported_program_to_ir(exported_program)
onnx_program = _onnx_program.ONNXProgram(onnx_model, exported_program)
verification_info = verify_onnx_program(onnx_program)
for info in verification_info:
if info.absolute_difference > atol or info.relative_difference > rtol:
logger.warning("Found inaccuracy: %s", info)
return True
return False

# Get the subgraph with error
fx_gm, fx_inputs = _exported_program_to_fx_graph_module_and_inputs(exported_program)
found_inaccuracies_num = 0
while True:
try:
graph_module = _normalize_getitem_nodes(fx_gm)
raw_min_fx_gm, raw_min_inputs = fx_minifier.minifier(
graph_module,
fx_inputs,
_export_and_verify,
)
raw_min_inputs = tuple(raw_min_inputs)
min_fx_gm, min_inputs = _normalize_minified_fx_gm(
raw_min_fx_gm, raw_min_inputs
)
found_inaccuracies_num += 1
yield SearchResult(min_fx_gm, min_inputs)
fx_gm, fx_inputs = _erase_sub_gm_from_gm(
fx_gm, fx_inputs, raw_min_fx_gm, raw_min_inputs
)
except RuntimeError as e: # noqa: PERF203
if (
str(e) == "Input graph did not fail the tester"
and found_inaccuracies_num > 0
):
break
raise
21 changes: 21 additions & 0 deletions tests/models/longformer_export_acc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import torch_onnx
from transformers import LongformerModel, LongformerTokenizer

torch_onnx.patch_torch()

tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors="pt")
print("Exporting model...")

ep = torch.export.export(
model,
(encoded_input["input_ids"], encoded_input["attention_mask"]),
)

for result in torch_onnx.verification.minimize_inaccurate_subgraph(
ep, atol=1e-2, rtol=10.0
):
print(result)
Loading