Skip to content

feat: Add support for device compilation setting #2190

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged

Conversation

gs-olive
Copy link
Collaborator

@gs-olive gs-olive commented Aug 10, 2023

Description

  • Enables support for device compilation settings in Dynamo paths
  • Developed in a way such that extensions for multiple gpus can be made in a straightforward fashion
    • Default device type/format can be customized
  • Defaults to current present device, with context-aware switching, enable the following workflow for torch compile:
    with torch.cuda.device(1):
        model = models.resnet18(pretrained=True).eval().cuda()
        input = torch.randn((1, 3, 224, 224)).cuda()

        compile_spec = {
            "inputs": [
                torchtrt.Input(
                    input.shape, dtype=torch.float, format=torch.contiguous_format
                )
            ],
            "enabled_precisions": {torch.float},
            "debug": True,
            "ir": "torch_compile",
        }

        trt_mod = torchtrt.compile(model, **compile_spec)

Fixes #2172

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive requested review from narendasan and peri044 August 10, 2023 22:48
@gs-olive gs-olive self-assigned this Aug 10, 2023
@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Aug 10, 2023
@github-actions github-actions bot requested a review from apbose August 10, 2023 22:48
@gs-olive gs-olive force-pushed the device_setting_checking_standardization branch from fa81096 to 9e1e58e Compare August 10, 2023 23:58
@github-actions github-actions bot added the component: tests Issues re: Tests label Aug 10, 2023
@gs-olive gs-olive force-pushed the device_setting_checking_standardization branch from 9e1e58e to 8e32c07 Compare August 11, 2023 16:58
@gs-olive gs-olive force-pushed the device_setting_checking_standardization branch from 8e32c07 to 776839a Compare August 15, 2023 22:12
@gs-olive gs-olive force-pushed the device_setting_checking_standardization branch from 776839a to 89764ce Compare August 22, 2023 22:36
- Add updated Device utilities and automatic context-aware device
detection for torch compile
- Add testing for new utilities
@gs-olive gs-olive force-pushed the device_setting_checking_standardization branch from 89764ce to ba18185 Compare August 23, 2023 20:46
@@ -54,3 +56,4 @@ class CompilationSettings:
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
use_fast_partitioner: bool = USE_FAST_PARTITIONER
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Device is now populated with a default factory once the CompilationSettings object gets instantiated.

@@ -38,7 +43,7 @@ def compile(
gm: Any,
inputs: Any,
*,
device: Device = Device._current_device(),
device: Optional[Union[Device, torch.device, str]] = DEVICE,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Device can now be None, which will take the Torch default device.

Comment on lines +153 to +154
elif device is None:
return Device(gpu_id=torch.cuda.current_device())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the device is None, we use the default Torch context device

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gs-olive gs-olive merged commit 64ce49b into pytorch:main Aug 25, 2023
@gs-olive gs-olive deleted the device_setting_checking_standardization branch August 25, 2023 02:44
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Device Auto-Detection in Torch Compile Path
4 participants