Working through the fast.ai book in Rust - Part 2
Favil Orbedios
Posted on November 19, 2023
Introduction
In Part 1 we introduced the dfdx
crate. And we didn't get into any of the actually implementing any of the fast.ai book projects.
In Part 2 we are going to see how far we can get into chapter 1 of the book. Since this isn't python, and we don't have the fastai
library, we are going to have to do everything ourselves.
If you want to follow along, and don't have a copy of the book, you can read it online for free here.
In particular this is what the book wants us to write:
from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'
def is_cat(x): return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
path, get_image_files(path), valid_pct=0.2, seed=42,
label_func=is_cat, item_tfms=Resize(224))
learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(1)
We can see from that, that it isn't much code. But the reason I don't like it, and the reason I'm writing this series is because it is just a bunch of magic. It gets you on your feet quickly, by hiding all the fun parts behind its façade.
In this sample we can see that it:
- Automatically downloads, and extracts the images to an
images
folder. - Defines a label function
- Automatically loads the images from the path with the
ImageDataLoaders
- Constructs a learner from a fully available
resnet34
model with weights that are already downloaded. - And runs a learning algorithm on it for a single cycle.
Now this is too much to cover in a single article, so I'm going to focus on 1.
for Part 2.
Creating a new Rust package
I realized while writing this article, that my structure for my code needs refinement. So I'm going to throw away the old code, and construct a repo with a number of crates in a Rust workspace.
If you want to follow along, I've created a git repo called tardyai
, where I will be committing all my code to.
To fetch the specific tag from the repo use the following command:
git clone --branch START_HERE https://github.com/favilo/tardyai.git
That will download the repo and put you in the same starting point as me. Specifically, it contains a Rust workspace with two member crates: tardyai
, and chapter1
. Both of these are the default packages created by cargo new
.
tardyai
will be a small, incomplete port of the fastai
library. It won't run any code itself, it just contains all the logic around downloading images, for now.
Let's add URLs
It would be very nice if we could take the same URLs that are in the python library and do the same thing in Rust.
I'm envisioning an interface similar the following. I'm adding this to our chapter1/src/main.rs
file.
use std::path::PathBuf;
fn main() {
let path: PathBuf = tardyai::untar_images(tardyai::Url::Pets)
.join("images");
}
Now to make that a reality, lets edit tardyai/src/lib.rs
use std::path::PathBuf;
pub enum Url {
Pets,
}
pub fn untar_images(url: Url) -> PathBuf {
todo!()
}
This just panics, but at least everything compiles.
From here we need to convert that enum Url::Pets
to an actual URL. For the fastai
library this is https://s3.amazonaws.com/fast-ai-imageclas/oxford-iiit-pet.tgz
. So lets add some methods to the Url
type to get a URL.
const S3_BASE: &str = "https://s3.amazonaws.com/fast-ai-";
const S3_IMAGE: &str = "imageclas/";
// v-- I decided that we need to derive some sane traits by default.
#[derive(Debug, Clone, Copy)]
pub enum Url {
Pets,
}
impl Url {
pub fn url(self) -> String {
match self {
Self::Pets => {
format!("{S3_BASE}{S3_IMAGE}oxford-iiit-pet.tgz")
}
}
}
}
This defines the url()
method, and I created a constant called S3_BASE
in order to collect the common prefix. This will allow us to quickly add new paths, and their corresponding URLs.
Actually download something why don't you?
Now we need to actually connect to the internet and download our archive from S3. In order to do this I'm going to use the reqwest
crate. This crate is the defacto crate for making HTTP requests. It offers both an async and a blocking API. We are going to be using the blocking API for now. (Maybe in a future article I'll convert everything over to async/await)
➜ cargo add reqwest -p tardyai -F blocking
Updating crates.io index
Adding reqwest v0.11.22 to dependencies.
Features:
+ __tls
+ blocking
+ default-tls
+ hyper-tls
+ native-tls-crate
+ tokio-native-tls
38 deactivated features
This adds the latest version of reqwest
with the blocking
feature turned on.
Then we edit tardyai/src/lib.rs
pub fn untar_images(url: Url) -> PathBuf {
let response = reqwest::blocking::get(url.url()).expect("get failed");
// ...
}
That .expect()
looks pretty ugly. Let's clean that up with our own custom error type derived with the help of thiserror
.
➜ cargo add -p tardyai thiserror
Updating crates.io index
Adding thiserror v1.0.50 to dependencies.
NOTE: I'm going to stop writing down the steps to add a crate. They are almost always the same. Instead I'll mention the crate and any features we need to add to get it to work for us.
thiserror
will let us create an error type that is portable, and works with some nice error reporting crates that I'll talk about later.
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("io error: {0}")]
IO(#[from] std::io::Error),
}
Then we can change the signature of the untar_images
.
pub fn untar_images(url: Url) -> Result<PathBuf, Error> {
let response = reqwest::blocking::get(url.url())?;
log::info!("response: {:?}", response);
Ok(todo!())
}
So, that is us fetching the file from the URL. Of course this is useless to us as it stands, because we haven't saved it to the hard disk, but this will not use any bandwidth because we haven't fetched the body of the response.
Save it to the hard disk already
The fastai
library fetches the archive files to ~/.fastai/archive/
. I'm going to do the same thing, but in ~/.tardyai/archive/
instead.
So first we need to make sure that the directory exists. And we need to fetch the user's home in a cross platform manner. For that I'm using the homedir
crate.
fn ensure_dir(path: &PathBuf) -> Result<(), Error> {
if !path.exists() {
std::fs::create_dir_all(path)?;
}
Ok(())
}
pub fn untar_images(url: Url) -> Result<PathBuf, Error> {
let dest_dir = homedir::get_my_home()?
.expect("home directory needs to exist")
.join(".tardyai")
.join("archive");
ensure_dir(&dest_dir)?;
// ...
}
This required creating a new variant for our Error
enum. I called it Home
.
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("io error: {0}")]
IO(#[from] std::io::Error),
#[error("homedir error: {0}")]
Home(#[from] homedir::GetHomeError),
}
And to save it to disk let's create a new function.
fn download_archive(url: Url, dest_dir: &PathBuf) -> Result<PathBuf, Error> {
let mut response = reqwest::blocking::get(url.url())?;
let archive_name = response
.url()
.path_segments()
.and_then(|s| s.last())
.and_then(|name| if name.is_empty() { None } else { Some(name) })
.unwrap_or("tmp.tar.gz");
let archive_file = dest_dir.join(archive_name);
// TODO: check if the archive is valid and exists
if archive_file.exists() {
log::info!("Archive already exists: {}", archive_file.display());
return Ok(archive_file);
}
log::info!(
"Downloading {} to archive: {}",
url.url(),
archive_file.display()
);
let mut dest = File::create(&archive_file)?;
response.copy_to(&mut dest)?;
Ok(archive_file)
}
We have the archive, now what?
Well, let's decompress and extract it of course. For decompression I'm going to use the flate2
crate, with the rust_backend
feature. And for extracting the resulting tar file, I'll use the tar
crate.
fn extract_archive(archive_file: &PathBuf, dest_dir: &PathBuf) -> Result<(), Error> {
let tar_gz = File::open(archive_file)?;
let tar = GzDecoder::new(tar_gz);
let mut archive = Archive::new(tar);
log::info!(
"Extracting archive {} to: {}",
archive_file.display(),
dest_dir.display()
);
archive.unpack(dest_dir)?;
Ok(())
}
Very straightforward. However, this doesn't give us the same path that the Python version does. The python version returns the extracted path. So We're going to have to do that next.
fn extract_archive(archive_file: &PathBuf, dest_dir: &PathBuf) -> Result<PathBuf, Error> {
let tar_gz = File::open(archive_file)?;
let tar = GzDecoder::new(tar_gz);
let mut archive = Archive::new(tar);
log::info!(
"Extracting archive {} to: {}",
archive_file.display(),
dest_dir.display()
);
let dir = {
let entry = &archive
.entries()?
.next()
.ok_or(Error::TarEntry("No entries in archive"))??;
entry.path()?.into_owned()
};
let archive_dir = dest_dir.join(dir);
if archive_dir.exists() {
log::info!("Archive already extracted to: {}", archive_dir.display());
return Ok(archive_dir);
}
let tar = archive.into_inner();
let mut tar_gz = tar.into_inner();
tar_gz.seek(io::SeekFrom::Start(0))?;
let tar = GzDecoder::new(tar_gz);
let mut archive = Archive::new(tar);
archive.unpack(dest_dir)?;
Ok(archive_dir)
}
This is a hack that I'm using in order to fetch the first entry in the tar archive, which is generally the top level directory stored inside. Then I have to unwind all the seeking I did by unwrapping the inner Reader
, seeking to 0, then reconstructing the archive.
If anyone knows of a more sane way to do this, please let me know in the comments.
This also required me to create another variant for our Error
enum, TarEntry
.
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("io error: {0}")]
IO(#[from] std::io::Error),
#[error("homedir error: {0}")]
Home(#[from] homedir::GetHomeError),
#[error("tar entry error: {0}")]
TarEntry(&'static str),
}
I also threw in a condition to return early if the archive has already been extracted. In the future we may want to change this to use SHA-1 hashes to verify that the data is the same as what was downloaded.
Conclusion
Well, so far we've managed to download and extract our dataset to a centralized location. This is a good first step. The first line of our program looks very similar to that of the python version.
use std::path::PathBuf;
use color_eyre::eyre::{Context, Result};
use tardyai::{untar_images, Url};
fn main() -> Result<()> {
env_logger::Builder::new()
.filter_level(log::LevelFilter::Info)
.init();
color_eyre::install()?;
let path: PathBuf = untar_images(Url::Pets)
.context("downloading Pets")?
.join("images");
log::info!("Images are in: {}", path.display());
Ok(())
}
In Part 3, we will figure out how to turn our images on disk into an ExactSizeDataset
that can provide the images as Tensor
structs, with their associated labels, and enable batching and other useful functions.
And if you want to see the code from this stage, you can either fetch the article-2
tag from git with
git co article-2
or browse it on github
Posted on November 19, 2023
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.