Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

ONNX to Burn: Development Guide

This guide offers in-depth design insights and step-by-step procedures for developers working on the ONNX to Mabor conversion tool. This tool allows the importation of ONNX models into the Mabor deep learning framework written in Rust. It converts both ONNX models to Rust source code and model weights to Mabor state files.

For an introduction to ONNX import in Mabor, see this section of the Mabor book.

Design Overview

Design Goals

  • Perform best-effort conversion of ONNX models to Rust source code via Mabor APIs.
  • Convert ONNX model weights to Mabor state files.
  • Support ONNX models generated by PyTorch (ONNX Opset 16).
  • Produce easy-to-understand and modifiable models.
  • Ensure the generated models are trainable using Mabor APIs.

Design Decisions

  • Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process.
  • Ensure operator behavior consistency across different OpSet versions.
  • Exclude any ONNX/Protobuf-specific logic from the Mabor graph.

The conversion process involves three main stages:

  1. Convert ONNX model to Intermediate Representation (IR).
  2. Translate IR to a Mabor graph.
  3. Generate Rust source code from the Mabor graph.

Adding New Operators

To extend mabor-import with support for new ONNX operators, follow these steps:

  1. Create PyTorch Script: Place a PyTorch script using the new operator under crates/mabor-import/onnx-tests/tests/<op>/<op>.py. Make sure to print both input and output tensors for end-to-end testing.

  2. Generate ONNX Model: Run the PyTorch script to produce an ONNX model.

  3. Visualize ONNX Model: Use Netron to verify the ONNX model contains the expected operators.

  4. Generate IR and Mabor Graph: Navigate to crates/mabor-import/ and run:

    cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out
    
  5. Implement Missing Operators: If you encounter an error stating that an operator is unsupported, implement it. The ./out/my-model.graph.txt should provide relevant information.

  6. Inspect Generated Files: The my-model.graph.txt contains IR details, my-model.rs holds the Mabor model in Rust code, and my-model.json includes the model data.

  7. Add End-to-End Test: Include the test in crates/mabor-import/onnx-tests/tests/test_onnx.rs. Further details can be found in the onnx-tests README.

Implementing a New Operator

To extend the capabilities of the Mabor library by supporting new operations imported from ONNX graphs, developers must go through a few systematic steps. Here, we detail the process, using the implementation of the Squeeze operation to illustrate points as needed. All file/directory paths are relative to the root of the mabor repository.

Step 1: Visibility

To make a new operation accessible, there are two key modules to update:

  1. In crates/onnx-ir/src/node/mod.rs, add your new operation module to make it visible within the IR
  2. In crates/mabor-import/src/mabor/node/mod.rs, make the corresponding node type visible within mabor-import

Step 2: Node Implementation

Within onnx-ir

The onnx-ir crate handles the Intermediate Representation (IR) of ONNX models. For each operation:

  1. Add the operation to the NodeType enum in crates/onnx-ir/src/ir.rs.

  2. Create a new module file in crates/onnx-ir/src/node/<operation_name>.rs. This file should include:

    • A <operation_name>_config function to extract operation parameters
    • A <operation_name>_update_output function for dimension inference
  3. If the operation might work with constants, add it to the list of node types checked for constants in crates/onnx-ir/src/from_onnx.rs.

For example, the squeeze operation is defined in crates/onnx-ir/src/node/squeeze.rs and contains:

  • A squeeze_config function that extracts axes from node attributes
  • A squeeze_update_output function that updates output dimensions by reducing input rank

Within mabor-import

  1. Create a new file named <operation_name>.rs in the crates/mabor-import/src/mabor/node/ directory. This file will define the structure and functionality of your new operation. By convention, the necessary information for carrying out an operation is encapsulated within a struct named <operation>Node. For the Squeeze operation, we defined a struct called SqueezeNode that holds necessary information about the input tensor, output tensor, and axes for the operation. If implementing a unary or binary operation, please see note below.

  2. The core of integrating a new operation involves implementing the NodeCodegen trait for your node. This trait defines how the node generates code during the graph compilation process. The implementation must provide methods to define input and output types, to generate the forward pass code, and to encapsulate the node into the more general Node structure. Specifically:

    • output_types and input_types return the tensor (or element) types for the output and inputs of the node, respectively.
    • forward generates the Rust code that performs the operation during the execution phase. The quote! macro is used to generate rust code. Ensure that this is syntactically correct using Mabor code.
    • into_node wraps the specific node in a general Node type, facilitating its inclusion in the broader Mabor graph structure.
  3. This file is also where you would put test_codegen_nodes(), to make sure that the generated code works within the Mabor library.

For unary and binary operations: The implementation of NodeCodegen is mostly implemented in binary.rs and unary.rs, so each new operation only has to define a method to execute the function on the input(s) token stream.

Step 3: Registering New Operations

  1. In crates/mabor-import/src/onnx/to_mabor.rs, add the operation to the match statement in the into_mabor() method:
#![allow(unused)]
fn main() {
impl ParsedOnnxGraph {
    pub fn into_mabor<PS: PrecisionSettings + 'static>(self) -> MaborGraph<PS> {
        // ...
        for node in self.0.nodes {
            match node.node_type {
                // ...
                NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)),
                // Add your new operation here
            }
        }
    }
}
}
  1. Create a conversion function that creates an instance of your Mabor node:
