Running ML models in a game (and in Wasm!)
François
Posted on February 19, 2021
Using Tract and Bevy, guess the number being drawn.
TL;DR: Try it here - Read the code here
One of my colleague recently starred Tract github repository and got me wondering how easy it would be to use. I know how to create a ONNX model with PyTorch Lightning, or even that there are pretrained models available. Spoilers: it's very easy to integrate!
The last couple of month, I've been doing game jams, Ludum Dare 47 and Game Off 2020 with Bevy, so I wanted to check If I could easily use Tract in a game made with Bevy and build for a Wasm target.
Running a ONNX model with Tract
Tract example
Using Tract to run a ONNX model is fearly easy, and the example provided is great.
First, you load the model, specifying it's input:
let model = tract_onnx::onnx()
// load the model
.model_for_path("mobilenetv2-1.0.onnx")?
// specify input type and shape
.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 3, 224, 224)))?
// optimize the model
.into_optimized()?
// make the model runnable and fix its inputs and outputs
.into_runnable()?;
Then take an image and transform it to an array with the expected shape and normalize the values:
// open image, resize it and make a Tensor out of it
let image = image::open("grace_hopper.jpg").unwrap().to_rgb8();
let resized =
image::imageops::resize(&image, 224, 224, ::image::imageops::FilterType::Triangle);
let image: Tensor = tract_ndarray::Array4::from_shape_fn((1, 3, 224, 224), |(_, c, y, x)| {
let mean = [0.485, 0.456, 0.406][c];
let std = [0.229, 0.224, 0.225][c];
(resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std
})
.into();
And finally run the model and get the result with the best score:
// run the model on the input
let result = model.run(tvec!(image))?;
// find and display the max value with its index
let best = result[0]
.to_array_view::<f32>()?
.iter()
.cloned()
.zip(2..)
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
println!("result: {:?}", best);
With MNIST
For MNIST, I use the model from ONNX MNIST. It takes an image of 28px by 28px as input, so its shape is (1, 1, 28, 28)
.
Its output is an array of 10 float numbers, representing the score of each digit. I can get the digit with the best score:
let result = model.model.run(tvec!(image)).unwrap();
if let Some((value, score)) = result[0]
.to_array_view::<f32>()
.unwrap()
.iter()
.cloned()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
{
if score > 10. {
println!("{:?}", value);
}
}
Game setup
Loading the model
To load the model from an .onnx
file, I need to create a custom asset loader for this format. Following bevy custom asset loader example, I first declare my asset type OnnxModel
#[derive(Debug, TypeUuid)]
#[uuid = "578fae90-a8de-41ab-a4dc-3aca66a31eed"]
pub struct OnnxModel {
pub model: SimplePlan<
TypedFact,
Box<dyn TypedOp>,
tract_onnx::prelude::Graph<TypedFact, Box<dyn TypedOp>>,
>,
}
And then I implement AssetLoader
for OnnxModelLoader
{:target=_"blank"}. Even though I know the input shape of my model, I do not call with_input_fact
in the loader to be independent on the model loaded.
#[derive(Default)]
pub struct OnnxModelLoader;
impl AssetLoader for OnnxModelLoader {
fn load<'a>(
&'a self,
mut bytes: &'a [u8],
load_context: &'a mut LoadContext,
) -> BoxedFuture<'a, Result<(), anyhow::Error>> {
Box::pin(async move {
let model = tract_onnx::onnx()
.model_for_read(&mut bytes)
.unwrap()
.into_optimized()?
.into_runnable()?;
load_context.set_default_asset(LoadedAsset::new(OnnxModel { model }));
Ok(())
})
}
fn extensions(&self) -> &[&str] {
&["onnx"]
}
}
I create a struct
to hold the Handle
to the model so that I can reuse the loaded model without needing to reload it every time
struct State {
model: Handle<OnnxModel>,
}
impl FromResources for State {
fn from_resources(resources: &Resources) -> Self {
let asset_server = resources.get::<AssetServer>().unwrap();
State {
model: asset_server.load("model.onnx"),
}
}
}
And finally, I add my asset loader and my resource to the Bevy app:
App::build();
.add_asset::<OnnxModel>()
.init_asset_loader::<OnnxModelLoader>()
.init_resource::<State>();
Drawing to a texture
To display a texture, I use standart Bevy UI components, here an ImageBundle
. I also mark this entity with components Interaction
and FocusPolicy
from bevy::ui
so that it can react to mouse clicks and movements.
To be input type independent (touch and mouse) I create an event Draw
that takes a position coordinates. I can then trigger this event on CursorMoved
events when mouse is clicked over the texture in a system
fn drawing_mouse(
(mut reader, events): (Local<EventReader<CursorMoved>>, Res<Events<CursorMoved>>),
mut last_mouse_position: Local<Option<Vec2>>,
mut texture_events: ResMut<Events<Event>>,
state: Res<State>,
drawable: Query<(&Interaction, &GlobalTransform, &Style), With<Drawable>>,
) {
for (interaction, transform, style) in drawable.iter() {
if let Interaction::Clicked = interaction {
// Get the width and height of the texture
let width = if let Val::Px(x) = style.size.width {
x
} else {
0.
};
let height = if let Val::Px(x) = style.size.height {
x
} else {
0.
};
// For every `CursorMoved` event
for event in reader.iter(&events) {
if let Some(last_mouse_position) = *last_mouse_position {
// If mouvement is fast, interpolate positions between last known position
// and current position from event
let steps =
(last_mouse_position.distance(event.position) as u32 / INPUT_SIZE + 1) * 3;
for i in 0..steps {
let lerped =
last_mouse_position.lerp(event.position, i as f32 / steps as f32);
// Change cursor position from window to texture
let x = lerped.x - transform.translation.x + width / 2.;
let y = lerped.y - transform.translation.y + height / 2.;
// And send the event to draw at this position
texture_events.send(Event::Draw(Vec2::new(x, y)));
}
} else {
let x = event.position.x - transform.translation.x + width / 2.;
let y = event.position.y - transform.translation.y + height / 2.;
texture_events.send(Event::Draw(Vec2::new(x, y)));
}
*last_mouse_position = Some(event.position);
}
} else {
*last_mouse_position = None;
}
}
}
And to actually draw on the texture, I listen for this event and use a brush to color the texture around the event coordinates:
fn update_texture(
(mut reader, events): (Local<EventReader<Event>>, Res<Events<Event>>),
materials: Res<Assets<ColorMaterial>>,
mut textures: ResMut<Assets<Texture>>,
mut state: ResMut<State>,
drawable: Query<(&bevy::ui::Node, &Handle<ColorMaterial>), With<Drawable>>,
) {
for event in reader.iter(&events) {
// Retrieving the texture data from it's `Handle`
// First, getting the `Handle<ColorMaterial>` of the `ImageBundle`
let (node, mat) = drawable.iter().next().unwrap();
// Then, getting the `ColorMaterial` matching this handle
let material = materials.get(mat).unwrap();
// Finally, getting the texture itself from the `texture` field of the `ColorMaterial`
let texture = textures
.get_mut(material.texture.as_ref().unwrap())
.unwrap();
match event {
Event::Draw(pos) => {
// Use a large round brush instead of drawing pixel by pixel
// `node.size` is the displayed size of the texture
// `texture.size` is the actual size of the texture data
// `INPUT_SIZE` is the expected input size by the model
// The brush will be bigger if the drawing area is bigger to provide
// a smoother drawing experience
let radius = (1.3 * node.size.x / INPUT_SIZE as f32 / 2.) as i32;
let scale = (texture.size.width as f32 / node.size.x) as i32;
for i in -radius..(radius + 1) {
for j in -radius..(radius + 1) {
let target_point = Vec2::new(pos.x + i as f32, pos.y + j as f32);
if pos.distance(target_point) < radius as f32 {
for i in 0..=scale {
for j in 0..=scale {
set_pixel(
(target_point.x as i32) * scale + i,
((node.size.y as f32 - target_point.y) as i32) * scale + j,
255,
texture,
)
}
}
}
}
}
}
}
}
}
Getting model input from texture, and infering digit
I can now run my model on my texture and guess the digit!. This model is fast enough that this can run at every frame
fn infer(
state: Res<State>,
materials: Res<Assets<ColorMaterial>>,
textures: Res<Assets<Texture>>,
models: Res<Assets<OnnxModel>>,
drawable: Query<&Handle<ColorMaterial>, With<Drawable>>,
mut display: Query<&mut Text>,
) {
for mat in drawable.iter() {
// Get the texture from the `Handle<ColorMaterial>`
let material = materials.get(mat).unwrap();
let texture = textures.get(material.texture.as_ref().unwrap()).unwrap();
// As the texture is much larger than the model input, each point in the
// model input will be 1 if at least half of the point in a square
// of `pixel_size`in the texture are colored
let pixel_size = (texture.size.width as u32 / INPUT_SIZE) as i32;
let image = tract_ndarray::Array4::from_shape_fn(
(1, 1, INPUT_SIZE as usize, INPUT_SIZE as usize),
|(_, _, y, x)| {
let mut val = 0;
for i in 0..pixel_size as i32 {
for j in 0..pixel_size as i32 {
val += get_pixel(
x as i32 * pixel_size + i,
y as i32 * pixel_size + j,
texture,
) as i32;
}
}
if val > pixel_size * pixel_size / 2 {
1. as f32
} else {
0. as f32
}
},
)
.into();
if let Some(model) = models.get(state.model.as_weak::<OnnxModel>()) {
// Run the model on the input
let result = model.model.run(tvec!(image)).unwrap();
// Get the best prediction, and display it if its score is high enough
if let Some((value, score)) = result[0]
.to_array_view::<f32>()
.unwrap()
.iter()
.cloned()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
{
if score > 10. {
display.iter_mut().next().unwrap().value = format!("{:?}", value);
} else {
display.iter_mut().next().unwrap().value = "".to_string();
}
}
}
}
}
Build for Wasm target
Thanks to bevy_webgl2, this is actually very straightforward. I just need to add the plugin WebGL2Plugin
and disable the default features of Bevy to only enable the one available on Wasm.
There is a bevy webgl2 template if you want to build a game that can works on all platforms.
Github Actions to deploy to itch.io
To host this POC (and my other games) on itch.io, I have a workflow on Github Actions that build the game for Windows, Linux, macOS and Wasm, create a release on Github, and push everything to itch.io. This workflow is triggered on tags.
It is made of two jobs. The first one will, on each platform:
- setup the environment, install dependencies and tools
- build the game
- perform platform specific steps (
strip
,wasm-bindgen
, ...) and copy assets if needed - create an archive for the platform (
dmg
for macOS,zip
for the other) - create a release on Github and add those archives to the release
- save the archives as artifacts
The second job will take the artifacts from the first job and send them to itch.io using butler, itch.io official command line
Posted on February 19, 2021
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.