Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Refactor onnx-ir to remove static shape inference and rely on rank inference only #2478

Open
antimora opened this issue Nov 11, 2024 · 6 comments
Assignees
Labels

Comments

@antimora
Copy link
Collaborator

antimora commented Nov 11, 2024

The goal of this refactor is to remove all static shape inference within the onnx-ir module and focus solely on rank inference. This shift aims to:

  1. Simplify the Shape Inference Process:

    • Remove static shape inference logic and rely on rank-based inference to determine tensor ranks across operations. This approach will support the necessary runtime shapes without involving complex static shape calculations.
  2. Align with ONNX and Runtime Shapes:

    • onnx-ir has distinctions between onnx-dynamic-shapes, runtime-shapes, and static-shapes, often causing confusion. By removing static shape inference, we emphasize runtime-shapes, which are central to burn-import, thereby reducing the cognitive load on developers.
  3. Enable Cleaner and More Focused Code:

    • The current codebase has shape and rank inference intertwined, leading to redundant boilerplate and added complexity, particularly in ONNX operations. With this refactor, we aim to streamline code maintenance and make the inference logic easier for developers to follow and extend.

Rationale

Static shape inference at build time should be handled separately, reducing the dependency on static shape information. This proposal emphasizes a cleaner, rank-inference-only approach that aligns with ONNX’s dynamic capabilities and reduces the overhead for burn-import contributors.

Action Items

  • Remove Static Shape Inference Code: Eliminate existing shape inference logic from onnx-ir.
  • Refactor for Rank-Only Inference: Adjust the onnx-ir and tensor APIs to perform rank inference exclusively.
  • Testing: Ensure the updated onnx-ir handles a variety of ONNX models with consistent rank inference across all operations.

Benefits

This refactor will unburden developers from managing complex static shapes, streamline the review process, and make Burn more adaptable for dynamic ONNX models.

@antimora antimora added the onnx label Nov 11, 2024
@antimora
Copy link
Collaborator Author

CCing: @laggui, @skewballfox, tiruka, @hexd0t, @nathanielsimard

@skewballfox
Copy link
Contributor

skewballfox commented Nov 11, 2024

While this is a bit short of runtime support, I've been thinking about a way to support compile time dynamic shapes with onnx-ir. It essentially boils down to:

  • change the type dim to be an enum where a value is either a usize or some kind of identifier ( I guess a string).
  • add a method to graph_builder for setting to pass in a Hashmap that provides the initial size of any named dynamic dimensions for the inputs
  • each time a the size of a shape needs to be inferred (the enum is a named variant) we compute it's size from the information we have so far.

This, again, would fall a bit short of runtime support, but allows more flexibility for build time configuration, and I think we'd need the same information if using burn-import for buillding the models at runtime.

EDIT: from what I remember, the onnx parsing for tract involved some sort of Solver, I assume this is what they meant

EDIT2: from #2304

I recommend we do not track shapes statically primarily because generated ONNX Models should be able to accept different size inputs at runtime

@antimora I'm making a few assumptions here and feel free to correct me if any of them are off base. The way I see this being used (at runtime) is the user loads an ONNX file at the beginning of a program or runtime initialization, the graph is built, and then the sizes don't change. The inputs are coming through a loop where they are exactly the same size every time.

If that doesn't necessarily hold and the inputs can change, then we should probably set the sizes(or capacity) to be the expected max input size, partially so if broadcasting is used, the tensors won't have to be moved around in memory.

In either case, for the intermediate representation produced by onnx-ir, it makes sense to keep the shape data available.

After reading the comments, I think I am also in favor of tracking shape information. But this means that we need to make sure all ops currently track shapes properly, including already implemented ones (which is a bit more restrictive and will require more work).

@laggui I could open a PR for this. It might have to wait until the weekend but I know this part of the code well enough that it wouldn't take me too long. I could do this separately from the stuff about inferred shapes above to make it easier to review

@antimora
Copy link
Collaborator Author

@skewballfox

