Rock Paper Scissor classifier using fast AI and restnet 34 pretrained model
Abhishek M Kori
Posted on August 18, 2019
We will be executing the below code in https://colab.research.google.com notebook
Step1: Import fast ai vision library
from fastai.vision import *
Step 2: Create a path and directory
path = Path('data/rps')
path.mkdir(parents=True, exist_ok=True)
Step 3: Download the kaggle json file from your kaggle profile and and upload it to current directory (/content)
Step 4: Copy the kaggle json to ~/.kaggle
change the file permissions
download the drgfreeman/rockpaperscissors dataset
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d drgfreeman/rockpaperscissors
Step 5: Copy the dataset to data/rps/ dir and unzip it
!cp rockpaperscissors.zip data/rps/rockpaperscissors.zip
!unzip data/rps/rockpaperscissors.zip -d data/rps
Step 6: Create ImageDataBunch bunch object which will
Create train and validation dataset
get_tranforms will get the default tranformations to add variety to the dataset. It might crop the image, tilt the image or skew it littlebit etc..
valid_pct is 0.2 ie 80% train 20% test
Image size of 224 pixels
Normalize the pixels according imagenet statistics to avoid values going nuts!
np.random.seed(42)
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)
Lets see how our dataset looks
data.show_batch(rows=3, figsize=(7,8))
Explore the data bunch object to see the number of classes, length of datasets
data.classes, data.c, len(data.train_ds), len(data.valid_ds)
Step 7: Lets create a CNN learner by passing in databunch object , restnet 34 pre tranined model (which the fast AI lib will download for us) and tell the trainer that we are inerested in error_rate as the metrics to meseaure
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
Step 8: Lets try to fit the model by calling fit_one_cycle.
This method accepts integer which will tell the learner how many times does it should see our data. In our case we will do for 4 times.
Each time it sees the our data it will learn few features
If we increase the number of cycles we might try to over fit the model to our dataset. We want it to just learn the general features of our images
learn.fit_one_cycle(4)
Lets analyze our metrics
epoch: Going through one cycle of all classes of data
train_loss: Difference between the predicted value and actual ground truth in traning dataset
valid_loss: Difference between the predicted value and actual grounf truth in validation dataset
error_rate: number of predictions that are wrong.
You can see that train loss is greater than validation loss because while traning it does not have the information about the image. As it learns about the features it tries to test its capability against the validation set. So as you can predict when running against validation set, it does have some idea about the image and performs better compare to traning set
Step 9: Lets save our model!
it creates a pth file in models directory
learn.save('stage-1')
Step 10: Now we will draw some insights from our trained model
ClassificationInterpretation.from_learner will give generate some insights.
interp = ClassificationInterpretation.from_learner(learn)
lets see for which images our model went wrong and why?
As you can see the predicted , actual, loss and what was the probability of the predicted class.
You can also notice bottom 2 images it predicted the classes correctly but with low confidence
interp.plot_top_losses(4, figsize=(15,11))
Lets explore out confusion matrix which tells us about exact number of correct and incorrect predictions
In our example we had 2 incorrect predictions
It predicted rock but it was a paper
it predicted paper but it was a rock
And you can see those incorrect predictions above in plot top losses. With this data you can further clean your data and tweak your model to get the best results
interp.plot_confusion_matrix(figsize=(12,12), dpi=60)
Most confused prints the exact classes which it was not able to predict
interp.most_confused()
lets unfreesze our model and experiment a bit
What does it mean by unfreeze? Well till now we just tranied our last few layers of the existed pre tranined model. unfeeze will help us train all the layes of the model.
learn.unfreeze()
Train all the layers of our model
As you can see this was a bad idea. The train loss , valid loss and error rate is much higher for 1 epoc when train all the layer than only last few layers.
This is because you are asking the model to re learn all the minute features of an image. It better to leave as it is because the pre trained model is good edges and small shapes which are common for all kinds of images.
Telling it to relearn all those features by our limited expertise and data is not suited.
learn.fit_one_cycle(1)
Step 11: Lets load back our saved model and try to find the learning rate
Learning rate is the change in step size during Stochastic gradient descent.
We will observe how the loss changes
learn.load('stage-1');
learn.lr_find()
learn.recorder.plot()
From the graph we can see that after the learning rate power of 3 it shoots. So good rule of thumb is to retrain the model 10 times lower as the upper limit
Step 12: Now lets re train our model with controlled learning rate between 10 to the power -06 to -04
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))
Step 13: Lets save our final model
learn.save('rock-paper-scissors-1')
Time to test!
I have a taken an image of my hand , removed the background and replaced with green screen. Use this tool https://www.remove.bg
img = open_image('/content/sci-removebg-preview.png')
img
learn.predict(img)
This is a small tutorial of transfer learning by taking existing pre trained model and re train it for your use case.
Fast AI library is really good for programmers and I strongly recommend. I learnt this concept from this lesson. For more https://course.fast.ai/
If you want to experiment it your, self check this jupyter notebook and run it on https://colab.research.google.com.
Make sure you change the runtime type GPU
Its time for you to try !
https://gist.github.com/abhishekori/4e4697ba7cb7d2b49fece674ce31cc00
Let me know what you think <3
Posted on August 18, 2019
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.
Related
August 18, 2019