#![allow(unused)]
fn main() {
fn squeeze_conversion(node: Node) -> SqueezeNode {
    let input = TensorType::from(node.inputs.first().unwrap());
    let output = TensorType::from(node.outputs.first().unwrap());
    let axes = squeeze_config(&node);

    SqueezeNode::new(input, output, axes)
}
}

This function extracts the necessary information from the ONNX node and passes it to your node's constructor.

Step 4: Create a Config Function

In crates/onnx-ir/src/node/<operation_name>.rs, create a config function that extracts operation-specific parameters from the ONNX node:

#![allow(unused)]
fn main() {
pub fn squeeze_config(curr: &Node) -> Vec<i64> {
    let axes = curr
        .attrs
        .iter()
        .filter_map(|(key, value)| {
            if key == "axes" {
                Some(value.clone().into_i64s())
            } else {
                None
            }
        })
        .next()
        .unwrap_or_else(Vec::new);

    match curr.inputs.first().unwrap().clone().ty {
        ArgType::Tensor(tensor) => tensor,
        _ => panic!("Only tensor input is valid"),
    };

    axes
}
}

This config function is responsible for parsing the ONNX node attributes and extracting operation-specific parameters. In this case, it extracts the "axes" attribute from the squeeze operation.

Step 5: Rank Inference

In crates/onnx-ir/src/node/<operation_name>.rs, implement a rank inference function that updates the output rank based on the operation:

#![allow(unused)]
fn main() {
pub fn squeeze_update_output(node: &mut Node) {
    // Extract axes information
    let axes = /* ... */;
    let input_rank = /* ... */;
    let output_rank = input_rank - axes.len();

    // Update output rank
    node.outputs[0].ty = ArgType::Tensor(TensorType {
        elem_type: node.inputs[0].ty.elem_type().clone(),
        rank: output_rank,
        static_shape: None,
    });
}
}

Then register this function in crates/onnx-ir/src/rank_inference.rs by adding it to the match statement:

#![allow(unused)]
fn main() {
pub fn rank_inference(node: &mut Node) {
    match node.node_type {
        // ...
        NodeType::Squeeze => squeeze_update_output(node),
        // Add your new operation here
    }
}
}

The rank_inference.rs file is responsible for determining the output tensor rank for each node in the graph.

If the rank remains unchanged, you can use helper functions like same_as_input() or same_as_input_broadcast() instead of writing a custom update function.

Step 6: Integrate into the Graph Building Process

When a new node type is introduced, it must be added to the Node<PS: PrecisionSettings> enum in crates/mabor-import/src/mabor/node/base.rs and the match_all! macro in the same file.

The Node enum abstracts over different types of operations (nodes) within a network graph. Each variant of the enum corresponds to a specific type of operation and encapsulates the operation-specific data structures (like SqueezeNode) that were defined in step 2.

Step 7: Add Newly Supported Op!

As a reward, add an extra check to crates/mabor-import/SUPPORTED-ONNX-OPS.md!

Lifting Constant Nodes

If your operation takes inputs from constant nodes (such as weights in Conv1d, shape tensors in Reshape, etc.), you need to add your operation's NodeType to the LIFT_CONSTANTS_FOR_NODE_TYPES array in crates/onnx-ir/src/from_onnx.rs.

#![allow(unused)]
fn main() {
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 16] = [
    NodeType::BatchNormalization,
    // other operations...
    NodeType::Squeeze,
    NodeType::Unsqueeze,
    // Add your operation here if it needs constants to be processed
];
}

"Lifting" constants means converting Constant nodes into direct input values. This is similar to how ONNX initializers work. For example, instead of having a separate Constant node providing weights to a Convolution operation, the weights are directly embedded as values in the Convolution node's inputs.

This transformation makes it easier to:

  1. Access the constant values during node configuration
  2. Process operations like Conv1d that expect weights as direct inputs
  3. Handle shape-defining inputs needed for operations like Reshape

Without this, operations that need to extract configuration from constant inputs (such as shapes, weights, or other parameters) would not work correctly because they wouldn't have direct access to those constant values.

Testing

When implementing a new operator, there are several levels of testing to consider:

Unit Testing

  • Node Configuration: Write unit tests for the <operation_name>_config function in crates/onnx-ir/src/node/<operation_name>.rs to verify that it correctly extracts parameters from ONNX nodes.

  • Rank Inference: Test the <operation_name>_update_output function to ensure it correctly computes output ranks.

  • Code Generation: Test the Node implementation in mabor-import to verify that it generates correct Rust code.

Integration Testing

  • Create small ONNX models that use your operator and test the end-to-end conversion process
  • Ensure the generated Rust code compiles and produces the expected outputs
  • Add these tests to crates/mabor-import/onnx-tests/tests/test_onnx.rs

End-to-End Testing

  • Test with realistic ONNX models that use your operator in conjunction with others
  • Verify that inputs and outputs match between the original ONNX model and the converted Mabor model
  • Include models that test edge cases (e.g., different input shapes, parameter combinations)

Testing both the rank inference and node configuration is particularly important as these components directly affect the correctness of the conversion process. Incorrect rank inference can lead to mismatched tensor shapes, while incorrect configuration can cause runtime errors or incorrect results.

Resources

  1. PyTorch to ONNX
  2. ONNX to PyTorch
  3. ONNX Introduction
  4. ONNX Operators
  5. ONNX Protos
  6. ONNX Optimizer
  7. Netron