Training a deep learning model from scratch to classify Pokémon


In this post I am documenting what I learned in Chapter 2 of the FastAI Course. I will demonstrate how you can train a deep learning model from scratch and by the end of the blog post we will have a model which can classify Pokemon.

Result

Installing Packages

First we install the fast ai library. To do this in a Jupiter Notebook you can prefix the line with a ! which means everything after the Exclamation Mark is interpreted as a Shell Command. The -Uqq flag is an option for pip. -U means upgrade and -qq stands for very quiet, so that there are fewer console outputs during the installation.

!pip install -Uqq fastai

To train our model, we will first search for images using the Duck Duck Go API and then download them with the fastdownload package from FastAI.

!pip install duckduckgo_search fastdownload

Setting up the imports

Now we are ready to import the stuff we need from the installed packages. Go ahead and create a new code cell and import:

from fastai.vision.all import *
from duckduckgo_search import DDGS
import ipywidgets as widgets
from ipywidgets import VBox, HBox
from IPython.display import display

Search and download Images

In the next step we will create a function to search for images and download them:

def search_images_ddg(term, max_images=5):
    print(f"Searching for '{term}'")
    # Search for images by using DuckDuckGo. 'with' is a context manager that ensures that the session is closed after the search.
    with DDGS() as ddgs:
        return L(ddgs.images(term, max_results=max_images)).itemgot('image')

# Define pokemon categories
pokemon_types = 'charizard', 'blastoise', 'venusaur'

# Create a path object
path = Path('pokemon')

# If the path does not exist, create it
if not path.exists():
    path.mkdir()

# Loop through the pokemon_types and download images for each type
for o in pokemon_types:
    dest = (path/o)
    dest.mkdir(exist_ok=True)
    results = search_images_ddg(f'{o}')
    download_images(dest, urls=results)

# Verify that each image can be opened
fns = get_image_files(path)
failed = verify_images(fns)
failed.map(Path.unlink)

Data Blocks

The DataBlock is a key concept in FastAI that defines how to load and preprocess your data. We will specify that we are working with images and tell FastAI about our 3 Pokemon categories. Furthermore we will define how we split our data into training and validation sets.

pokemon = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms()
)
  • blocks=(ImageBlock, CategoryBlock): Specifies that we’re working with image data and category labels.
  • get_items=get_image_files: Defines how to get the list of items (image files in this case).
  • splitter=RandomSplitter(valid_pct=0.2, seed=42): Splits the data into training and validation sets (80% train, 20% validation).
  • get_y=parent_label: Specifies how to get the label for each item (using the parent folder name).
  • item_tfms=RandomResizedCrop(224, min_scale=0.5): Applies random resized cropping to each item.
  • batch_tfms=aug_transforms(): Applies additional augmentations to each batch.

Creating DataLoaders

In this step we will take our datablock definition and load the data.

# Create a DataLoaders object
dls = pokemon.dataloaders(path)

Training the model

Now we can train our model. For this we will call the vision_learner function which creates a learner object, which:

Groups together a model, some dls and a loss_func to handle training fastai documentation

We will use a pre-trained model called ResNet18 which is a common choice for image classification tasks.

learn = vision_learner(dls, resnet18, metrics=error_rate)
# The fine_tune method then trains this model on our specific dataset. It uses transfer learning, which means it starts with the pre-trained weights and adapts them to our task.
# The parameter 4 means we will run it for 4 iterations
learn.fine_tune(4)

Exporting the model

# Export the model for later usage
learn.export('pokemon_classifier.pkl')

This line exports the trained model to a file named pokemon_classifier.pkl. It contains the model weights and the architecture. The architecture is like a blueprint of the neural network. It defines how the model is structured and how information flows through it.

Setting up for Inference

Inference basically means putting our trained model to work. It’s the process of using the model to make predictions on new images. To do this, we will load the trained model by calling the load_learner function. This will recreate the entire learner object, including its model architecture.

learn_inf = load_learner('pokemon_classifier.pkl')

Creating the ‘UI’ to test our model

In the final step we are going to implement a UI with two buttons. One to upload the image and the second button to call the predict method on our Inference object from the last section.

# Create a button which will open a file upload dialog
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()
btn_run = widgets.Button(description='Classify')

def on_click_classify(change):
    img = PILImage.create(btn_upload.data[-1])
    out_pl.clear_output()
    with out_pl:
        display(img.to_thumb(128,128))
        
    # Call our predict method on our inference object
    pred, pred_idx, probs = learn_inf.predict(img)
    lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
    
# On click handler
btn_run.on_click(on_click_classify)

# Setup the UI
widgets_panel = VBox([
    widgets.Label('Select your Pokemon!'),
    btn_upload,
    btn_run,
    out_pl,
    lbl_pred
])
display(widgets_panel)