Data block tutorial

Using the data block across all applications

In this tutorial, we’ll see how to use the data block API on a variety of tasks and how to debug data blocks. The data block API takes its name from the way it’s designed: every bit needed to build the DataLoaders object (type of inputs, targets, how to label, split…) is encapsulated in a block, and you can mix and match those blocks

Building a DataBlock from scratch

The rest of this tutorial will give many examples, but let’s first build a DataBlock from scratch on the dogs versus cats problem we saw in the vision tutorial. First we import everything needed in vision.

from fastai.data.all import *
from fastai.vision.all import *

The first step is to download and decompress our data (if it’s not already done) and get its location:

path = untar_data(URLs.PETS)

And as we saw, all the filenames are in the “images” folder. The get_image_files function helps get all the images in subfolders:

fnames = get_image_files(path/"images")

Let’s begin with an empty DataBlock.

dblock = DataBlock()

By itself, a DataBlock is just a blue print on how to assemble your data. It does not do anything until you pass it a source. You can choose to then convert that source into a Datasets or a DataLoaders by using the DataBlock.datasets or DataBlock.dataloaders method. Since we haven’t done anything to get our data ready for batches, the dataloaders method will fail here, but we can have a look at how it gets converted in Datasets. This is where we pass the source of our data, here all our filenames:

dsets = dblock.datasets(fnames)
dsets.train[0]
(Path('/home/jhoward/.fastai/data/oxford-iiit-pet/images/Birman_82.jpg'),
 Path('/home/jhoward/.fastai/data/oxford-iiit-pet/images/Birman_82.jpg'))

By default, the data block API assumes we have an input and a target, which is why we see our filename repeated twice.

The first thing we can do is use a get_items function to actually assemble our items inside the data block:

dblock = DataBlock(get_items = get_image_files)

The difference is that you then pass as a source the folder with the images and not all the filenames:

dsets = dblock.datasets(path/"images")
dsets.train[0]
(Path('/home/jhoward/.fastai/data/oxford-iiit-pet/images/english_cocker_spaniel_76.jpg'),
 Path('/home/jhoward/.fastai/data/oxford-iiit-pet/images/english_cocker_spaniel_76.jpg'))

Our inputs are ready to be processed as images (since images can be built from filenames), but our target is not. Since we are in a cat versus dog problem, we need to convert that filename to “cat” vs “dog” (or True vs False). Let’s build a function for this:

def label_func(fname):
    return "cat" if fname.name[0].isupper() else "dog"

We can then tell our data block to use it to label our target by passing it as get_y:

dblock = DataBlock(get_items = get_image_files,
                   get_y     = label_func)

dsets = dblock.datasets(path/"images")
dsets.train[0]
(Path('/home/jhoward/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_77.jpg'),
 'dog')

Now that our inputs and targets are ready, we can specify types to tell the data block API that our inputs are images and our targets are categories. Types are represented by blocks in the data block API, here we use ImageBlock and CategoryBlock:

dblock = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                   get_items = get_image_files,
                   get_y     = label_func)

dsets = dblock.datasets(path/"images")
dsets.train[0]
(PILImage mode=RGB size=361x500, TensorCategory(1))

We can see how the DataBlock automatically added the transforms necessary to open the image or how it changed the name “dog” to an index 1 (with a special tensor type TensorCategory(1)). To do this, it created a mapping from categories to the index called “vocab” that we can access this way:

dsets.vocab
['cat', 'dog']

Note that you can mix and match any block for input and targets, which is why the API is named data block API. You can also have more than two blocks (if you have multiple inputs and/or targets), you would just need to pass n_inp to the DataBlock to tell the library how many inputs there are (the rest would be targets) and pass a list of functions to get_x and/or get_y (to explain how to process each item to be ready for his type). See the object detection below for such an example.

The next step is to control how our validation set is created. We do this by passing a splitter to DataBlock. For instance, here is how to do a random split.

dblock = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                   get_items = get_image_files,
                   get_y     = label_func,
                   splitter  = RandomSplitter())

dsets = dblock.datasets(path/"images")
dsets.train[0]
(PILImage mode=RGB size=320x480, TensorCategory(0))

The last step is to specify item transforms and batch transforms (the same way we do it in ImageDataLoaders factory methods):

dblock = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                   get_items = get_image_files,
                   get_y     = label_func,
                   splitter  = RandomSplitter(),
                   item_tfms = Resize(224))

With that resize, we are now able to batch items together and can finally call dataloaders to convert our DataBlock to a DataLoaders object:

dls = dblock.dataloaders(path/"images")
dls.show_batch()

