Skip to content

fix/feat: Add lowering pass to resolve most aten::Int.Tensor uses #1937

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 2 commits into from
May 30, 2023

Conversation

gs-olive
Copy link
Collaborator

Description

  • Adds improved support for full-conversion for a variety of models
  • Implement lowering pass which detects canonical aten::Int.Tensor cases and recursively replaces input Value pointers until all 0D tensors have been resolved to their scalar components
  • Lowering pass is specialized to replacing strictly integer-typed Value pointers and can only trace through aten::mul and aten::floor_divide operators, which are two of the most common cases of use
  • Lowering pass traverses the graph until one of three base cases are encountered (or an invalid Value type is detected). These cases are prim::NumToTensor, prim::Constant (0D tensor), or simple integers. It then replaces the child nodes with the integer equivalents of the produced Tensors
  • Added extensive testing of new capabilities for accuracy, robustness, and functionality

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes #1880
Fixes #1836
Fixes #513

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive requested a review from narendasan May 19, 2023 21:50
@gs-olive gs-olive self-assigned this May 19, 2023
@github-actions github-actions bot added component: core Issues re: The core compiler component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests labels May 19, 2023
@gs-olive gs-olive changed the title fix/feat: Add lowering pass to resolve aten::Int.Tensor fix/feat: Add lowering pass to resolve most aten::Int.Tensor invocations May 19, 2023
@gs-olive gs-olive changed the title fix/feat: Add lowering pass to resolve most aten::Int.Tensor invocations fix/feat: Add lowering pass to resolve most aten::Int.Tensor uses May 19, 2023
@gs-olive gs-olive force-pushed the replace_aten_int_schema branch from 8d43fb0 to 0ca6a52 Compare May 19, 2023 21:53
- Implement lowering pass which detects canonical `aten::Int.Tensor`
cases and recursively replaces input Value pointers until all 0D tensors
have been resolved to their scalar components
- Lowering pass is specialized to replacing strictly integer-typed Value pointers
and can only trace through aten::mul and aten::floor_divide operators,
which are two of the most common cases of use
- Lowering pass traverses the graph until one of three base cases are
encountered (or an invalid Value type is detected). These cases are
`prim::NumToTensor`, `prim::Constant` (0D tensor), or simple integers.
It then replaces the child nodes with the integer equivalents of the
produced Tensors
- Added extensive testing of new capabilities for accuracy, robustness,
and functionality
@gs-olive gs-olive force-pushed the replace_aten_int_schema branch from 0ca6a52 to a86ac93 Compare May 19, 2023 21:55
@gs-olive gs-olive requested a review from narendasan May 22, 2023 17:20
torch::jit::aten::floor_divide,
};

c10::optional<torch::jit::Value*> Validate0DTensor(torch::jit::Value* value) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to use c10::optional wrapper instead of nullptr + replaced pointer checks with .has_value()

- Edit in favor of `c10::optional` type usage
@gs-olive gs-olive force-pushed the replace_aten_int_schema branch from d3c0c7a to 15d9fcd Compare May 23, 2023 04:54
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gs-olive gs-olive merged commit 3fc3c6d into pytorch:main May 30, 2023
@gs-olive gs-olive deleted the replace_aten_int_schema branch May 30, 2023 21:54
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
cla signed component: core Issues re: The core compiler component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
3 participants