Working through the fast.ai book in Rust - Part 3
Favil Orbedios
Posted on November 21, 2023
Introduction
In Part 2, we went over downloading and extracting our datasets from the internet.
Today in Part 3, we are going to go over reading that data and transforming it into a form that our models can actually understand.
Tensor
objects are potentially multi-dimensional arrays that store numbers. And since files on a computer are all just numbers, we can utilize that to our advantage, and create something called a Rank 3 Tensor of our images. This is a 3 dimensional structure that represents a single image in it's entirety. The dimensions of the Tensor represent the height, width, and the color channels of the image.
Specifically for the ResNet model that chapter 1 wants us to work with these images need to be in the shape of 3x224x224. That corresponds to 3 color channels, and height and width both set to 224. Because of this size limitation, we are going to need to resize our images before we store them in our Tensors.
Refactoring
First things first, in the last Part of this series, we just wrote all of our code into a single module. That is going to get overwhelming quickly, so let's refactor our code into separate modules.
This is how our tardyai/src/lib.rs
will look
pub mod download;
pub mod error;
pub use self::{
download::{untar_images, Url},
error::Error,
};
So in order to accomplish this, we can pull out the Error
enum into tardyai/src/error.rs
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("io error: {0}")]
IO(#[from] std::io::Error),
#[error("homedir error: {0}")]
Home(#[from] homedir::GetHomeError),
#[error("tar entry error: {0}")]
TarEntry(&'static str),
}
And we can pull everything else out into tardyai/src/download.rs
.
use std::{
fs::File,
io::{self, Seek},
path::PathBuf,
};
use flate2::read::GzDecoder;
use tar::Archive;
use crate::error::Error;
const S3_BASE: &str = "https://s3.amazonaws.com/fast-ai-";
const S3_IMAGE: &str = "imageclas/";
#[derive(Debug, Clone, Copy)]
pub enum Url {
Pets,
}
impl Url {
pub fn url(self) -> String {
match self {
Self::Pets => {
format!("{S3_BASE}{S3_IMAGE}oxford-iiit-pet.tgz")
}
}
}
}
fn ensure_dir(path: &PathBuf) -> Result<(), Error> {
if !path.exists() {
std::fs::create_dir_all(path)?;
}
Ok(())
}
pub fn untar_images(url: Url) -> Result<PathBuf, Error> {
let home = &homedir::get_my_home()?
.expect("home directory needs to exist")
.join(".tardyai");
let dest_dir = home.join("archive");
ensure_dir(&dest_dir)?;
let archive_file = download_archive(url, &dest_dir)?;
let dest_dir = home.join("data");
let dir = extract_archive(&archive_file, &dest_dir)?;
Ok(dir)
}
fn download_archive(url: Url, dest_dir: &PathBuf) -> Result<PathBuf, Error> {
let mut response = reqwest::blocking::get(url.url())?;
let archive_name = response
.url()
.path_segments()
.and_then(|s| s.last())
.and_then(|name| if name.is_empty() { None } else { Some(name) })
.unwrap_or("tmp.tar.gz");
let archive_file = dest_dir.join(archive_name);
// TODO: check if the archive is valid and exists
if archive_file.exists() {
log::info!("Archive already exists: {}", archive_file.display());
return Ok(archive_file);
}
log::info!(
"Downloading {} to archive: {}",
url.url(),
archive_file.display()
);
let mut dest = File::create(&archive_file)?;
response.copy_to(&mut dest)?;
Ok(archive_file)
}
fn extract_archive(archive_file: &PathBuf, dest_dir: &PathBuf) -> Result<PathBuf, Error> {
let tar_gz = File::open(archive_file)?;
let tar = GzDecoder::new(tar_gz);
let mut archive = Archive::new(tar);
log::info!(
"Extracting archive {} to: {}",
archive_file.display(),
dest_dir.display()
);
let dir = {
let entry = &archive
.entries()?
.next()
.ok_or(Error::TarEntry("No entries in archive"))??;
entry.path()?.into_owned()
};
let archive_dir = dest_dir.join(dir);
if archive_dir.exists() {
log::info!("Archive already extracted to: {}", archive_dir.display());
return Ok(archive_dir);
}
let tar = archive.into_inner();
let mut tar_gz = tar.into_inner();
tar_gz.seek(io::SeekFrom::Start(0))?;
let tar = GzDecoder::new(tar_gz);
let mut archive = Archive::new(tar);
archive.unpack(dest_dir)?;
Ok(archive_dir)
}
And that is done. Super simple, but much more organized.
On to finding images
Now that that's out of the way, we can get to the actual meat of this article.
As a good starting place let's create a new module for our datasets. And the associated tardyai/src/datasets.rs
file.
pub mod datasets;
And now we need to create a struct that can keep track of all the image files we want to read.
use std::path::PathBuf;
pub struct DirectoryImageDataset {
files: Vec<PathBuf>,
}
Let's also create a constructor that takes the parent directory, walks all the files, and collects all the images into the files
field. This is going to use the walkdir
crate for walking the parent directory and fetching the names of images that match extensions, and the image
crate for listing the actual file extensions that are supported as well as eventually decoding the images and loading them into memory.
impl DirectoryImageDataset {
pub fn new(parent: PathBuf) -> Result<Self, Error> {
let exts = image_extensions();
let walker = WalkDir::new(parent).follow_links(true).into_iter();
let files = walker
.filter_map(|entry| {
let entry = entry.ok()?;
entry
.path()
.extension()
.and_then(|ext| Some(exts.contains(ext.to_str()?)))?
.then_some(entry)
})
.map(|entry| entry.path().to_owned())
.collect();
Ok(Self { files })
}
}
This uses a function called image_extensions
that returns a HashSet
of all the supported file extensions that I feel like including right now.
fn image_extensions() -> HashSet<&'static str> {
let mut set = HashSet::default();
set.extend(ImageFormat::Jpeg.extensions_str());
set.extend(ImageFormat::Png.extensions_str());
set.extend(ImageFormat::Gif.extensions_str());
set.extend(ImageFormat::WebP.extensions_str());
set.extend(ImageFormat::Tiff.extensions_str());
set.extend(ImageFormat::Bmp.extensions_str());
set.extend(ImageFormat::Qoi.extensions_str());
set
}
I've also added a utility method for fetching the list of files that are represented
impl DirectoryImageDataset {
// pub fn new(...)
pub fn files(&self) -> &[PathBuf] {
&self.files
}
}
Let's get to the Tensors
now
Now that we've discovered all the images, we just need to get them into memory. For now, I'm planning on doing this the naive way and just load the images each time they are requested. We can optimize it later by adding a cache
The dfdx
crate has a trait called ExactSizeDataset
that it can use to power the training epochs and calculate the validation loss at the end.
It's a pretty straightforward interface that has only two methods that are needed, a len()
, and a get()
method. len()
is pretty self explanatory. get()
takes an index and returns something you specify. Generally this can be both a Tensor, and the label that that Tensor represents.
impl ExactSizeDataset for DirectoryImageDataset<'_> {
type Item<'a> = Result<(Tensor<Rank3<3, 224, 224>, f32, AutoDevice>, bool), Error>
where
Self: 'a;
fn get(&self, index: usize) -> Self::Item<'_> {
let image_file = &self.files[index];
// Read the image and resize it to 224x224, and 3 channels
let image = ImageReader::open(image_file)?
.decode()?
.resize_exact(224, 224, FilterType::Triangle)
.into_rgb8();
// Shrink the byte values to f32 between [0, 1]
let bytes: Vec<f32> = image.as_bytes().iter().map(|&b| b as f32 / 255.0).collect();
// Create the tensor and the label
Ok((
self.dev
.tensor_from_vec(bytes, (Const::<3>, Const::<224>, Const::<224>)),
(self.label_fn)(image_file),
))
}
fn len(&self) -> usize {
self.files.len()
}
}
I'm doing a lot here. Probably the most important is that I need to modify the struct to contain the device in order to actually construct the Tensor
objects. This is because the Tensor might be stored on the GPU, and the device facilitates constructing them there.
I've also made the struct slightly generic, by storing a function that maps the path to the label, in this case the label is hard coded to a bool
. We'll need to take care of that later, but I don't want to borrow tomorrow's problem just yet.
Here's the new structure for the struct, along with the associated constructor changes to support it. I unfortunately had to add a lifetime to handle the label function, but it's not too bad, we can get away with lifetime elision except when we are constructing it. I suppose I could have gone with 'static
but that would have meant that I could only pass functions, not lambdas like I am hoping to eventually use. (So much for not borrowing tomorrow's problems, I guess)
pub struct DirectoryImageDataset<'fun> {
files: Vec<PathBuf>,
dev: AutoDevice,
label_fn: Box<dyn Fn(&Path) -> bool + 'fun>,
}
impl<'fun> DirectoryImageDataset<'fun> {
pub fn new(
parent: PathBuf,
dev: AutoDevice,
label_fn: impl Fn(&Path) -> bool + 'fun,
) -> Result<Self, Error> {
let exts = image_extensions();
let walker = WalkDir::new(parent).follow_links(true).into_iter();
let files = walker
.filter_map(|entry| {
let entry = entry.ok()?;
entry
.path()
.extension()
.and_then(|ext| Some(exts.contains(ext.to_str()?)))?
.then_some(entry)
})
.map(|entry| entry.path().to_owned())
.collect();
Ok(Self {
files,
dev,
label_fn: Box::new(label_fn),
})
}
}
How do we optimize it?
Right now, as we run through the epochs if we started training right now, we'd run into a problem with speed. We aren't caching the Tensors, so they are getting cleared and rebuilt each time through the training loop.
That will take much longer than if we just cached all the values the first time. So by spending a little memory (about 2 GiB from my back of the envelope math), we can improve our performance, significantly.
Now, the ExactSizeDataset
doesn't allow us to get a mutable reference to self
, so we can't just use a HashMap
and be done, because all the points where we would be able to add tensors to the cache don't have a mutable reference.
So to help us fix that issue, and help us in the event that we are calling get()
in parallel (say if we are trying to load all the images with rayon
, for example), we can use the dashmap
crate. DashMap
is a concurrent hashmap, which means we can update the keys and values concurrently in multiple threads. This also means that we can insert values with only a shared reference.
pub struct DirectoryImageDataset<'fun> {
files: Vec<PathBuf>,
dev: AutoDevice,
label_fn: Box<dyn Fn(&Path) -> bool + 'fun>,
tensors: DashMap<PathBuf, Tensor<Rank3<3, 224, 224>, f32, AutoDevice>>,
}
We also need to set this in the constructor
// pub fn new(...) ... {
// ...
Ok(Self {
files,
dev,
label_fn: Box::new(label_fn),
tensors: Default::default(),
})
}
And the final step is to pipe it through the get()
method.
fn get(&self, index: usize) -> Self::Item<'_> {
let image_file = &self.files[index];
// v---- New stuff here ---v
let label = (self.label_fn)(image_file);
if self.tensors.contains_key(image_file) {
return Ok((self.tensors.get(image_file).unwrap().clone(), label));
}
// Read the image and resize it to 224x224, and 3 channels
let image = ImageReader::open(image_file)?
.decode()?
.resize_exact(224, 224, FilterType::Triangle)
.into_rgb8();
// Shrink the byte values to f32 between [0, 1]
let bytes: Vec<f32> = image.as_bytes().iter().map(|&b| b as f32 / 255.0).collect();
// Create the tensor and the label
let tensor = self
.dev
.tensor_from_vec(bytes, (Const::<3>, Const::<224>, Const::<224>));
// v--- And here ---v
self.tensors.insert(image_file.clone(), tensor.clone());
Ok((tensor, label))
}
Finally, I've updated the chapter1/src/main.rs
file to use the new struct.
// ...
let path: PathBuf = untar_images(Url::Pets)
.context("downloading Pets")?
.join("images");
log::info!("Images are in: {}", path.display());
// v--- New stuff ---v
let dev = AutoDevice::default();
// Silly thing about the Pets dataset, all the cats have a capital first letter in their
// filename, all the dogs are lowercase only
let is_cat = |path: &Path| {
path.file_name()
.and_then(|n| n.to_str())
.and_then(|n| n.chars().next().map(|c| c.is_uppercase()))
.unwrap_or(false)
};
let dataset = DirectoryImageDataset::new(path, dev.clone(), is_cat)?;
log::info!("Found {} files", dataset.files().len());
// ...
Conclusion
Well, that was more satisfying than the last two articles, we're finally actually able to store things on the GPU, or at least in a Tensor. To actually build this for GPU support, you need to build it with the cuda
feature enabled.
Check out the code at github.
In Part 4, we're going to go through the process of building the ResNet-34 model in rust. Stay tuned! I'm so excited!
Posted on November 21, 2023
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.