Using Fastai for Image Classification

Learn how to build a state-of-the-art image classifier in no time with fastai

I recently took a peek into Jeremy Howard’s 2019 course on deep learning. I never used the Fastai library before so I was pretty amazed by its level of abstraction that allows you to create stage-of-the-art neural networks in minutes, with a ridiculously tiny amount of code. In the following I will offer a short tutorial on building a CNN image classifier with Fastai that is supposed to act simultaneously as a helpful summary of the key concepts for myself and a clear overview for newcomers.

I chose the plant seedling dataset from Kaggle as an examplary dataset.

About Fastai

If you haven’t heard of Fastai yet I recommend taking a look at their homepage. In their mission statement it is stated that not only do they want to accelerate and help deep learning research, they also want to decrease the entry barriers for everyone. This line is from their introduction of PyTorch for Fastai back in September 2017 (Source):

Everybody should be able to use deep learning to solve their problems with no more education than it takes to use a smart phone. Therefore, each year our main research goal is to be able to teach a wider range of deep learning applications, that run faster, and are more accurate, to people with less prerequisites.

You can feel this ambition even if you watch Fastai‘s 2019 course, and it’s something that gives me a good feeling about the direction the field is headed. As it often happens in research, knowledge and tools are only accessible to a certain minority. This is even more pronounced in deep learning research, where you need huge amounts of RAM in powerful GPUs to solve a lot of problems (check this unrelated video on GPU vs CPU).

In addition, support for Fastai is now offered on a lot of cloud computing platforms, like Paperspace, Cradle and AWS, just to name a few. But since these services all cost money, I will stick to Google’s just recently announced Colaboratory, that let’s you use Google’s GPUs for free. Yes, zero cost! Ain’t that incredible times to be alive?

Setting up Google Colab

To avoid redundancy online, please check this Medium post by Manikanta Yadunanda for a quick introduction on how to use the Google Colaboratory (Colab) with Fastai. Since this introduction was written, Google has included official Fastai and PyTorch support in Colab, so you probably don’t even need to install it after connecting to a runtime. You can check with the following line of code if all of the important pip packages are installed. If not, uncomment the last line and use it to install Fastai and PyTorch for Python 3.6.x and CUDA 9.x.

Cloning this Notebook

To access this notebook and run the computations yourself, you can simply import it directly from my GitHub repository. Just go to File… Open Notebook…, select the GitHub tab and enter ‘verrannt/Tutorials’ in the search bar. Select the ‘fastai-plant-seedlings-classification.ipynb’ and you’re done. Simple as that. After you checked that Colab is configured for Fastai support, you can continue with the data preperation part.

Data Preperation

Get the Data

Since we are using Kaggle as our dataset supply, you need a Kaggle account to download the dataset. If you got one, head over to your account settings and get a new API key which you can use for the Kaggle CLI.

The following code installs the CLI and registers a new API key. We then download the plant seedling dataset and unzip it. Make sure that you are in the /content folder of the fastai directory.

Inspect the Data

Let’s inspect the data. The folder we downloaded and extracted from Kaggle has 12 subfolders, each of which corresponds to one type of seedling with its respective images inside. These will be our labels for the classification task.

If we print the lengths of the contents of these folders we can see that every folder contains a different amount of images. The differences are great, see e.g. that the ‘Loose Silky-bent’ has the most images (762) while the ‘Common Wheat’ the fewest (253). We will see later if this amounts to differences in the prediction accuracy.

This outputs the following:

No. of labels: 12
Small-flowered Cranesbill, 576 files
Common wheat, 253 files
Charlock, 452 files
Sugar beet, 463 files
Maize, 257 files
Black-grass, 309 files
Loose Silky-bent, 762 files
Fat Hen, 538 files
Cleavers, 335 files
Shepherd’s Purse, 274 files
Scentless Mayweed, 607 files
Common Chickweed, 713 files

Let’s have a look at those images. For each of the 12 labels, we will print one random seedling.

Alright, they look good and quite distinguishable, except for ‘Loose Silky-bent’ and ‘Black-grass’. These might be harder for the network to recognize, but we will see. Let’s get to it!

Creating the Fastai Model

We can now create the CNN model using the Fastai library. Since its major update to v1, it got a lot clearer and consistent, so we only need to import the vision module and accuracy for our metric.

from import *
from fastai.metrics import accuracy

Fastai has a really nice class for handling everything related to the input images for vision tasks. It is called ImageDataBunch and has different functions, respective of the different ways data can be presented to the network. Since our images are placed in folders whose names correspond to the image labels, we will use the ImageDataBunch.from_folder() function to create an object that contains our image data. This is super useful and makes it incredibly easy to read the data into our model, as you will see in a bit.

What’s even more handy is that Fastai can automatically split our data into train and validation sets, so we don’t even need to create these on our own.

The only hyperparameters we need now are the path-variable pointing to our data set, the size of the inputs and the batch size for each gradient descent iteration. To make matters simple, the ImageDataBunch object will scale all images to a size*size squared image unless otherwise instructed.