The way we usually build the data block in one go is by answering a list of questions:

  • what is the types of your inputs/targets? Here images and categories
  • where is your data? Here in filenames in subfolders
  • does something need to be applied to inputs? Here no
  • does something need to be applied to the target? Here the label_func function
  • how to split the data? Here randomly
  • do we need to apply something on formed items? Here a resize
  • do we need to apply something on formed batches? Here no

This gives us this design:

dblock = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                   get_items = get_image_files,
                   get_y     = label_func,
                   splitter  = RandomSplitter(),
                   item_tfms = Resize(224))

For two questions that got a no, the corresponding arguments we would pass if the anwser was different would be get_x and batch_tfms.

Image classification

Let’s begin with examples of image classification problems. There are two kinds of image classification problems: problems with single-label (each image has one given label) or multi-label (each image can have multiple or no labels at all). We will cover those two kinds here.

from fastai.vision.all import *

MNIST (single label)

MNIST is a dataset of hand-written digits from 0 to 9. We can very easily load it in the data block API by answering the following questions:

  • what are the types of our inputs and targets? Black and white images and labels.
  • where is the data? In subfolders.
  • how do we know if a sample is in the training or the validation set? By looking at the grandparent folder.
  • how do we know the label of an image? By looking at the parent folder.

In terms of the API, those answers translate like this:

mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), 
                  get_items=get_image_files, 
                  splitter=GrandparentSplitter(),
                  get_y=parent_label)

Our types become blocks: one for images (using the black and white PILImageBW class) and one for categories. Searching subfolder for all image filenames is done by the get_image_files function. The split training/validation is done by using a GrandparentSplitter. And the function to get our targets (often called y) is parent_label.

To get an idea of the objects the fastai library provides for reading, labelling or splitting, check the data.transforms module.

In itself, a data block is just a blueprint. It does not do anything and does not check for errors. You have to feed it the source of the data to actually gather something. This is done with the .dataloaders method:

dls = mnist.dataloaders(untar_data(URLs.MNIST_TINY))
dls.show_batch(max_n=9, figsize=(4,4))

If something went wrong in the previous step, or if you’re just curious about what happened under the hood, use the summary method. It will go verbosely step by step, and you will see at which point the process failed.

mnist.summary(untar_data(URLs.MNIST_TINY))
Setting-up type transforms pipelines
Collecting items from /home/jhoward/.fastai/data/mnist_tiny
Found 2856 items
2 datasets of sizes 1418,1398
Setting up Pipeline: PILBase.create
Setting up Pipeline: parent_label -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}

Building one sample
  Pipeline: PILBase.create
    starting from
      /home/jhoward/.fastai/data/mnist_tiny/mnist_tiny/train/7/7420.png
    applying PILBase.create gives
      PILImageBW mode=L size=28x28
  Pipeline: parent_label -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}
    starting from
      /home/jhoward/.fastai/data/mnist_tiny/mnist_tiny/train/7/7420.png
    applying parent_label gives
      7
    applying Categorize -- {'vocab': None, 'sort': True, 'add_na': False} gives
      TensorCategory(1)

Final sample: (PILImageBW mode=L size=28x28, TensorCategory(1))


Collecting items from /home/jhoward/.fastai/data/mnist_tiny
Found 2856 items
2 datasets of sizes 1418,1398
Setting up Pipeline: PILBase.create
Setting up Pipeline: parent_label -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}
Setting up after_item: Pipeline: ToTensor
Setting up before_batch: Pipeline: 
Setting up after_batch: Pipeline: IntToFloatTensor -- {'div': 255.0, 'div_mask': 1}

Building one batch
Applying item_tfms to the first sample:
  Pipeline: ToTensor
    starting from
      (PILImageBW mode=L size=28x28, TensorCategory(1))
    applying ToTensor gives
      (TensorImageBW of size 1x28x28, TensorCategory(1))

Adding the next 3 samples

No before_batch transform to apply

Collating items in a batch

Applying batch_tfms to the batch built
  Pipeline: IntToFloatTensor -- {'div': 255.0, 'div_mask': 1}
    starting from
      (TensorImageBW of size 4x1x28x28, TensorCategory([1, 1, 1, 1], device='cuda:0'))
    applying IntToFloatTensor -- {'div': 255.0, 'div_mask': 1} gives
      (TensorImageBW of size 4x1x28x28, TensorCategory([1, 1, 1, 1], device='cuda:0'))

Let’s go over another example!

Pets (single label)

The Oxford IIIT Pets dataset is a dataset of pictures of dogs and cats, with 37 different breeds. A slight (but very) important difference with MNIST is that images are now not all of the same size. In MNIST they were all 28 by 28 pixels, but here they have different aspect ratios or dimensions. Therefore, we will need to add something to make them all the same size to be able to assemble them together in a batch. We will also see how to add data augmentation.

So let’s go over the same questions as before and add two more:

  • what are the types of our inputs and targets? Images and labels.
  • where is the data? In subfolders.
  • how do we know if a sample is in the training or the validation set? We’ll take a random split.
  • how do we know the label of an image? By looking at the parent folder.
  • do we want to apply a function to a given sample? Yes, we need to resize everything to a given size.
  • do we want to apply a function to a batch after it’s created? Yes, we want data augmentation.
pets = DataBlock(blocks=(ImageBlock, CategoryBlock), 
                 get_items=get_image_files, 
                 splitter=RandomSplitter(),
                 get_y=Pipeline([attrgetter("name"), RegexLabeller(pat = r'^(.*)_\d+.jpg$')]),
                 item_tfms=Resize(128),
                 batch_tfms=aug_transforms())

And like for MNIST, we can see how the answers to those questions directly translate in the API. Our types become blocks: one for images and one for categories. Searching subfolder for all image filenames is done by the get_image_files function. The split training/validation is done by using a RandomSplitter. The function to get our targets (often called y) is a composition of two transforms: we get the name attribute of our Path filenames, then apply a regular expression to get the class. To compose those two transforms into one, we use a Pipeline.

Finally, We apply a resize at the item level and aug_transforms() at the batch level.

dls = pets.dataloaders(untar_data(URLs.PETS)/"images")
dls.show_batch(max_n=9)

Now let’s see how we can use the same API for a multi-label problem.

Pascal (multi-label)

The Pascal dataset is originally an object detection dataset (we have to predict where some objects are in pictures). But it contains lots of pictures with various objects in them, so it gives a great example for a multi-label problem. Let’s download it and have a look at the data:

pascal_source = untar_data(URLs.PASCAL_2007)
df = pd.read_csv(pascal_source/"train.csv")
df.head()
fname labels is_valid
0 000005.jpg chair True
1 000007.jpg car True
2 000009.jpg horse person True
3 000012.jpg car False
4 000016.jpg bicycle True

So it looks like we have one column with filenames, one column with the labels (separated by space) and one column that tells us if the filename should go in the validation set or not.

There are multiple ways to put this in a DataBlock, let’s go over them, but first, let’s answer our usual questionnaire:

  • what are the types of our inputs and targets? Images and multiple labels.
  • where is the data? In a dataframe.
  • how do we know if a sample is in the training or the validation set? A column of our dataframe.
  • how do we get an image? By looking at the column fname.
  • how do we know the label of an image? By looking at the column labels.
  • do we want to apply a function to a given sample? Yes, we need to resize everything to a given size.
  • do we want to apply a function to a batch after it’s created? Yes, we want data augmentation.

Notice how there is one more question compared to before: we wont have to use a get_items function here because we already have all our data in one place. But we will need to do something to the raw dataframe to get our inputs, read the first column and add the proper folder before the filename. This is what we pass as get_x.

pascal = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
                   splitter=ColSplitter(),
                   get_x=ColReader(0, pref=pascal_source/"train"),
                   get_y=ColReader(1, label_delim=' '),
                   item_tfms=Resize(224),
                   batch_tfms=aug_transforms())

Again, we can see how the answers to the questions directly translate in the API. Our types become blocks: one for images and one for multi-categories. The split is done by a ColSplitter (it defaults to the column named is_valid). The function to get our inputs (often called x) is a ColReader on the first column with a prefix, the function to get our targets (often called y) is ColReader on the second column, with a space delimiter. We apply a resize at the item level and aug_transforms() at the batch level.

dls = pascal.dataloaders(df)
dls.show_batch()

Another way to do this is by directly using functions for get_x and get_y:

pascal = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
                   splitter=ColSplitter(),
                   get_x=lambda x:pascal_source/"train"/f'{x[0]}',
                   get_y=lambda x:x[1].split(' '),
                   item_tfms=Resize(224),
                   batch_tfms=aug_transforms())

dls = pascal.dataloaders(df)
dls.show_batch()

Alternatively, we can use the names of the columns as attributes (since rows of a dataframe are pandas series).

pascal = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
                   splitter=ColSplitter(),
                   get_x=lambda o:f'{pascal_source}/train/'+o.fname,
                   get_y=lambda o:o.labels.split(),
                   item_tfms=Resize(224),
                   batch_tfms=aug_transforms())

dls = pascal.dataloaders(df)
dls.show_batch()

The most efficient way (to avoid iterating over the rows of the dataframe, which can take a long time) is to use the from_columns method. It will use get_items to convert the columns into numpy arrays. The drawback is that since we lose the dataframe after extracting the relevant columns, we can’t use a ColSplitter anymore. Here we used an IndexSplitter after manually extracting the index of the validation set from the dataframe:

def _pascal_items(x): return (
    f'{pascal_source}/train/'+x.fname, x.labels.str.split())
valid_idx = df[df['is_valid']].index.values

pascal = DataBlock.from_columns(blocks=(ImageBlock, MultiCategoryBlock),
                   get_items=_pascal_items,
                   splitter=IndexSplitter(valid_idx),
                   item_tfms=Resize(224),
                   batch_tfms=aug_transforms())
dls = pascal.dataloaders(df)
dls.show_batch()

Image localization

There are various problems that fall in the image localization category: image segmentation (which is a task where you have to predict the class of each pixel of an image), coordinate predictions (predict one or several key points on an image) and object detection (draw a box around objects to detect).

Let’s see an example of each of those and how to use the data block API in each case.

Segmentation

We will use a small subset of the CamVid dataset for our example.

path = untar_data(URLs.CAMVID_TINY)

Let’s go over our usual questionnaire:

  • what are the types of our inputs and targets? Images and segmentation masks.
  • where is the data? In subfolders.
  • how do we know if a sample is in the training or the validation set? We’ll take a random split.
  • how do we know the label of an image? By looking at a corresponding file in the “labels” folder.
  • do we want to apply a function to a batch after it’s created? Yes, we want data augmentation.
camvid = DataBlock(blocks=(ImageBlock, MaskBlock(codes = np.loadtxt(path/'codes.txt', dtype=str))),
    get_items=get_image_files,
    splitter=RandomSplitter(),
    get_y=lambda o: path/'labels'/f'{o.stem}_P{o.suffix}',
    batch_tfms=aug_transforms())

The MaskBlock is generated with the codes that give the correpondence between pixel value of the masks and the object they correspond to (like car, road, pedestrian…). The rest should look pretty familiar by now.

dls = camvid.dataloaders(path/"images")
dls.show_batch()
/home/jhoward/mambaforge/lib/python3.9/site-packages/torch/_tensor.py:1142: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  ret = func(*args, **kwargs)

Points

For this example we will use a small sample of the BiWi kinect head pose dataset. It contains pictures of people and the task is to predict where the center of their head is. We have saved this small dataet with a dictionary filename to center:

biwi_source = untar_data(URLs.BIWI_SAMPLE)
fn2ctr = load_pickle(biwi_source/'centers.pkl')

Then we can go over our usual questions:

  • what are the types of our inputs and targets? Images and points.
  • where is the data? In subfolders.
  • how do we know if a sample is in the training or the validation set? We’ll take a random split.
  • how do we know the label of an image? By using the fn2ctr dictionary.
  • do we want to apply a function to a batch after it’s created? Yes, we want data augmentation.
biwi = DataBlock(blocks=(ImageBlock, PointBlock),
                 get_items=get_image_files,
                 splitter=RandomSplitter(),
                 get_y=lambda o:fn2ctr[o.name].flip(0),
                 batch_tfms=aug_transforms())

And we can use it to create a DataLoaders:

dls = biwi.dataloaders(biwi_source)
dls.show_batch(max_n=9)

Bounding boxes

For this task, we will use a small subset of the COCO dataset. It contains pictures with day-to-day objects and the goal is to predict where the objects are by drawing a rectangle around them.

The fastai library comes with a function called get_annotations that will interpret the content of train.json and give us a dictionary filename to (bounding boxes, labels).

coco_source = untar_data(URLs.COCO_TINY)
images, lbl_bbox = get_annotations(coco_source/'train.json')
img2bbox = dict(zip(images, lbl_bbox))

Then we can go over our usual questions:

  • what are the types of our inputs and targets? Images and bounding boxes.
  • where is the data? In subfolders.
  • how do we know if a sample is in the training or the validation set? We’ll take a random split.
  • how do we know the label of an image? By using the img2bbox dictionary.
  • do we want to apply a function to a given sample? Yes, we need to resize everything to a given size.
  • do we want to apply a function to a batch after it’s created? Yes, we want data augmentation.
coco = DataBlock(blocks=(ImageBlock, BBoxBlock, BBoxLblBlock),
                 get_items=get_image_files,
                 splitter=RandomSplitter(),
                 get_y=[lambda o: img2bbox[o.name][0], lambda o: img2bbox[o.name][1]], 
                 item_tfms=Resize(128),
                 batch_tfms=aug_transforms(),
                 n_inp=1)

Note that we provide three types, because we have two targets: the bounding boxes and the labels. That’s why we pass n_inp=1 at the end, to tell the library where the inputs stop and the targets begin.

This is also why we pass a list to get_y: since we have two targets, we must tell the library how to label for each of them (you can use noop if you don’t want to do anything for one).

dls = coco.dataloaders(coco_source)
dls.show_batch(max_n=9)

Text

We will show two examples: language modeling and text classification. Note that with the data block API, you can adapt the example before for multi-label to a problem where the inputs are texts.

from fastai.text.all import *

Language model

We will use a dataset compose of movie reviews from IMDb. As usual, we can download it in one line of code with untar_data.

path = untar_data(URLs.IMDB_SAMPLE)
df = pd.read_csv(path/'texts.csv')
df.head()
label text is_valid
0 negative Un-bleeping-believable! Meg Ryan doesn't even look her usual pert lovable self in this, which normally makes me forgive her shallow ticky acting schtick. Hard to believe she was the producer on this dog. Plus Kevin Kline: what kind of suicide trip has his career been on? Whoosh... Banzai!!! Finally this was directed by the guy who did Big Chill? Must be a replay of Jonestown - hollywood style. Wooofff! False
1 positive This is a extremely well-made film. The acting, script and camera-work are all first-rate. The music is good, too, though it is mostly early in the film, when things are still relatively cheery. There are no really superstars in the cast, though several faces will be familiar. The entire cast does an excellent job with the script.<br /><br />But it is hard to watch, because there is no good end to a situation like the one presented. It is now fashionable to blame the British for setting Hindus and Muslims against each other, and then cruelly separating them into two countries. There is som... False
2 negative Every once in a long while a movie will come along that will be so awful that I feel compelled to warn people. If I labor all my days and I can save but one soul from watching this movie, how great will be my joy.<br /><br />Where to begin my discussion of pain. For starters, there was a musical montage every five minutes. There was no character development. Every character was a stereotype. We had swearing guy, fat guy who eats donuts, goofy foreign guy, etc. The script felt as if it were being written as the movie was being shot. The production value was so incredibly low that it felt li... False
3 positive Name just says it all. I watched this movie with my dad when it came out and having served in Korea he had great admiration for the man. The disappointing thing about this film is that it only concentrate on a short period of the man's life - interestingly enough the man's entire life would have made such an epic bio-pic that it is staggering to imagine the cost for production.<br /><br />Some posters elude to the flawed characteristics about the man, which are cheap shots. The theme of the movie "Duty, Honor, Country" are not just mere words blathered from the lips of a high-brassed offic... False
4 negative This movie succeeds at being one of the most unique movies you've seen. However this comes from the fact that you can't make heads or tails of this mess. It almost seems as a series of challenges set up to determine whether or not you are willing to walk out of the movie and give up the money you just paid. If you don't want to feel slighted you'll sit through this horrible film and develop a real sense of pity for the actors involved, they've all seen better days, but then you realize they actually got paid quite a bit of money to do this and you'll lose pity for them just like you've alr... False

We can see it’s composed of (pretty long!) reviews labeled positive or negative. Let’s go over our usual questions:

  • what are the types of our inputs and targets? Texts and we don’t really have targets, since the targets is derived from the inputs.
  • where is the data? In a dataframe.
  • how do we know if a sample is in the training or the validation set? We have an is_valid column.
  • how do we get our inputs? In the text column.
imdb_lm = DataBlock(blocks=TextBlock.from_df('text', is_lm=True),
                    get_x=ColReader('text'),
                    splitter=ColSplitter())

Since there are no targets here, we only have one block to specify. TextBlocks are a bit special compared to other TransformBlocks: to be able to efficiently tokenize all texts during setup, you need to use the class methods from_folder or from_df.

Note: the TestBlock tokenization process puts tokenized inputs into a column called text. The ColReader for get_x will always reference text, even if the original text inputs were in a column with another name in the dataframe.

We can then get our data into DataLoaders by passing the dataframe to the dataloaders method:

dls = imdb_lm.dataloaders(df, bs=64, seq_len=72)
dls.show_batch(max_n=6)
text text_
0 xxbos xxmaj not sure if it was right or wrong , but i read thru the other comments before watching the xxunk have to say i disagree with most of the negative comments or problems people have had with it . \n\n xxmaj as a first time " lone xxmaj wolf " director / producer , i like to see things that i can xxunk to , not necessarily from the pro xxmaj not sure if it was right or wrong , but i read thru the other comments before watching the xxunk have to say i disagree with most of the negative comments or problems people have had with it . \n\n xxmaj as a first time " lone xxmaj wolf " director / producer , i like to see things that i can xxunk to , not necessarily from the pro 's
1 and each and every actor . xxmaj it 's like they all think they 're the main part of the movie and scream " notice xxup me ! " over and over again . xxmaj the bad guy has his bad - guy music going on and says sinister bad - guy - like things , just in case you did n't quite catch on . xxmaj the good guy does brave each and every actor . xxmaj it 's like they all think they 're the main part of the movie and scream " notice xxup me ! " over and over again . xxmaj the bad guy has his bad - guy music going on and says sinister bad - guy - like things , just in case you did n't quite catch on . xxmaj the good guy does brave and
2 innocently helps the xxmaj confederate hide . xxmaj later , when he returns to kill her father , the little girl 's xxunk is remembered . a sweet , small story from director xxup xxunk . xxmaj griffith . xxmaj location footage and humanity are xxunk displayed . \n\n▁ xxrep 4 * xxmaj in the xxmaj border xxmaj states ( 6 / 13 / 10 ) xxup xxunk . xxmaj griffith ~ helps the xxmaj confederate hide . xxmaj later , when he returns to kill her father , the little girl 's xxunk is remembered . a sweet , small story from director xxup xxunk . xxmaj griffith . xxmaj location footage and humanity are xxunk displayed . \n\n▁ xxrep 4 * xxmaj in the xxmaj border xxmaj states ( 6 / 13 / 10 ) xxup xxunk . xxmaj griffith ~ xxmaj
3 when they real winner should of been xxmaj xxunk xxmaj fiennes for " sunshine " . xxmaj if you have n't seen this movie yet , watch it and you 'll agree . " eyes xxmaj wide xxmaj shut " when released xxunk no nominations . xxmaj and as far as this year goes , well , the bad choices were all over the place ! xxmaj xxunk xxmaj xxunk gets no they real winner should of been xxmaj xxunk xxmaj fiennes for " sunshine " . xxmaj if you have n't seen this movie yet , watch it and you 'll agree . " eyes xxmaj wide xxmaj shut " when released xxunk no nominations . xxmaj and as far as this year goes , well , the bad choices were all over the place ! xxmaj xxunk xxmaj xxunk gets no "
4 xxmaj in this case , however , when combined with the moody atmosphere , and the fact that the small town of xxmaj red xxmaj rock seems almost empty of normal daily life , the coincidences and unlikely timing suggest a story that , beyond " xxunk " , is … surreal . xxmaj it 's almost as if fate deliberately xxunk with improbable events so as to force xxmaj michael to in this case , however , when combined with the moody atmosphere , and the fact that the small town of xxmaj red xxmaj rock seems almost empty of normal daily life , the coincidences and unlikely timing suggest a story that , beyond " xxunk " , is … surreal . xxmaj it 's almost as if fate deliberately xxunk with improbable events so as to force xxmaj michael to come
5 is not over the top and enough twists and turns to keep you interested until the end . \n\n xxmaj well directed , well acted and a good story . xxbos xxmaj not the worst movie xxmaj i 've seen but definitely not very good either . i myself am a paintball player , used to play airball a lot and going from woods to airball is quite a large change . not over the top and enough twists and turns to keep you interested until the end . \n\n xxmaj well directed , well acted and a good story . xxbos xxmaj not the worst movie xxmaj i 've seen but definitely not very good either . i myself am a paintball player , used to play airball a lot and going from woods to airball is quite a large change . xxmaj

Text classification

For the text classification, let’s go over our usual questions:

  • what are the types of our inputs and targets? Texts and categories.
  • where is the data? In a dataframe.
  • how do we know if a sample is in the training or the validation set? We have an is_valid column.
  • how do we get our inputs? In the text column.
  • how do we get our targets? In the label column.
imdb_clas = DataBlock(blocks=(TextBlock.from_df('text', seq_len=72, vocab=dls.vocab), CategoryBlock),
                      get_x=ColReader('text'),
                      get_y=ColReader('label'),
                      splitter=ColSplitter())

Like in the previous example, we use a class method to build a TextBlock. We can pass it the vocabulary of our language model (very useful for the ULMFit approach). We also show the seq_len argument (which defaults to 72) just because you need to make sure to use the same here and also in your text_classifier_learner.

Warning

You need to make sure to use the same seq_len in TextBlock and the Learner you will define later on.

dls = imdb_clas.dataloaders(df, bs=64)
dls.show_batch()
text category
0 xxbos xxmaj raising xxmaj victor xxmaj vargas : a xxmaj review \n\n xxmaj you know , xxmaj raising xxmaj victor xxmaj vargas is like sticking your hands into a big , xxunk bowl of xxunk . xxmaj it 's warm and gooey , but you 're not sure if it feels right . xxmaj try as i might , no matter how warm and gooey xxmaj raising xxmaj victor xxmaj vargas became i was always aware that something did n't quite feel right . xxmaj victor xxmaj vargas suffers from a certain xxunk on the director 's part . xxmaj apparently , the director thought that the ethnic backdrop of a xxmaj latino family on the lower east side , and an xxunk storyline would make the film critic proof . xxmaj he was right , but it did n't fool me . xxmaj raising xxmaj victor xxmaj vargas is negative
1 xxbos xxup the xxup shop xxup around xxup the xxup corner is one of the xxunk and most feel - good romantic comedies ever made . xxmaj there 's just no getting around that , and it 's hard to actually put one 's feeling for this film into words . xxmaj it 's not one of those films that tries too hard , nor does it come up with the xxunk possible scenarios to get the two protagonists together in the end . xxmaj in fact , all its charm is xxunk , contained within the characters and the setting and the plot … which is highly believable to xxunk . xxmaj it 's easy to think that such a love story , as beautiful as any other ever told , * could * happen to you … a feeling you do n't often get from other romantic comedies positive
2 xxbos xxmaj now that xxmaj che(2008 ) has finished its relatively short xxmaj australian cinema run ( extremely limited xxunk screen in xxmaj xxunk , after xxunk ) , i can xxunk join both xxunk of " at xxmaj the xxmaj movies " in taking xxmaj steven xxmaj soderbergh to task . \n\n xxmaj it 's usually satisfying to watch a film director change his style / subject , but xxmaj soderbergh 's most recent stinker , xxmaj the xxmaj girlfriend xxmaj xxunk ) , was also missing a story , so narrative ( and editing ? ) seem to suddenly be xxmaj soderbergh 's main challenge . xxmaj strange , after 20 - odd years in the business . xxmaj he was probably never much good at narrative , just xxunk it well inside " edgy " projects . \n\n xxmaj none of this excuses him this present , negative
3 xxbos xxmaj this film sat on my xxmaj xxunk for weeks before i watched it . i xxunk a self - indulgent xxunk flick about relationships gone bad . i was wrong ; this was an xxunk xxunk into the screwed - up xxunk of xxmaj new xxmaj xxunk . \n\n xxmaj the format is the same as xxmaj max xxmaj xxunk ' " la xxmaj xxunk , " based on a play by xxmaj arthur xxmaj xxunk , who is given an " inspired by " credit . xxmaj it starts from one person , a prostitute , standing on a street corner in xxmaj brooklyn . xxmaj she is picked up by a home contractor , who has sex with her on the hood of a car , but ca n't come . xxmaj he refuses to pay her . xxmaj when he 's off xxunk , she positive
4 xxbos i really wanted to love this show . i truly , honestly did . \n\n xxmaj for the first time , gay viewers get their own version of the " the xxmaj bachelor " . xxmaj with the help of his obligatory " hag " xxmaj xxunk , xxmaj james , a good looking , well - to - do thirty - something has the chance of love with 15 suitors ( or " mates " as they are referred to in the show ) . xxmaj the only problem is half of them are straight and xxmaj james does n't know this . xxmaj if xxmaj james picks a gay one , they get a trip to xxmaj new xxmaj zealand , and xxmaj if he picks a straight one , straight guy gets $ 25 , xxrep 3 0 . xxmaj how can this not be fun negative
5 xxbos xxmaj many neglect that this is n't just a classic due to the fact that it 's the first 3d game , or even the first xxunk - up . xxmaj it 's also one of the first xxunk games , one of the xxunk definitely the first ) truly claustrophobic games , and just a pretty well - xxunk gaming experience in general . xxmaj with graphics that are terribly dated today , the game xxunk you into the role of xxunk even * think * xxmaj i 'm going to attempt spelling his last name ! ) , an xxmaj american xxup xxunk . caught in an underground bunker . xxmaj you fight and search your way through xxunk in order to achieve different xxunk for the six xxunk , let 's face it , most of them are just an excuse to hand you a weapon positive
6 xxbos xxmaj the xxmaj blob starts with one of the most bizarre theme songs ever , xxunk by an uncredited xxmaj burt xxmaj xxunk of all people ! xxmaj you really have to hear it to believe it , xxmaj the xxmaj blob may be worth watching just for this song alone & my user comment summary is just a little taste of the classy lyrics … xxmaj after this xxunk opening credits sequence xxmaj the xxmaj blob introduces us , the viewer that is , to xxmaj steve xxmaj xxunk ( steve mcqueen as xxmaj steven mcqueen ) & his girlfriend xxmaj jane xxmaj martin ( xxunk xxmaj xxunk ) who are xxunk on their own somewhere & witness what looks like a meteorite falling to xxmaj earth in nearby woods . xxmaj an old man ( xxunk xxmaj xxunk as xxmaj xxunk xxmaj xxunk ) who lives in negative
7 xxbos xxmaj the year 2005 saw no xxunk than 3 filmed productions of xxup h. xxup g. xxmaj wells ' great novel , " war of the xxmaj worlds " . xxmaj this is perhaps the least well - known and very probably the best of them . xxmaj no other version of xxunk has ever attempted not only to present the story very much as xxmaj wells wrote it , but also to create the atmosphere of the time in which it was supposed to take place : the last year of the 19th xxmaj century , 1900 …▁ using xxmaj wells ' original setting , in and near xxmaj xxunk , xxmaj england . \n\n imdb seems xxunk to what they regard as " spoilers " . xxmaj that might apply with some films , where the ending might actually be a surprise , but with regard to positive
8 xxbos xxmaj well , what can i say . \n\n " what the xxmaj xxunk do we xxmaj know " has achieved the nearly impossible - leaving behind such masterpieces of the genre as " the xxmaj xxunk " , " the xxmaj xxunk xxmaj master " , " xxunk " , and so fourth , it will go down in history as the single worst movie i have ever seen in its xxunk . xxmaj and that , ladies and gentlemen , is impressive indeed , for i have seen many a bad movie . \n\n xxmaj this masterpiece of modern cinema consists of two xxunk parts , xxunk between a silly and contrived plot about an extremely annoying photographer , abandoned by her husband and forced to take anti - xxunk to survive , and a bunch of talking heads going on about how quantum physics supposedly xxunk negative

Tabular data

Tabular data doesn’t really use the data block API as it’s relying on another API with TabularPandas for efficient preprocessing and batching (there will be some less efficient API that plays nicely with the data block API added in the near future). You can still use different blocks for the targets.

from fastai.tabular.core import *

For our example, we will look at a subset of the adult dataset which contains some census data and where the task is to predict if someone makes more than 50k or not.

adult_source = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(adult_source/'adult.csv')
df.head()
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country salary
0 49 Private 101320 Assoc-acdm 12.0 Married-civ-spouse NaN Wife White Female 0 1902 40 United-States >=50k
1 44 Private 236746 Masters 14.0 Divorced Exec-managerial Not-in-family White Male 10520 0 45 United-States >=50k
2 38 Private 96185 HS-grad NaN Divorced NaN Unmarried Black Female 0 0 32 United-States <50k
3 38 Self-emp-inc 112847 Prof-school 15.0 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander Male 0 0 40 United-States >=50k
4 42 Self-emp-not-inc 82297 7th-8th NaN Married-civ-spouse Other-service Wife Black Female 0 0 50 United-States <50k

In a tabular problem, we need to split the columns between the ones that represent continuous variables (like the age) and the ones that represent categorical variables (like the education):

cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']

Standard preprocessing in fastai, use those pre-processors:

procs = [Categorify, FillMissing, Normalize]

Categorify will change the categorical columns into indices, FillMissing will fill the missing values in the continuous columns (if any) and add an na categorical column (if necessary). Normalize will normalize the continuous columns (subtract the mean and divide by the standard deviation).

We can still use any splitter to create the splits as we’d like them:

splits = RandomSplitter()(range_of(df))

And then everything goes in a TabularPandas object:

to = TabularPandas(df, procs, cat_names, cont_names, y_names="salary", splits=splits, y_block=CategoryBlock)

We put y_block=CategoryBlock just to show you how to customize the block for the targets, but it’s usually inferred from the data, so you don’t need to pass it, normally.

dls = to.dataloaders()
dls.show_batch()
workclass education marital-status occupation relationship race education-num_na age fnlwgt education-num salary
0 Self-emp-not-inc Prof-school Never-married Prof-specialty Not-in-family White False 34.0 204374.999924 15.0 >=50k
1 Private Some-college Never-married Adm-clerical Not-in-family White False 62.0 141307.999756 10.0 <50k
2 Private Assoc-acdm Never-married Other-service Not-in-family White False 23.0 152188.999004 12.0 <50k
3 Private HS-grad Divorced Craft-repair Unmarried White False 38.0 27407.999090 9.0 <50k
4 Private Bachelors Never-married Prof-specialty Not-in-family White False 32.0 340917.004812 13.0 >=50k
5 Private Bachelors Never-married Prof-specialty Not-in-family White False 22.0 153515.999598 13.0 <50k
6 Self-emp-not-inc Doctorate Never-married Prof-specialty Not-in-family White False 46.0 165754.000335 16.0 <50k
7 Private Masters Married-civ-spouse Prof-specialty Husband White False 33.0 202050.999896 14.0 <50k
8 Private Assoc-acdm Divorced Sales Unmarried White False 40.0 197919.000079 12.0 <50k
9 ? Some-college Never-married ? Own-child White False 18.0 264924.000434 10.0 <50k