ONNX to Burn: Development Guide
This guide offers in-depth design insights and step-by-step procedures for developers working on the ONNX to Mabor conversion tool. This tool allows the importation of ONNX models into the Mabor deep learning framework written in Rust. It converts both ONNX models to Rust source code and model weights to Mabor state files.
For an introduction to ONNX import in Mabor, see this section of the Mabor book.
Design Overview
Design Goals
- Perform best-effort conversion of ONNX models to Rust source code via Mabor APIs.
- Convert ONNX model weights to Mabor state files.
- Support ONNX models generated by PyTorch (ONNX Opset 16).
- Produce easy-to-understand and modifiable models.
- Ensure the generated models are trainable using Mabor APIs.
Design Decisions
- Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process.
- Ensure operator behavior consistency across different OpSet versions.
- Exclude any ONNX/Protobuf-specific logic from the Mabor graph.
The conversion process involves three main stages:
- Convert ONNX model to Intermediate Representation (IR).
- Translate IR to a Mabor graph.
- Generate Rust source code from the Mabor graph.
Adding New Operators
To extend mabor-import
with support for new ONNX operators, follow these steps:
-
Create PyTorch Script: Place a PyTorch script using the new operator under
crates/mabor-import/onnx-tests/tests/<op>/<op>.py
. Make sure to print both input and output tensors for end-to-end testing. -
Generate ONNX Model: Run the PyTorch script to produce an ONNX model.
-
Visualize ONNX Model: Use Netron to verify the ONNX model contains the expected operators.
-
Generate IR and Mabor Graph: Navigate to crates/mabor-import/ and run:
cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out
-
Implement Missing Operators: If you encounter an error stating that an operator is unsupported, implement it. The
./out/my-model.graph.txt
should provide relevant information. -
Inspect Generated Files: The
my-model.graph.txt
contains IR details,my-model.rs
holds the Mabor model in Rust code, andmy-model.json
includes the model data. -
Add End-to-End Test: Include the test in crates/mabor-import/onnx-tests/tests/test_onnx.rs. Further details can be found in the onnx-tests README.
Implementing a New Operator
To extend the capabilities of the Mabor library by supporting new operations imported from ONNX
graphs, developers must go through a few systematic steps. Here, we detail the process, using the
implementation of the Squeeze
operation to illustrate points as needed. All file/directory paths
are relative to the root of the mabor repository.
Step 1: Visibility
To make a new operation accessible, there are two key modules to update:
- In
crates/onnx-ir/src/node/mod.rs
, add your new operation module to make it visible within the IR - In
crates/mabor-import/src/mabor/node/mod.rs
, make the corresponding node type visible within mabor-import
Step 2: Node Implementation
Within onnx-ir
The onnx-ir
crate handles the Intermediate Representation (IR) of ONNX models. For each operation:
-
Add the operation to the
NodeType
enum incrates/onnx-ir/src/ir.rs
. -
Create a new module file in
crates/onnx-ir/src/node/<operation_name>.rs
. This file should include:- A
<operation_name>_config
function to extract operation parameters - A
<operation_name>_update_output
function for dimension inference
- A
-
If the operation might work with constants, add it to the list of node types checked for constants in
crates/onnx-ir/src/from_onnx.rs
.
For example, the squeeze operation is defined in crates/onnx-ir/src/node/squeeze.rs
and contains:
- A
squeeze_config
function that extracts axes from node attributes - A
squeeze_update_output
function that updates output dimensions by reducing input rank
Within mabor-import
-
Create a new file named
<operation_name>.rs
in thecrates/mabor-import/src/mabor/node/
directory. This file will define the structure and functionality of your new operation. By convention, the necessary information for carrying out an operation is encapsulated within a struct named<operation>Node
. For theSqueeze
operation, we defined a struct calledSqueezeNode
that holds necessary information about the input tensor, output tensor, and axes for the operation. If implementing a unary or binary operation, please see note below. -
The core of integrating a new operation involves implementing the
NodeCodegen
trait for your node. This trait defines how the node generates code during the graph compilation process. The implementation must provide methods to define input and output types, to generate the forward pass code, and to encapsulate the node into the more generalNode
structure. Specifically:output_types
andinput_types
return the tensor (or element) types for the output and inputs of the node, respectively.forward
generates the Rust code that performs the operation during the execution phase. Thequote!
macro is used to generate rust code. Ensure that this is syntactically correct using Mabor code.into_node
wraps the specific node in a generalNode
type, facilitating its inclusion in the broader Mabor graph structure.
-
This file is also where you would put
test_codegen_nodes()
, to make sure that the generated code works within the Mabor library.
For unary and binary operations: The implementation of NodeCodegen
is mostly implemented in
binary.rs and unary.rs, so each new operation only has to define a method to execute the function on
the input(s) token stream.
Step 3: Registering New Operations
- In
crates/mabor-import/src/onnx/to_mabor.rs
, add the operation to the match statement in theinto_mabor()
method:
#![allow(unused)] fn main() { impl ParsedOnnxGraph { pub fn into_mabor<PS: PrecisionSettings + 'static>(self) -> MaborGraph<PS> { // ... for node in self.0.nodes { match node.node_type { // ... NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), // Add your new operation here } } } } }
- Create a conversion function that creates an instance of your Mabor node:
#![allow(unused)] fn main() { fn squeeze_conversion(node: Node) -> SqueezeNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); let axes = squeeze_config(&node); SqueezeNode::new(input, output, axes) } }
This function extracts the necessary information from the ONNX node and passes it to your node's constructor.
Step 4: Create a Config Function
In crates/onnx-ir/src/node/<operation_name>.rs
, create a config function that extracts
operation-specific parameters from the ONNX node:
#![allow(unused)] fn main() { pub fn squeeze_config(curr: &Node) -> Vec<i64> { let axes = curr .attrs .iter() .filter_map(|(key, value)| { if key == "axes" { Some(value.clone().into_i64s()) } else { None } }) .next() .unwrap_or_else(Vec::new); match curr.inputs.first().unwrap().clone().ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), }; axes } }
This config function is responsible for parsing the ONNX node attributes and extracting operation-specific parameters. In this case, it extracts the "axes" attribute from the squeeze operation.
Step 5: Rank Inference
In crates/onnx-ir/src/node/<operation_name>.rs
, implement a rank inference function that updates
the output rank based on the operation:
#![allow(unused)] fn main() { pub fn squeeze_update_output(node: &mut Node) { // Extract axes information let axes = /* ... */; let input_rank = /* ... */; let output_rank = input_rank - axes.len(); // Update output rank node.outputs[0].ty = ArgType::Tensor(TensorType { elem_type: node.inputs[0].ty.elem_type().clone(), rank: output_rank, static_shape: None, }); } }
Then register this function in crates/onnx-ir/src/rank_inference.rs
by adding it to the match
statement:
#![allow(unused)] fn main() { pub fn rank_inference(node: &mut Node) { match node.node_type { // ... NodeType::Squeeze => squeeze_update_output(node), // Add your new operation here } } }
The rank_inference.rs
file is responsible for determining the output tensor rank for each node in
the graph.
If the rank remains unchanged, you can use helper functions like same_as_input()
or
same_as_input_broadcast()
instead of writing a custom update function.
Step 6: Integrate into the Graph Building Process
When a new node type is introduced, it must be added to the Node<PS: PrecisionSettings>
enum in
crates/mabor-import/src/mabor/node/base.rs
and the match_all!
macro in the same file.
The Node
enum abstracts over different types of operations (nodes) within a network graph. Each
variant of the enum corresponds to a specific type of operation and encapsulates the
operation-specific data structures (like SqueezeNode
) that were defined in step 2.
Step 7: Add Newly Supported Op!
As a reward, add an extra check to crates/mabor-import/SUPPORTED-ONNX-OPS.md
!
Lifting Constant Nodes
If your operation takes inputs from constant nodes (such as weights in Conv1d, shape tensors in
Reshape, etc.), you need to add your operation's NodeType
to the LIFT_CONSTANTS_FOR_NODE_TYPES
array in crates/onnx-ir/src/from_onnx.rs
.
#![allow(unused)] fn main() { const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 16] = [ NodeType::BatchNormalization, // other operations... NodeType::Squeeze, NodeType::Unsqueeze, // Add your operation here if it needs constants to be processed ]; }
"Lifting" constants means converting Constant nodes into direct input values. This is similar to how ONNX initializers work. For example, instead of having a separate Constant node providing weights to a Convolution operation, the weights are directly embedded as values in the Convolution node's inputs.
This transformation makes it easier to:
- Access the constant values during node configuration
- Process operations like Conv1d that expect weights as direct inputs
- Handle shape-defining inputs needed for operations like Reshape
Without this, operations that need to extract configuration from constant inputs (such as shapes, weights, or other parameters) would not work correctly because they wouldn't have direct access to those constant values.
Testing
When implementing a new operator, there are several levels of testing to consider:
Unit Testing
-
Node Configuration: Write unit tests for the
<operation_name>_config
function incrates/onnx-ir/src/node/<operation_name>.rs
to verify that it correctly extracts parameters from ONNX nodes. -
Rank Inference: Test the
<operation_name>_update_output
function to ensure it correctly computes output ranks. -
Code Generation: Test the Node implementation in
mabor-import
to verify that it generates correct Rust code.
Integration Testing
- Create small ONNX models that use your operator and test the end-to-end conversion process
- Ensure the generated Rust code compiles and produces the expected outputs
- Add these tests to
crates/mabor-import/onnx-tests/tests/test_onnx.rs
End-to-End Testing
- Test with realistic ONNX models that use your operator in conjunction with others
- Verify that inputs and outputs match between the original ONNX model and the converted Mabor model
- Include models that test edge cases (e.g., different input shapes, parameter combinations)
Testing both the rank inference and node configuration is particularly important as these components directly affect the correctness of the conversion process. Incorrect rank inference can lead to mismatched tensor shapes, while incorrect configuration can cause runtime errors or incorrect results.