A quick note on image size: the bigger an image, the more details the CNN will be able to pick out of it. At the same time, a bigger image means longer computation times. On the same note it might be that your GPU runs out of memory for a too large batch size. You can half the batch size if this is the case.

path = “./plant_seedlings-data/”
size = 224
bs = 64

We will create a variable called data in which we place the ImageDataBunch object. We create this object with the from_folder() function that we discussed above. Among the path to our data, the image and batch size, it also takes:

  • a function argument called get_transforms() which returns a list of available image transformations upon call.
  • a parameter valid_pct which controls the percentage of images that will be randomly chosen to be in the validation set
  • a parameter flip_vert which controls vertical flips and 90° turns in addition to just horizontal flips. (Since our plant images are taken from above, we can perform those without a problem, which would not be feasible on e.g. face data.)

To normalize the data in our object, we simply call normalize() on the object. It is possible to use ImageNet, CIFAR or MNIST stats as templates here, and if left empty this function will simply grab a batch of data from our object and compute the stats on it (mean and standard-deviation) and normalize the data accordingly. Since we will be using a ResNet architecture for our model which was trained on ImageNet, we will be using the ImageNet stats.


This outputs a summary:


Train: LabelList
y: CategoryList (4432 items)
[Category Small-flowered Cranesbill, Category Small-flowered Cranesbill, Category Small-flowered Cranesbill, Category Small-flowered Cranesbill, Category Small-flowered Cranesbill]...
Path: plant-seedlings-data
x: ImageItemList (4432 items)
[Image (3, 237, 237), Image (3, 497, 497), Image (3, 94, 94), Image (3, 551, 551), Image (3, 246, 246)]...
Path: plant-seedlings-data;

Valid: LabelList
y: CategoryList (1107 items)
[Category Maize, Category Black-grass, Category Common Chickweed, Category Cleavers, Category Charlock]...
Path: plant-seedlings-data
x: ImageItemList (1107 items)
[Image (3, 529, 529), Image (3, 945, 945), Image (3, 171, 171), Image (3, 125, 125), Image (3, 163, 163)]...
Path: plant-seedlings-data;

Test: None

That’s it, two lines of code for optimizing our dataset for training, adding different kinds of transformations and normalization! I would kindly ask you to stop here for a moment and just take a second to appreciate the beauty of this. The beauty of high-level libraries; just two lines of code and we have increased the diversity of our dataset tremendously. I can already hear my ugly batchnorm code crying in the trashbin.

All that is left now is to create the actual network and train it, which could not be more simple.

Fastai supplies us with a function called create_cnn() from its vision module. This function creates what is called a learner object, which we’ll put into a properly named variable. See here that we specify the ResNet architecture as our base model for transfer learning. Upon call, the trained architecture will be downloaded via the Fastai API and stored locally.

We will use accuracy for our metric. If you check the docs you can see a list of other metrics which are availabe. Defining the callback function ShowGraph simply tells the learner that it should return a graph for whatever it does, which seems very useful to me for seeing whether the model is still improving.

learner = create_cnn(data, models.resnet18, metrics=[accuracy], callback_fns=ShowGraph)

Finding the learning rate

The learner object we create comes with a build-in function to find the optimal learning rate, or range of learning rates, for training. It achieves this by fitting the model for a few epochs and saving for which learning rates the loss decreases the most.

We want to choose a learning rate, for which the loss is still decreasing, i.e. we do not want the learning rate with the minimum loss, but with the steepest slope.

In the following plot, which is stored in the recorder object of our learner, we can see that this is the case for learning rates between 0.001 and 0.01.


First fit and evaluation

Now let’s fit our model for 8 epochs, with a learning rate between 0.001 and 0.01

learner.fit_one_cycle(8, max_lr=slice(1e-3, 1e-2))

That already looks good! The loss is decreasing a lot in the first few 1/5th of iterations, and less but continously afterwards.

Let’s see where the algorithm is making the most mistakes:

interpreter = ClassificationInterpretation.from_learner(learner)
[(‘Black-grass’, ‘Loose Silky-bent’, 28),
(‘Loose Silky-bent’, ‘Black-grass’, 7),
(‘Shepherd’s Purse’, ‘Scentless Mayweed’, 4)]

This shows us that the algorithm confuses the classes ‘black-grass’ and ‘loose silky-bent’ most often. We already saw on the sample images we displayed earlier that these look the most alike, so this makes sense.

Improving the Model

Unfreezing and fine-tuning

Before we unfreeze the layers and learn again, we save the weights so that we can go back in case we mess up.‘stage-1’)
learner.fit_one_cycle(12, max_lr=slice(1e-5, 1e-4))

We will stick to this, because the validation error is getting worse than the testing error, and it looks like this tendency will only increase. The model would start to overfit the training data if we continue training from this point onwards!

Good job, we successfully trained a state-of-the-art image classifier for a custom dataset, achieving 96.5% accuracy in just a handful of lines of code!

read original article at——artificial_intelligence-5