Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Optimum profiling #70

Open
justinchuby opened this issue Jun 25, 2024 · 13 comments
Open

Optimum profiling #70

justinchuby opened this issue Jun 25, 2024 · 13 comments

Comments

@justinchuby
Copy link
Owner

justinchuby commented Jun 25, 2024

optimum-cli export onnx --model openai/whisper-large- v3 whisper/

mprof run optimum-cli export onnx --model open ai/whisper-large-v3 whisper/

@justinchuby
Copy link
Owner Author

PyTorch ONNX Conversion Error Report

✅ Obtain model graph with `torch.export.export`
✅ Translate the graph into ONNX
❌ Run `onnx.checker` on the ONNX model
⚪ Execute the model with ONNX Runtime
⚪ Validate model output accuracy

Error message:

Traceback (most recent call last):

  File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/_core.py", line 933, in export
    onnx.checker.check_model(onnx_program.model_proto, full_check=True)

  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/onnx/checker.py", line 176, in check_model
    raise ValueError(

ValueError: This protobuf of onnx model is too large (>2GB). Call check_model with model path instead.

Analysis

PyTorch ONNX Conversion Analysis

Model Information

The model has 636968960 parameters and 0 buffers (non-trainable parameters).
Number of parameters per dtype:

defaultdict(<class 'int'>, {torch.float32: 636968960})

Number of buffers per dtype:

defaultdict(<class 'int'>, {})

Inputs:

  • arg487_1: TensorMetadata(shape=torch.Size([2, 128, 3000]), dtype=torch.float32, requires_grad=False, stride=(384000, 3000, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})

Outputs:

  • getitem_192: TensorMetadata(shape=torch.Size([2, 1500, 1280]), dtype=torch.float32, requires_grad=False, stride=(1920000, 1280, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})

The FX graph has 2290 nodes in total. Number of FX nodes per op:

  • placeholder: 488
  • call_function: 1801
  • output: 1

Of the call_function nodes, the counts of operators used are:

  • aten.view.default: 608
  • aten.clone.default: 257
  • aten.t.default: 192
  • aten.addmm.default: 160
  • aten.transpose.int: 160
  • aten.add.Tensor: 65
  • aten.native_layer_norm.default: 65
  • <built-in function getitem>: 65
  • aten.bmm.default: 64
  • aten.gelu.default: 34
  • aten.mul.Tensor: 32
  • aten.mm.default: 32
  • aten._softmax.default: 32
  • aten._unsafe_view.default: 32
  • aten.convolution.default: 2
  • aten.permute.default: 1

ONNX Conversion Information

All operators in the model have registered ONNX decompositions.

Profiling result


  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:11:01  Samples:  11267
 /_//_/// /_\ / //_// / //_'/ //     Duration: 12.009    CPU time: 13.367
/   _/                      v4.6.2

Program: /Users/justinc/Documents/GitHub/torch-onnx/venv/bin/optimum-cli export onnx --model openai/whisper-large-v3 whisper/

12.008 export  torch_onnx/_core.py:793
├─ 7.311 export  torch/export/__init__.py:73
│     [234 frames hidden]  torch, contextlib, dis, importlib, co...
└─ 4.695 exported_program_to_ir  torch_onnx/_core.py:618
   ├─ 3.392 wrapper  torch/export/exported_program.py:80
   │     [80 frames hidden]  torch, <string>
   ├─ 0.640 _add_nodes  torch_onnx/_core.py:486
   │  └─ 0.631 _handle_call_function_node_with_lowering  torch_onnx/_core.py:356
   │     └─ 0.424 TracedOnnxFunction.__call__  ../../onnxscript/onnxscript/values.py:581
   │        └─ 0.269 SymbolicTensor.aten_view  ../../onnxscript/onnxscript/function_libs/torch_lib/ops/core.py:8740
   │           └─ 0.153 Opset18.Cast  ../../onnxscript/onnxscript/onnx_opset/_impl/opset13.py:241
   │              └─ 0.145 Op.__call__  ../../onnxscript/onnxscript/values.py:291
   │                 └─ 0.144 OpRecorder.eval  torch_onnx/_building.py:390
   ├─ 0.342 OnnxRegistry.from_torchlib  torch_onnx/_registration.py:114
   │  └─ 0.182 _get_overload  torch_onnx/_registration.py:57
   │     └─ 0.176 <module>  torchvision/__init__.py:1
   └─ 0.301 insert_type_promotion_nodes  torch_onnx/_fx_passes.py:13
      └─ 0.280 wrapper  torch/onnx/_internal/diagnostics/infra/decorator.py:71
            [13 frames hidden]  torch

@justinchuby
Copy link
Owner Author

  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:21:09  Samples:  3045
 /_//_/// /_\ / //_// / //_'/ //     Duration: 20.408    CPU time: 78.383
/   _/                      v4.6.2

Program: /Users/justinc/Documents/GitHub/torch-onnx/venv/bin/optimum-cli export onnx --model openai/whisper-large-v3 whisper2/

20.407 export_pytorch  optimum/exporters/onnx/convert.py:484
└─ 20.407 export  torch/onnx/utils.py:189
      [52 frames hidden]  torch, optimum, transformers, <built-in>
         4.360 PyCapsule._jit_pass_onnx_graph_shape_type_inference  <built-in>

@justinchuby
Copy link
Owner Author

justinchuby commented Jun 25, 2024

torch.onnx

  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:21:35  Samples:  8542
 /_//_/// /_\ / //_// / //_'/ //     Duration: 39.021    CPU time: 46.264
/   _/                      v4.6.2

Program: /Users/justinc/Documents/GitHub/torch-onnx/venv/bin/optimum-cli export onnx --model openai/whisper-large-v3 whisper2/

39.020 export_pytorch  optimum/exporters/onnx/convert.py:484
└─ 39.020 export  torch/onnx/utils.py:189
      [56 frames hidden]  torch, <built-in>, optimum, transformers
         25.273 _optimize_graph  torch/onnx/utils.py:574
         ├─ 13.000 PyCapsule._jit_pass_onnx_graph_shape_type_inference  <built-in>
         ├─ 8.258 [self]  torch/onnx/utils.py

@justinchuby
Copy link
Owner Author

justinchuby commented Jun 25, 2024

torch.onnx

  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:22:28  Samples:  6560
 /_//_/// /_\ / //_// / //_'/ //     Duration: 29.185    CPU time: 42.493
/   _/                      v4.6.2

Program: /Users/justinc/Documents/GitHub/torch-onnx/venv/bin/optimum-cli export onnx --model openai/whisper-large-v3 whisper2/

29.184 export_pytorch  optimum/exporters/onnx/convert.py:484
└─ 29.184 export  torch/onnx/utils.py:189
      [59 frames hidden]  torch, <built-in>, optimum, transformers
         19.346 _optimize_graph  torch/onnx/utils.py:574
         ├─ 10.390 PyCapsule._jit_pass_onnx_graph_shape_type_inference  <built-in>
         ├─ 6.166 [self]  torch/onnx/utils.py

@justinchuby
Copy link
Owner Author

justinchuby commented Jun 25, 2024

Profiling result dynamo


  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:11:34  Samples:  18373
 /_//_/// /_\ / //_// / //_'/ //     Duration: 19.273    CPU time: 21.785
/   _/                      v4.6.2

Program: /Users/justinc/Documents/GitHub/torch-onnx/venv/bin/optimum-cli export onnx --model openai/whisper-large-v3 whisper/

19.273 export  torch_onnx/_core.py:793
├─ 12.199 export  torch/export/__init__.py:73
│     [269 frames hidden]  torch, contextlib, copy, dis, optimum...
└─ 7.071 exported_program_to_ir  torch_onnx/_core.py:618
   ├─ 5.120 wrapper  torch/export/exported_program.py:80
   │     [77 frames hidden]  torch, <string>
   ├─ 1.281 _add_nodes  torch_onnx/_core.py:486
   │  └─ 1.265 _handle_call_function_node_with_lowering  torch_onnx/_core.py:356
   │     └─ 0.903 TracedOnnxFunction.__call__  ../../onnxscript/onnxscript/values.py:581
   │        └─ 0.620 SymbolicTensor.aten_view  ../../onnxscript/onnxscript/function_libs/torch_lib/ops/core.py:8740
   │           └─ 0.435 Opset18.Cast  ../../onnxscript/onnxscript/onnx_opset/_impl/opset13.py:241
   │              └─ 0.430 Op.__call__  ../../onnxscript/onnxscript/values.py:291
   │                 └─ 0.423 OpRecorder.eval  torch_onnx/_building.py:390
   │                    └─ 0.274 OpRecorder._call_op  torch_onnx/_building.py:352
   │                       └─ 0.250 _process_python_constants_and_sequences  torch_onnx/_building.py:185
   │                          └─ 0.214 TensorProtocol.__instancecheck__  typing.py:1990
   │                                [2 frames hidden]  typing
   └─ 0.491 insert_type_promotion_nodes  torch_onnx/_fx_passes.py:13
      └─ 0.453 wrapper  torch/onnx/_internal/diagnostics/infra/decorator.py:71
            [13 frames hidden]  torch

@justinchuby
Copy link
Owner Author

justinchuby commented Jun 25, 2024

Profiling result dynamo


  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:12:24  Samples:  17774
 /_//_/// /_\ / //_// / //_'/ //     Duration: 18.581    CPU time: 21.069
/   _/                      v4.6.2

Program: /Users/justinc/Documents/GitHub/torch-onnx/venv/bin/optimum-cli export onnx --model openai/whisper-large-v3 whisper/

18.580 export  torch_onnx/_core.py:793
├─ 12.479 export  torch/export/__init__.py:73
│     [302 frames hidden]  torch, contextlib, copy, dis, ast, op...
└─ 6.098 exported_program_to_ir  torch_onnx/_core.py:618
   ├─ 4.429 wrapper  torch/export/exported_program.py:80
   │     [73 frames hidden]  torch, <string>
   ├─ 0.877 _add_nodes  torch_onnx/_core.py:486
   │  └─ 0.860 _handle_call_function_node_with_lowering  torch_onnx/_core.py:356
   │     └─ 0.560 TracedOnnxFunction.__call__  ../../onnxscript/onnxscript/values.py:581
   │        └─ 0.385 SymbolicTensor.aten_view  ../../onnxscript/onnxscript/function_libs/torch_lib/ops/core.py:8740
   │           └─ 0.220 Opset18.Cast  ../../onnxscript/onnxscript/onnx_opset/_impl/opset13.py:241
   │              └─ 0.212 Op.__call__  ../../onnxscript/onnxscript/values.py:291
   │                 └─ 0.211 OpRecorder.eval  torch_onnx/_building.py:390
   └─ 0.604 insert_type_promotion_nodes  torch_onnx/_fx_passes.py:13
      └─ 0.574 wrapper  torch/onnx/_internal/diagnostics/infra/decorator.py:71
            [17 frames hidden]  torch

@justinchuby
Copy link
Owner Author

torch.onnx memory

   582   6297.8 MiB   6297.8 MiB           1               @memory_profiler.profile
   583                                                     def export():
   584   7491.5 MiB   1193.7 MiB           2                   onnx_export(
   585   6297.8 MiB      0.0 MiB           1                       model,
   586   6297.8 MiB      0.0 MiB           1                       (dummy_inputs,),
   587   6297.8 MiB      0.0 MiB           1                       f=output.as_posix(),
   588   6297.8 MiB      0.0 MiB           1                       input_names=input_names,
   589   6297.8 MiB      0.0 MiB           1                       output_names=output_names,
   590                                                             # dynamic_axes=dynamix_axes,
   591   6297.8 MiB      0.0 MiB           1                       do_constant_folding=do_constant_folding,
   592   6297.8 MiB      0.0 MiB           1                       opset_version=opset,
   593                                                         )


Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   582   8819.4 MiB   8819.4 MiB           1               @memory_profiler.profile
   583                                                     def export():
   584   9594.7 MiB    775.3 MiB           2                   onnx_export(
   585   8819.4 MiB      0.0 MiB           1                       model,
   586   8819.4 MiB      0.0 MiB           1                       (dummy_inputs,),
   587   8819.4 MiB      0.0 MiB           1                       f=output.as_posix(),
   588   8819.4 MiB      0.0 MiB           1                       input_names=input_names,
   589   8819.4 MiB      0.0 MiB           1                       output_names=output_names,
   590                                                             # dynamic_axes=dynamix_axes,
   591   8819.4 MiB      0.0 MiB           1                       do_constant_folding=do_constant_folding,
   592   8819.4 MiB      0.0 MiB           1                       opset_version=opset,
   593                                                         )


Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   582   5770.4 MiB   5770.4 MiB           1               @memory_profiler.profile
   583                                                     def export():
   584   8630.9 MiB   2860.5 MiB           2                   onnx_export(
   585   5770.4 MiB      0.0 MiB           1                       model,
   586   5770.4 MiB      0.0 MiB           1                       (dummy_inputs,),
   587   5770.4 MiB      0.0 MiB           1                       f=output.as_posix(),
   588   5770.4 MiB      0.0 MiB           1                       input_names=input_names,
   589   5770.4 MiB      0.0 MiB           1                       output_names=output_names,
   590                                                             # dynamic_axes=dynamix_axes,
   591   5770.4 MiB      0.0 MiB           1                       do_constant_folding=do_constant_folding,
   592   5770.4 MiB      0.0 MiB           1                       opset_version=opset,
   593                                                         )

@justinchuby
Copy link
Owner Author

dynamo memory usage

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   582   6295.5 MiB   6295.5 MiB           1               @memory_profiler.profile
   583                                                     def export():
   584   6295.5 MiB  -1803.0 MiB           2                   onnx_export(
   585   6295.5 MiB      0.0 MiB           1                       model,
   586   6295.5 MiB      0.0 MiB           1                       (dummy_inputs,),
   587   6295.5 MiB      0.0 MiB           1                       f=output.as_posix(),
   588   6295.5 MiB      0.0 MiB           1                       input_names=input_names,
   589   6295.5 MiB      0.0 MiB           1                       output_names=output_names,
   590                                                             # dynamic_axes=dynamix_axes,
   591   6295.5 MiB      0.0 MiB           1                       do_constant_folding=do_constant_folding,
   592   6295.5 MiB      0.0 MiB           1                       opset_version=opset,
   593                                                         )


Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   582   4711.7 MiB   4711.7 MiB           1               @memory_profiler.profile
   583                                                     def export():
   584   4820.3 MiB    108.6 MiB           2                   onnx_export(
   585   4711.7 MiB      0.0 MiB           1                       model,
   586   4711.7 MiB      0.0 MiB           1                       (dummy_inputs,),
   587   4711.7 MiB      0.0 MiB           1                       f=output.as_posix(),
   588   4711.7 MiB      0.0 MiB           1                       input_names=input_names,
   589   4711.7 MiB      0.0 MiB           1                       output_names=output_names,
   590                                                             # dynamic_axes=dynamix_axes,
   591   4711.7 MiB      0.0 MiB           1                       do_constant_folding=do_constant_folding,
   592   4711.7 MiB      0.0 MiB           1                       opset_version=opset,
   593                                                         )


Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   582   5004.2 MiB   5004.2 MiB           1               @memory_profiler.profile
   583                                                     def export():
   584   5373.4 MiB    369.2 MiB           2                   onnx_export(
   585   5004.2 MiB      0.0 MiB           1                       model,
   586   5004.2 MiB      0.0 MiB           1                       (dummy_inputs,),
   587   5004.2 MiB      0.0 MiB           1                       f=output.as_posix(),
   588   5004.2 MiB      0.0 MiB           1                       input_names=input_names,
   589   5004.2 MiB      0.0 MiB           1                       output_names=output_names,
   590                                                             # dynamic_axes=dynamix_axes,
   591   5004.2 MiB      0.0 MiB           1                       do_constant_folding=do_constant_folding,
   592   5004.2 MiB      0.0 MiB           1                       opset_version=opset,
   593                                                         )

@justinchuby
Copy link
Owner Author

justinchuby commented Jun 25, 2024

new dynamo w/o external tensor IR pass

image

@justinchuby
Copy link
Owner Author

justinchuby commented Jun 25, 2024

torch.onnx

image
  • Increase may be memory leak.

@justinchuby
Copy link
Owner Author

justinchuby commented Jun 25, 2024

There may be shared tensors, as safetensors is 3GB but the external data is 6,11, and 16GB?

@justinchuby
Copy link
Owner Author

justinchuby commented Jun 25, 2024

new dynamo with IR pass that avoids creating tensor protos first (theoretical)

image
Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   583   6315.5 MiB   6315.5 MiB           1               @memory_profiler.profile
   584                                                     def export():
   585   6439.4 MiB    123.9 MiB           2                   onnx_export(
   586   6315.5 MiB      0.0 MiB           1                       model,
   587   6315.5 MiB      0.0 MiB           1                       (dummy_inputs,),
   588   6315.5 MiB      0.0 MiB           1                       f=output.as_posix(),
   589   6315.5 MiB      0.0 MiB           1                       input_names=input_names,
   590   6315.5 MiB      0.0 MiB           1                       output_names=output_names,
   591                                                             # dynamic_axes=dynamix_axes,
   592   6315.5 MiB      0.0 MiB           1                       do_constant_folding=do_constant_folding,
   593   6315.5 MiB      0.0 MiB           1                       opset_version=opset,
   594   6315.5 MiB      0.0 MiB           1                       export_params=False, # MARK
   595                                                         )

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   583   6444.5 MiB   6444.5 MiB           1               @memory_profiler.profile
   584                                                     def export():
   585   6541.0 MiB     96.5 MiB           2                   onnx_export(
   586   6444.5 MiB      0.0 MiB           1                       model,
   587   6444.5 MiB      0.0 MiB           1                       (dummy_inputs,),
   588   6444.5 MiB      0.0 MiB           1                       f=output.as_posix(),
   589   6444.5 MiB      0.0 MiB           1                       input_names=input_names,
   590   6444.5 MiB      0.0 MiB           1                       output_names=output_names,
   591                                                             # dynamic_axes=dynamix_axes,
   592   6444.5 MiB      0.0 MiB           1                       do_constant_folding=do_constant_folding,
   593   6444.5 MiB      0.0 MiB           1                       opset_version=opset,
   594   6444.5 MiB      0.0 MiB           1                       export_params=False, # MARK
   595                                                         )

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   583   6543.2 MiB   6543.2 MiB           1               @memory_profiler.profile
   584                                                     def export():
   585   6588.6 MiB     45.4 MiB           2                   onnx_export(
   586   6543.2 MiB      0.0 MiB           1                       model,
   587   6543.2 MiB      0.0 MiB           1                       (dummy_inputs,),
   588   6543.2 MiB      0.0 MiB           1                       f=output.as_posix(),
   589   6543.2 MiB      0.0 MiB           1                       input_names=input_names,
   590   6543.2 MiB      0.0 MiB           1                       output_names=output_names,
   591                                                             # dynamic_axes=dynamix_axes,
   592   6543.2 MiB      0.0 MiB           1                       do_constant_folding=do_constant_folding,
   593   6543.2 MiB      0.0 MiB           1                       opset_version=opset,
   594   6543.2 MiB      0.0 MiB           1                       export_params=False, # MARK
   595                                                         )

@justinchuby
Copy link
Owner Author

current dynamo

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   583   6304.6 MiB   6304.6 MiB           1               @memory_profiler.profile
   584                                                     def export():
   585   6304.6 MiB      0.0 MiB           1                   export_options = torch.onnx.ExportOptions(dynamic_shapes=False)
   586   8923.8 MiB   2619.2 MiB           4                   onnx_program = torch.onnx.dynamo_export(
   587   6304.6 MiB      0.0 MiB           1                       model,
   588   6304.6 MiB      0.0 MiB           1                       export_options = export_options,
   589   6304.6 MiB      0.0 MiB           1                       **dummy_inputs,
   590                                                         )

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   583   1634.0 MiB   1634.0 MiB           1               @memory_profiler.profile
   584                                                     def export():
   585   1634.0 MiB      0.0 MiB           1                   export_options = torch.onnx.ExportOptions(dynamic_shapes=False)
   586   8891.5 MiB   7257.5 MiB           4                   onnx_program = torch.onnx.dynamo_export(
   587   1634.0 MiB      0.0 MiB           1                       model,
   588   1634.0 MiB      0.0 MiB           1                       export_options = export_options,
   589   1634.0 MiB      0.0 MiB           1                       **dummy_inputs,
   590                                                         )

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   583   3830.5 MiB   3830.5 MiB           1               @memory_profiler.profile
   584                                                     def export():
   585   3830.8 MiB      0.2 MiB           1                   export_options = torch.onnx.ExportOptions(dynamic_shapes=False)
   586   7247.3 MiB   3416.6 MiB           4                   onnx_program = torch.onnx.dynamo_export(
   587   3830.8 MiB      0.0 MiB           1                       model,
   588   3830.8 MiB      0.0 MiB           1                       export_options = export_options,
   589   3830.8 MiB      0.0 MiB           1                       **dummy_inputs,
   590                                                         )
image

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant