Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Feat/Split ONNX Import #2568

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft

Conversation

agelas
Copy link
Contributor

@agelas agelas commented Nov 29, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#2440

Changes

Adds split to the list of supported ops that can be imported via ONNX.

Testing

tbd

@agelas
Copy link
Contributor Author

agelas commented Dec 14, 2024

@antimora For nodes that produce multiple outputs, is there a recommended pattern for how to assign and return these outputs in the forward method? I tried to do the reverse of concat, but still having a bit of trouble.

Also, in the generated ONNX graph, all three outputs of the graph are named "split1_out1", but the outputs of the node are named uniquely. How do we ensure that each output of a multi-output node is named uniquely (assuming they should be)? I keep getting this error when I try generating the IR:

thread 'main' panicked at crates/burn-import/src/burn/graph.rs:566:40:
Output type not found for split1_out1

I have a feeling that the non-unique names might be tripping this up, because the overall structure of the outputs is consistent with other IRs I've generated.

I also copied some of the generated graph to show you what I mean.

ParsedOnnxGraph(
    // omitting the constant node for brevity
            Node {
                // omitting a bunch of stuff here too
                outputs: [
                    Argument {
                        name: "split1_out1",
                        ty: Tensor(
                            TensorType {
                                elem_type: Int64,
                                dim: 2,
                                shape: None,
                            },
                        ),
                        value: None,
                        passed: false,
                    },
                    Argument {
                        name: "split1_out2",
                        ty: Tensor(
                            TensorType {
                                elem_type: Int64,
                                dim: 2,
                                shape: None,
                            },
                        ),
                        value: None,
                        passed: false,
                    },
                    Argument {
                        name: "split1_out3",
                        ty: Tensor(
                            TensorType {
                                elem_type: Int64,
                                dim: 2,
                                shape: None,
                            },
                        ),
                        value: None,
                        passed: false,
                    },
                ],
                attrs: {
                    "axis": Int64(
                        0,
                    ),
                },
            },
        ],
        inputs: [
           // ignoring this
        ],
        outputs: [
            Argument {
                name: "split1_out1", <--- this name is the same for everything in outputs[]
                ty: Tensor(
                    TensorType {
                        elem_type: Int64,
                        dim: 2,
                        shape: None,
                    },
                ),
                value: None,
                passed: false,
            },
            Argument {
                name: "split1_out1",
                ty: Tensor(
                    TensorType {
                        elem_type: Int64,
                        dim: 2,
                        shape: None,
                    },
                ),
                value: None,
                passed: false,
            },
            Argument {
                name: "split1_out1",
                ty: Tensor(
                    TensorType {
                        elem_type: Int64,
                        dim: 2,
                        shape: None,
                    },
                ),
                value: None,
                passed: false,
            },
        ],
    },
)

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant