Training a deep learning model from scratch to classify Pokémon
In this post I am documenting what I learned in Chapter 2 of the FastAI Course. I will demonstrate how you can train a deep learning model from scratch and by the end of the blog post we will have a model which can classify Pokemon.
Installing Packages
First we install the fast ai library. To do this in a Jupiter Notebook you can prefix the line with a !
which means everything after the Exclamation Mark is interpreted as a Shell Command.
The -Uqq
flag is an option for pip. -U
means upgrade and -qq
stands for very quiet
, so that there are fewer console outputs during the installation.
!pip install -Uqq fastai
To train our model, we will first search for images using the Duck Duck Go API and then download them with the fastdownload
package from FastAI.
!pip install duckduckgo_search fastdownload
Setting up the imports
Now we are ready to import the stuff we need from the installed packages. Go ahead and create a new code cell and import:
from fastai.vision.all import *
from duckduckgo_search import DDGS
import ipywidgets as widgets
from ipywidgets import VBox, HBox
from IPython.display import display
Search and download Images
In the next step we will create a function to search for images and download them:
def search_images_ddg(term, max_images=5):
print(f"Searching for '{term}'")
# Search for images by using DuckDuckGo. 'with' is a context manager that ensures that the session is closed after the search.
with DDGS() as ddgs:
return L(ddgs.images(term, max_results=max_images)).itemgot('image')
# Define pokemon categories
pokemon_types = 'charizard', 'blastoise', 'venusaur'
# Create a path object
path = Path('pokemon')
# If the path does not exist, create it
if not path.exists():
path.mkdir()
# Loop through the pokemon_types and download images for each type
for o in pokemon_types:
dest = (path/o)
dest.mkdir(exist_ok=True)
results = search_images_ddg(f'{o}')
download_images(dest, urls=results)
# Verify that each image can be opened
fns = get_image_files(path)
failed = verify_images(fns)
failed.map(Path.unlink)
Data Blocks
The DataBlock is a key concept in FastAI that defines how to load and preprocess your data. We will specify that we are working with images and tell FastAI about our 3 Pokemon categories. Furthermore we will define how we split our data into training and validation sets.
pokemon = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=parent_label,
item_tfms=RandomResizedCrop(224, min_scale=0.5),
batch_tfms=aug_transforms()
)
blocks=(ImageBlock, CategoryBlock)
: Specifies that we’re working with image data and category labels.get_items=get_image_files
: Defines how to get the list of items (image files in this case).splitter=RandomSplitter(valid_pct=0.2, seed=42)
: Splits the data into training and validation sets (80% train, 20% validation).get_y=parent_label
: Specifies how to get the label for each item (using the parent folder name).item_tfms=RandomResizedCrop(224, min_scale=0.5)
: Applies random resized cropping to each item.batch_tfms=aug_transforms()
: Applies additional augmentations to each batch.
Creating DataLoaders
In this step we will take our datablock definition and load the data.
# Create a DataLoaders object
dls = pokemon.dataloaders(path)
Training the model
Now we can train our model. For this we will call the vision_learner
function which creates a learner object, which:
Groups together a
model
, somedls
and aloss_func
to handle training fastai documentation
We will use a pre-trained model called ResNet18
which is a common choice for image classification tasks.
learn = vision_learner(dls, resnet18, metrics=error_rate)
# The fine_tune method then trains this model on our specific dataset. It uses transfer learning, which means it starts with the pre-trained weights and adapts them to our task.
# The parameter 4 means we will run it for 4 iterations
learn.fine_tune(4)
Exporting the model
# Export the model for later usage
learn.export('pokemon_classifier.pkl')
This line exports the trained model to a file named pokemon_classifier.pkl
. It contains the model weights and the architecture. The architecture is like a blueprint of the neural network. It defines how the model is structured and how information flows through it.
Setting up for Inference
Inference basically means putting our trained model to work. It’s the process of using the model to make predictions on new images.
To do this, we will load the trained model by calling the load_learner
function. This will recreate the entire learner object, including its model architecture.
learn_inf = load_learner('pokemon_classifier.pkl')
Creating the ‘UI’ to test our model
In the final step we are going to implement a UI
with two buttons. One to upload the image and the second button to call the predict method on our Inference object from the last section.
# Create a button which will open a file upload dialog
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()
btn_run = widgets.Button(description='Classify')
def on_click_classify(change):
img = PILImage.create(btn_upload.data[-1])
out_pl.clear_output()
with out_pl:
display(img.to_thumb(128,128))
# Call our predict method on our inference object
pred, pred_idx, probs = learn_inf.predict(img)
lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
# On click handler
btn_run.on_click(on_click_classify)
# Setup the UI
widgets_panel = VBox([
widgets.Label('Select your Pokemon!'),
btn_upload,
btn_run,
out_pl,
lbl_pred
])
display(widgets_panel)