Skip to content

[torch.compile] Request for shape ranges in torch.compile workflow #115137

Closed
@peri044

Description

@peri044

🚀 The feature, motivation and pitch

Torch-TensorRT has two workflows (torch.export based and torch.compile based) for optimizing Pytorch models using TensorRT.

  • torch.export.export() flow can accept dynamic shapes with min and max ranges specified for any input dimension. This range information can be accessible for intermediate nodes in the graph via node.shape_env.var_to_range. TensorRT uses this range information when building engines (with dynamic shapes).

  • torch.compile provides torch._dynamo.mark_dynamic(tensor, dim) which is great. But can this be extended to accept ranges (or something similar to torch.export.Dim(name, min, max) API ?

Ultimately, the ask here is for the graph (provided by torch.compile to its backend) to have the range information (similar to torch.export.export()).

Please let me know if there's a way to do this currently (or) any questions you have.

Alternatives

No response

Additional context

No response

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions