Working through the fast.ai book in Rust - Part 3

favilo

Favil Orbedios

Posted on November 21, 2023

Working through the fast.ai book in Rust - Part 3

Introduction

In Part 2, we went over downloading and extracting our datasets from the internet.

Today in Part 3, we are going to go over reading that data and transforming it into a form that our models can actually understand.

Tensor objects are potentially multi-dimensional arrays that store numbers. And since files on a computer are all just numbers, we can utilize that to our advantage, and create something called a Rank 3 Tensor of our images. This is a 3 dimensional structure that represents a single image in it's entirety. The dimensions of the Tensor represent the height, width, and the color channels of the image.

Specifically for the ResNet model that chapter 1 wants us to work with these images need to be in the shape of 3x224x224. That corresponds to 3 color channels, and height and width both set to 224. Because of this size limitation, we are going to need to resize our images before we store them in our Tensors.

Refactoring

First things first, in the last Part of this series, we just wrote all of our code into a single module. That is going to get overwhelming quickly, so let's refactor our code into separate modules.

This is how our tardyai/src/lib.rs will look

pub mod download;
pub mod error;

pub use self::{
    download::{untar_images, Url},
    error::Error,
};
Enter fullscreen mode Exit fullscreen mode

So in order to accomplish this, we can pull out the Error enum into tardyai/src/error.rs

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("reqwest error: {0}")]
    Reqwest(#[from] reqwest::Error),

    #[error("io error: {0}")]
    IO(#[from] std::io::Error),

    #[error("homedir error: {0}")]
    Home(#[from] homedir::GetHomeError),

    #[error("tar entry error: {0}")]
    TarEntry(&'static str),
}
Enter fullscreen mode Exit fullscreen mode

And we can pull everything else out into tardyai/src/download.rs.

use std::{
    fs::File,
    io::{self, Seek},
    path::PathBuf,
};

use flate2::read::GzDecoder;
use tar::Archive;

use crate::error::Error;

const S3_BASE: &str = "https://s3.amazonaws.com/fast-ai-";
const S3_IMAGE: &str = "imageclas/";

#[derive(Debug, Clone, Copy)]
pub enum Url {
    Pets,
}

impl Url {
    pub fn url(self) -> String {
        match self {
            Self::Pets => {
                format!("{S3_BASE}{S3_IMAGE}oxford-iiit-pet.tgz")
            }
        }
    }
}

fn ensure_dir(path: &PathBuf) -> Result<(), Error> {
    if !path.exists() {
        std::fs::create_dir_all(path)?;
    }
    Ok(())
}

pub fn untar_images(url: Url) -> Result<PathBuf, Error> {
    let home = &homedir::get_my_home()?
        .expect("home directory needs to exist")
        .join(".tardyai");
    let dest_dir = home.join("archive");
    ensure_dir(&dest_dir)?;
    let archive_file = download_archive(url, &dest_dir)?;

    let dest_dir = home.join("data");
    let dir = extract_archive(&archive_file, &dest_dir)?;

    Ok(dir)
}

fn download_archive(url: Url, dest_dir: &PathBuf) -> Result<PathBuf, Error> {
    let mut response = reqwest::blocking::get(url.url())?;
    let archive_name = response
        .url()
        .path_segments()
        .and_then(|s| s.last())
        .and_then(|name| if name.is_empty() { None } else { Some(name) })
        .unwrap_or("tmp.tar.gz");

    let archive_file = dest_dir.join(archive_name);

    // TODO: check if the archive is valid and exists
    if archive_file.exists() {
        log::info!("Archive already exists: {}", archive_file.display());
        return Ok(archive_file);
    }

    log::info!(
        "Downloading {} to archive: {}",
        url.url(),
        archive_file.display()
    );
    let mut dest = File::create(&archive_file)?;
    response.copy_to(&mut dest)?;
    Ok(archive_file)
}

fn extract_archive(archive_file: &PathBuf, dest_dir: &PathBuf) -> Result<PathBuf, Error> {
    let tar_gz = File::open(archive_file)?;
    let tar = GzDecoder::new(tar_gz);
    let mut archive = Archive::new(tar);

    log::info!(
        "Extracting archive {} to: {}",
        archive_file.display(),
        dest_dir.display()
    );
    let dir = {
        let entry = &archive
            .entries()?
            .next()
            .ok_or(Error::TarEntry("No entries in archive"))??;
        entry.path()?.into_owned()
    };
    let archive_dir = dest_dir.join(dir);
    if archive_dir.exists() {
        log::info!("Archive already extracted to: {}", archive_dir.display());
        return Ok(archive_dir);
    }

    let tar = archive.into_inner();
    let mut tar_gz = tar.into_inner();
    tar_gz.seek(io::SeekFrom::Start(0))?;
    let tar = GzDecoder::new(tar_gz);
    let mut archive = Archive::new(tar);
    archive.unpack(dest_dir)?;

    Ok(archive_dir)
}
Enter fullscreen mode Exit fullscreen mode

And that is done. Super simple, but much more organized.

On to finding images

Now that that's out of the way, we can get to the actual meat of this article.

As a good starting place let's create a new module for our datasets. And the associated tardyai/src/datasets.rs file.

pub mod datasets;
Enter fullscreen mode Exit fullscreen mode

And now we need to create a struct that can keep track of all the image files we want to read.

use std::path::PathBuf;

pub struct DirectoryImageDataset {
    files: Vec<PathBuf>,
}
Enter fullscreen mode Exit fullscreen mode

Let's also create a constructor that takes the parent directory, walks all the files, and collects all the images into the files field. This is going to use the walkdir crate for walking the parent directory and fetching the names of images that match extensions, and the image crate for listing the actual file extensions that are supported as well as eventually decoding the images and loading them into memory.

impl DirectoryImageDataset {
    pub fn new(parent: PathBuf) -> Result<Self, Error> {
        let exts = image_extensions();

        let walker = WalkDir::new(parent).follow_links(true).into_iter();
        let files = walker
            .filter_map(|entry| {
                let entry = entry.ok()?;
                entry
                    .path()
                    .extension()
                    .and_then(|ext| Some(exts.contains(ext.to_str()?)))?
                    .then_some(entry)
            })
            .map(|entry| entry.path().to_owned())
            .collect();
        Ok(Self { files })
    }
}
Enter fullscreen mode Exit fullscreen mode

This uses a function called image_extensions that returns a HashSet of all the supported file extensions that I feel like including right now.

fn image_extensions() -> HashSet<&'static str> {
    let mut set = HashSet::default();
    set.extend(ImageFormat::Jpeg.extensions_str());
    set.extend(ImageFormat::Png.extensions_str());
    set.extend(ImageFormat::Gif.extensions_str());
    set.extend(ImageFormat::WebP.extensions_str());
    set.extend(ImageFormat::Tiff.extensions_str());
    set.extend(ImageFormat::Bmp.extensions_str());
    set.extend(ImageFormat::Qoi.extensions_str());
    set
}
Enter fullscreen mode Exit fullscreen mode

I've also added a utility method for fetching the list of files that are represented

impl DirectoryImageDataset {
    // pub fn new(...)

    pub fn files(&self) -> &[PathBuf] {
        &self.files
    }
}
Enter fullscreen mode Exit fullscreen mode

Let's get to the Tensors now

Now that we've discovered all the images, we just need to get them into memory. For now, I'm planning on doing this the naive way and just load the images each time they are requested. We can optimize it later by adding a cache

The dfdx crate has a trait called ExactSizeDataset that it can use to power the training epochs and calculate the validation loss at the end.

It's a pretty straightforward interface that has only two methods that are needed, a len(), and a get() method. len() is pretty self explanatory. get() takes an index and returns something you specify. Generally this can be both a Tensor, and the label that that Tensor represents.

impl ExactSizeDataset for DirectoryImageDataset<'_> {
    type Item<'a> = Result<(Tensor<Rank3<3, 224, 224>, f32, AutoDevice>, bool), Error>
    where
        Self: 'a;

    fn get(&self, index: usize) -> Self::Item<'_> {
        let image_file = &self.files[index];
        // Read the image and resize it to 224x224, and 3 channels
        let image = ImageReader::open(image_file)?
            .decode()?
            .resize_exact(224, 224, FilterType::Triangle)
            .into_rgb8();

        // Shrink the byte values to f32 between [0, 1]
        let bytes: Vec<f32> = image.as_bytes().iter().map(|&b| b as f32 / 255.0).collect();

        // Create the tensor and the label
        Ok((
            self.dev
                .tensor_from_vec(bytes, (Const::<3>, Const::<224>, Const::<224>)),
            (self.label_fn)(image_file),
        ))
    }

    fn len(&self) -> usize {
        self.files.len()
    }
}
Enter fullscreen mode Exit fullscreen mode

I'm doing a lot here. Probably the most important is that I need to modify the struct to contain the device in order to actually construct the Tensor objects. This is because the Tensor might be stored on the GPU, and the device facilitates constructing them there.

I've also made the struct slightly generic, by storing a function that maps the path to the label, in this case the label is hard coded to a bool. We'll need to take care of that later, but I don't want to borrow tomorrow's problem just yet.

Here's the new structure for the struct, along with the associated constructor changes to support it. I unfortunately had to add a lifetime to handle the label function, but it's not too bad, we can get away with lifetime elision except when we are constructing it. I suppose I could have gone with 'static but that would have meant that I could only pass functions, not lambdas like I am hoping to eventually use. (So much for not borrowing tomorrow's problems, I guess)

pub struct DirectoryImageDataset<'fun> {
    files: Vec<PathBuf>,
    dev: AutoDevice,
    label_fn: Box<dyn Fn(&Path) -> bool + 'fun>,
}

impl<'fun> DirectoryImageDataset<'fun> {
    pub fn new(
        parent: PathBuf,
        dev: AutoDevice,
        label_fn: impl Fn(&Path) -> bool + 'fun,
    ) -> Result<Self, Error> {
        let exts = image_extensions();

        let walker = WalkDir::new(parent).follow_links(true).into_iter();
        let files = walker
            .filter_map(|entry| {
                let entry = entry.ok()?;
                entry
                    .path()
                    .extension()
                    .and_then(|ext| Some(exts.contains(ext.to_str()?)))?
                    .then_some(entry)
            })
            .map(|entry| entry.path().to_owned())
            .collect();
        Ok(Self {
            files,
            dev,
            label_fn: Box::new(label_fn),
        })
    }
}
Enter fullscreen mode Exit fullscreen mode

How do we optimize it?

Right now, as we run through the epochs if we started training right now, we'd run into a problem with speed. We aren't caching the Tensors, so they are getting cleared and rebuilt each time through the training loop.

That will take much longer than if we just cached all the values the first time. So by spending a little memory (about 2 GiB from my back of the envelope math), we can improve our performance, significantly.

Now, the ExactSizeDataset doesn't allow us to get a mutable reference to self, so we can't just use a HashMap and be done, because all the points where we would be able to add tensors to the cache don't have a mutable reference.

So to help us fix that issue, and help us in the event that we are calling get() in parallel (say if we are trying to load all the images with rayon, for example), we can use the dashmap crate. DashMap is a concurrent hashmap, which means we can update the keys and values concurrently in multiple threads. This also means that we can insert values with only a shared reference.

pub struct DirectoryImageDataset<'fun> {
    files: Vec<PathBuf>,
    dev: AutoDevice,
    label_fn: Box<dyn Fn(&Path) -> bool + 'fun>,
    tensors: DashMap<PathBuf, Tensor<Rank3<3, 224, 224>, f32, AutoDevice>>,
}
Enter fullscreen mode Exit fullscreen mode

We also need to set this in the constructor

// pub fn new(...) ... {
        // ...
        Ok(Self {
            files,
            dev,
            label_fn: Box::new(label_fn),
            tensors: Default::default(),
        })
}
Enter fullscreen mode Exit fullscreen mode

And the final step is to pipe it through the get() method.

    fn get(&self, index: usize) -> Self::Item<'_> {
        let image_file = &self.files[index];

        // v---- New stuff here ---v
        let label = (self.label_fn)(image_file);
        if self.tensors.contains_key(image_file) {
            return Ok((self.tensors.get(image_file).unwrap().clone(), label));
        }
        // Read the image and resize it to 224x224, and 3 channels
        let image = ImageReader::open(image_file)?
            .decode()?
            .resize_exact(224, 224, FilterType::Triangle)
            .into_rgb8();

        // Shrink the byte values to f32 between [0, 1]
        let bytes: Vec<f32> = image.as_bytes().iter().map(|&b| b as f32 / 255.0).collect();

        // Create the tensor and the label
        let tensor = self
            .dev
            .tensor_from_vec(bytes, (Const::<3>, Const::<224>, Const::<224>));

        // v--- And here ---v
        self.tensors.insert(image_file.clone(), tensor.clone());
        Ok((tensor, label))
    }
Enter fullscreen mode Exit fullscreen mode

Finally, I've updated the chapter1/src/main.rsfile to use the new struct.

    // ...
    let path: PathBuf = untar_images(Url::Pets)
        .context("downloading Pets")?
        .join("images");
    log::info!("Images are in: {}", path.display());

    // v--- New stuff ---v
    let dev = AutoDevice::default();

    // Silly thing about the Pets dataset, all the cats have a capital first letter in their
    // filename, all the dogs are lowercase only
    let is_cat = |path: &Path| {
        path.file_name()
            .and_then(|n| n.to_str())
            .and_then(|n| n.chars().next().map(|c| c.is_uppercase()))
            .unwrap_or(false)
    };

    let dataset = DirectoryImageDataset::new(path, dev.clone(), is_cat)?;
    log::info!("Found {} files", dataset.files().len());

    // ...
Enter fullscreen mode Exit fullscreen mode

Conclusion

Well, that was more satisfying than the last two articles, we're finally actually able to store things on the GPU, or at least in a Tensor. To actually build this for GPU support, you need to build it with the cuda feature enabled.

Check out the code at github.

In Part 4, we're going to go through the process of building the ResNet-34 model in rust. Stay tuned! I'm so excited!

💖 💪 🙅 🚩
favilo
Favil Orbedios

Posted on November 21, 2023

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related