The fastai library as a layered API as summarized by this graph:
If you are following this tutorial, you are probably already familiar with the applications, here we will see how they are powered by the high-level and mid-level API.
Imagenette is a subset of ImageNet with 10 very different classes. It's great to quickly experiment before trying a fleshed-out technique on the full ImageNet dataset. We will show in this tutorial how to train a model on it, using the usual high-level APIs, then delving inside the fastai library to show you how to use the mid-level APIs we designed. This way you'll be able to customize your own data collection or training as needed.
We will look at several ways to get our data in DataLoaders
: first we will use ImageDataLoaders
factory methods (application layer), then the data block API (high level API) and lastly, how to do the same thing with the mid-level API.
This is the most basic way of assembling the data that we have presented in all the beginner tutorials, so hopefully it should be familiar to you by now.
First, we import everything inside the vision application:
from fastai.vision.all import *
Then we download the dataset and decompress it (if needed) and get its location:
path = untar_data(URLs.IMAGENETTE_160)
We use ImageDataLoaders.from_folder
to get everything (since our data is organized in an imageNet-style format):
dls = ImageDataLoaders.from_folder(path, valid='val',
item_tfms=RandomResizedCrop(128, min_scale=0.35), batch_tfms=Normalize.from_stats(*imagenet_stats))
And we can have a look at our data:
dls.show_batch()
And as we saw in previous tutorials, the get_image_files
function helps get all the images in subfolders:
fnames = get_image_files(path)
Let's begin with an empty DataBlock
.
dblock = DataBlock()
By itself, a DataBlock
is just a blue print on how to assemble your data. It does not do anything until you pass it a source. You can choose to then convert that source into a Datasets
or a DataLoaders
by using the DataBlock.datasets
or DataBlock.dataloaders
method. Since we haven't done anything to get our data ready for batches, the dataloaders
method will fail here, but we can have a look at how it gets converted in Datasets
. This is where we pass the source of our data, here all of our filenames:
dsets = dblock.datasets(fnames)
dsets.train[0]
By default, the data block API assumes we have an input and a target, which is why we see our filename repeated twice.
The first thing we can do is to use a get_items
function to actually assemble our items inside the data block:
dblock = DataBlock(get_items = get_image_files)
The difference is that you then pass as a source the folder with the images and not all the filenames:
dsets = dblock.datasets(path)
dsets.train[0]
Our inputs are ready to be processed as images (since images can be built from filenames), but our target is not. We need to convert that filename to a class name. For this, fastai provides parent_label
:
parent_label(fnames[0])
This is not very readable, so since we can actually make the function we want, let's convert those obscure labels to something we can read:
lbl_dict = dict(
n01440764='tench',
n02102040='English springer',
n02979186='cassette player',
n03000684='chain saw',
n03028079='church',
n03394916='French horn',
n03417042='garbage truck',
n03425413='gas pump',
n03445777='golf ball',
n03888257='parachute'
)
def label_func(fname):
return lbl_dict[parent_label(fname)]
We can then tell our data block to use it to label our target by passing it as get_y
:
dblock = DataBlock(get_items = get_image_files,
get_y = label_func)
dsets = dblock.datasets(path)
dsets.train[0]
Now that our inputs and targets are ready, we can specify types to tell the data block API that our inputs are images and our targets are categories. Types are represented by blocks in the data block API, here we use ImageBlock
and CategoryBlock
:
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
get_y = label_func)
dsets = dblock.datasets(path)
dsets.train[0]
We can see how the DataBlock
automatically added the transforms necessary to open the image, or how it changed the name "cat" to an index (with a special tensor type). To do this, it created a mapping from categories to index called "vocab" that we can access this way:
dsets.vocab
Note that you can mix and match any block for input and targets, which is why the API is named data block API. You can also have more than two blocks (if you have multiple inputs and/or targets), you would just need to pass n_inp
to the DataBlock
to tell the library how many inputs there are (the rest would be targets) and pass a list of functions to get_x
and/or get_y
(to explain how to process each item to be ready for its type). See the object detection below for such an example.
The next step is to control how our validation set is created. We do this by passing a splitter
to DataBlock
. For instance, here is how we split by grandparent folder.
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
get_y = label_func,
splitter = GrandparentSplitter())
dsets = dblock.datasets(path)
dsets.train[0]
The last step is to specify item transforms and batch transforms (the same way as we do it in ImageDataLoaders
factory methods):
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
get_y = label_func,
splitter = GrandparentSplitter(),
item_tfms = RandomResizedCrop(128, min_scale=0.35),
batch_tfms=Normalize.from_stats(*imagenet_stats))
With that resize, we are now able to batch items together and can finally call dataloaders
to convert our DataBlock
to a DataLoaders
object:
dls = dblock.dataloaders(path)
dls.show_batch()
Another way to compose several functions for get_y
is to put them in a Pipeline
:
imagenette = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
get_y = Pipeline([parent_label, lbl_dict.__getitem__]),
splitter = GrandparentSplitter(valid_name='val'),
item_tfms = RandomResizedCrop(128, min_scale=0.35),
batch_tfms = Normalize.from_stats(*imagenet_stats))
dls = imagenette.dataloaders(path)
dls.show_batch()