I should have made requirements more clear. Currently there are two requirements:

  1. Burn's Shape Handling: For Burn, we aim to avoid static shapes entirely (Burn's own graph will track shapes), which encourages more flexible creative approaches when passing shape information is necessary. This can include creating or fusing into nodes like NodeLike to facilitate operations without relying on static shape data.

  2. ONNX-IR Shape Inference: With ONNX-IR, dynamic shape inference can be implemented where shapes (including dynamic dimensions that can be represented as identifiers (-1 for example)) are inferred. This allows rank and dimension inference from input and output node information, enhancing flexibility without static dependencies.

I would like to take up the first one behind a feature flag in onnx-ir crate. I want to disable shape information for burn-import because it is impossible to guarantee that a runtime shape would match the static shape information generated during the build.

@antimora antimora self-assigned this Nov 14, 2024
@antimora antimora changed the title Refactor burn-onnxir to Enhance Dynamic Shape Support and Shape Inference Refactor onnx-ir to remove static shape inference to rely on rank inference only Nov 14, 2024
@antimora antimora changed the title Refactor onnx-ir to remove static shape inference to rely on rank inference only Refactor onnx-ir to remove static shape inference and rely on rank inference only Nov 14, 2024
@skewballfox
Copy link
Contributor

For 1, could you point me to some relevant examples of what the syntax might look like? or just code out some examples here? I'm mainly trying to get a better idea of what the changes would be required for node shape handling.

This allows rank and dimension inference from input and output node information, enhancing flexibility without static dependencies.

Even if the shapes can be changed at runtime, it still seems reasonable to want to initialize the tensors with some expected capacity.

@antimora
Copy link
Collaborator Author

For 1, could you point me to some relevant examples of what the syntax might look like? or just code out some examples here? I'm mainly trying to get a better idea of what the changes would be required for node shape handling.

This allows rank and dimension inference from input and output node information, enhancing flexibility without static dependencies.

Even if the shapes can be changed at runtime, it still seems reasonable to want to initialize the tensors with some expected capacity.

I am currently refactoring onnx-ir and burn-import.

The main change is that I removed shape from TensorType. Renamed dim to rank, and Tensor to TensorData, which still contains shape because it's part of the input data.

#[derive(Debug, Clone, Default)]
pub struct TensorType {
    /// The type of the tensor.
    pub elem_type: ElementType,

    /// The dimension of the tensor.
    pub rank: Rank,
}

#[derive(Debug, Clone, Default)]
pub struct TensorData {
    /// The type of the tensor.
    pub elem_type: ElementType,

    /// The dimension of the tensor.
    pub rank: Rank,

    /// The data of the tensor.
    pub data: Option<Data>,

    /// The shape of the tensor.
    pub shape: Option<Shape>,
}

/// The type of an attribute.
#[derive(Debug, Clone)]
pub enum AttributeValue {
    Float32(f32),
    Float32s(Vec<f32>),
    Int64(i64),
    Int64s(Vec<i64>),
    String(String),
    Strings(Vec<String>),
    Tensor(TensorData),
    Tensors(Vec<TensorData>),
}

/// A node input or output.
#[derive(Debug, Clone)]
pub struct Argument {
    /// The name of the node input.
    pub name: String,

    /// The type of the argument.
    pub ty: ArgType,

    /// The data of the argument.
    pub value: Option<TensorData>,

    /// True if the argument is passed to node, false otherwise. We use it mainly for informational purposes.
    /// The argument should contain a value if passed is false.
    pub passed: bool,
}

The yellow bar indicates a change.

image image image

@antimora
Copy link
Collaborator Author

Even if the shapes can be changed at runtime, it still seems reasonable to want to initialize the tensors with some expected capacity.

This feature can be implemented if a use case warrants it. Shape inferencing likely won't be needed in most cases since other libraries using onnx-ir will track shapes at runtime themselves.

Currently, onnx-ir tracks shapes statically, mainly from tensor information contained in ONNX files. I didn't want downstream developers to use it assuming it was accurate information.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants