Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Backend

We have effectively written most of the necessary code to train our model. However, we have not explicitly designated the backend to be used at any point. This will be defined in the main entrypoint of our program, namely the main function defined in src/main.rs.

#![recursion_limit = "256"]
mod data;
mod model;
mod training;

use crate::{model::ModelConfig, training::TrainingConfig};
use mabor::{
    backend::{Autodiff, Wgpu},
    data::dataset::Dataset,
    optim::AdamConfig,
};

fn main() {
    type MyBackend = Wgpu<f32, i32>;
    type MyAutodiffBackend = Autodiff<MyBackend>;

    let device = mabor::backend::wgpu::WgpuDevice::default();
    let artifact_dir = "/tmp/guide";
    crate::training::train::<MyAutodiffBackend>(
        artifact_dir,
        TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
        device.clone(),
    );
}

In this code snippet, we use the Wgpu backend which is compatible with any operating system and will use the GPU. For other options, see the Mabor README. This backend type takes the graphics API, the float type and the int type as generic arguments that will be used during the training. The autodiff backend is simply the same backend, wrapped within the Autodiff struct which imparts differentiability to any backend.

We call the train function defined earlier with a directory for artifacts, the configuration of the model (the number of digit classes is 10 and the hidden dimension is 512), the optimizer configuration which in our case will be the default Adam configuration, and the device which can be obtained from the backend.

You can now train your freshly created model with the command:

cargo run --release

When running your project with the command above, you should see the training progression through a basic CLI dashboard:

Alt text