= TensorBBox([[-2,-0.5,0.5,1.5], [-0.5,-0.5,0.5,0.5], [1,0.5,0.5,0.75], [-0.5,-0.5,0.5,0.5], [-2, -0.5, -1.5, 0.5]])
bb = clip_remove_empty(bb, TensorMultiCategory([1,2,3,2,5]))
bb,lbl -1,-0.5,0.5,1.], [-0.5,-0.5,0.5,0.5], [-0.5,-0.5,0.5,0.5]]))
test_eq(bb, TensorBBox([[1,2,2])) test_eq(lbl, TensorMultiCategory([
Vision data
DataLoaders
in the vision application and higher class ImageDataLoaders
The main classes defined in this module are ImageDataLoaders
and SegmentationDataLoaders
, so you probably want to jump to their definitions. They provide factory methods that are a great way to quickly get your data ready for training, see the vision tutorial for examples.
Helper functions
get_grid
get_grid (n:int, nrows:int=None, ncols:int=None, figsize:tuple=None, double:bool=False, title:str=None, return_fig:bool=False, flatten:bool=True, imsize:int=3, suptitle:str=None, sharex:"bool|Literal['none','all','row','col']"=False, sharey:"bool|Literal['none','all','row','col']"=False, squeeze:bool=True, width_ratios:Sequence[float]|None=None, height_ratios:Sequence[float]|None=None, subplot_kw:dict[str,Any]|None=None, gridspec_kw:dict[str,Any]|None=None)
Return a grid of n
axes, rows
by cols
Type | Default | Details | |
---|---|---|---|
n | int | Number of axes in the returned grid | |
nrows | int | None | Number of rows in the returned grid, defaulting to int(math.sqrt(n)) |
ncols | int | None | Number of columns in the returned grid, defaulting to ceil(n/rows) |
figsize | tuple | None | Width, height in inches of the returned figure |
double | bool | False | Whether to double the number of columns and n |
title | str | None | If passed, title set to the figure |
return_fig | bool | False | Whether to return the figure created by subplots |
flatten | bool | True | Whether to flatten the matplot axes such that they can be iterated over with a single loop |
imsize | int | 3 | Size (in inches) of images that will be displayed in the returned figure |
suptitle | str | None | Title to be set to returned figure |
sharex | bool | Literal[‘none’, ‘all’, ‘row’, ‘col’] | False | |
sharey | bool | Literal[‘none’, ‘all’, ‘row’, ‘col’] | False | |
squeeze | bool | True | |
width_ratios | Sequence[float] | None | None | |
height_ratios | Sequence[float] | None | None | |
subplot_kw | dict[str, Any] | None | None | |
gridspec_kw | dict[str, Any] | None | None | |
Returns | (plt.Figure, plt.Axes) | Returns just axs by default, and (fig , axs ) if return_fig is set to True |
This is used by the type-dispatched versions of show_batch
and show_results
for the vision application. The default figsize
is (cols*imsize, rows*imsize+0.6)
. imsize
is passed down to subplots
. suptitle
, sharex
, sharey
, squeeze
, subplot_kw
and gridspec_kw
are all passed down to plt.subplots. If return_fig
is True
, returns fig,axs
, otherwise just axs
.
clip_remove_empty
clip_remove_empty (bbox:fastai.vision.core.TensorBBox, label:fastai.torch_core.TensorMultiCategory)
Clip bounding boxes with image border and remove empty boxes along with corresponding labels
Type | Details | |
---|---|---|
bbox | TensorBBox | Coordinates of bounding boxes |
label | TensorMultiCategory | Labels of the bounding boxes |
This is used in bb_pad
bb_pad
bb_pad (samples:list, pad_idx=0)
Function that collects samples
of labelled bboxes and adds padding with pad_idx
.
Type | Default | Details | |
---|---|---|---|
samples | list | List of 3-tuples like (image, bounding_boxes, labels) | |
pad_idx | int | 0 | Label that will be used to pad each list of labels |
This is used in BBoxBlock
= TensorImage(torch.randn(16,16,3)),TensorImage(torch.randn(16,16,3))
img1,img2 = tensor([[-2,-0.5,0.5,1.5], [-0.5,-0.5,0.5,0.5], [1,0.5,0.5,0.75], [-0.5,-0.5,0.5,0.5]])
bb1 = tensor([1, 2, 3, 2])
lbl1 = tensor([[-0.5,-0.5,0.5,0.5], [-0.5,-0.5,0.5,0.5]])
bb2 = tensor([2, 2])
lbl2 = [(img1, bb1, lbl1), (img2, bb2, lbl2)]
samples = bb_pad(samples)
res = tensor([True,True,False,True])
non_empty 0][0], img1)
test_eq(res[0][1], tensor([[-1,-0.5,0.5,1.], [-0.5,-0.5,0.5,0.5], [-0.5,-0.5,0.5,0.5]]))
test_eq(res[0][2], tensor([1,2,2]))
test_eq(res[1][0], img2)
test_eq(res[1][1], tensor([[-0.5,-0.5,0.5,0.5], [-0.5,-0.5,0.5,0.5], [0,0,0,0]]))
test_eq(res[1][2], tensor([2,2,0])) test_eq(res[
TransformBlock
s for vision
These are the blocks the vision application provide for the data block API.
ImageBlock
ImageBlock (cls:fastai.vision.core.PILBase=<class 'fastai.vision.core.PILImage'>)
A TransformBlock
for images of cls
MaskBlock
MaskBlock (codes:list=None)
A TransformBlock
for segmentation masks, potentially with codes
Type | Default | Details | |
---|---|---|---|
codes | list | None | Vocab labels for segmentation masks |
PointBlock
A TransformBlock
for points in an image
BBoxBlock
A TransformBlock
for bounding boxes in an image
BBoxLblBlock
BBoxLblBlock (vocab:list=None, add_na:bool=True)
A TransformBlock
for labeled bounding boxes, potentially with vocab
Type | Default | Details | |
---|---|---|---|
vocab | list | None | Vocab labels for bounding boxes |
add_na | bool | True | Add NaN as a background class |
If add_na
is True
, a new category is added for NaN (that will represent the background class).
ImageDataLoaders
ImageDataLoaders (*loaders, path:str|pathlib.Path='.', device=None)
Basic wrapper around several DataLoader
s with factory methods for computer vision problems
This class should not be used directly, one of the factory methods should be preferred instead. All those factory methods accept as arguments:
item_tfms
: one or several transforms applied to the items before batching thembatch_tfms
: one or several transforms applied to the batches once they are formedbs
: the batch sizeval_bs
: the batch size for the validationDataLoader
(defaults tobs
)shuffle_train
: if we shuffle the trainingDataLoader
or notdevice
: the PyTorch device to use (defaults todefault_device()
)
ImageDataLoaders.from_folder
ImageDataLoaders.from_folder (path, train='train', valid='valid', valid_pct=None, seed=None, vocab=None, item_tfms=None, batch_tfms=None, img_cls=<class 'fastai.vision.core.PILImage'>, bs:int=64, val_bs:int=None, shuffle:bool=True, device=None)
Create from imagenet style dataset in path
with train
and valid
subfolders (or provide valid_pct
)
Type | Default | Details | |
---|---|---|---|
path | str | pathlib.Path | . | Path to put in DataLoaders |
train | str | train | |
valid | str | valid | |
valid_pct | NoneType | None | |
seed | NoneType | None | |
vocab | NoneType | None | |
item_tfms | NoneType | None | |
batch_tfms | NoneType | None | |
img_cls | BypassNewMeta | PILImage | |
bs | int | 64 | Size of batch |
val_bs | int | None | Size of batch for validation DataLoader |
shuffle | bool | True | Whether to shuffle data |
device | NoneType | None | Device to put DataLoaders |
If valid_pct
is provided, a random split is performed (with an optional seed
) by setting aside that percentage of the data for the validation set (instead of looking at the grandparents folder). If a vocab
is passed, only the folders with names in vocab
are kept.
Here is an example loading a subsample of MNIST:
= untar_data(URLs.MNIST_TINY)
path = ImageDataLoaders.from_folder(path, img_cls=PILImageBW) dls
= dls.one_batch()
x,y 64, 1, 28, 28]) test_eq(x.shape, [
Passing valid_pct
will ignore the valid/train folders and do a new random split:
= ImageDataLoaders.from_folder(path, valid_pct=0.2)
dls 3] dls.valid_ds.items[:
[Path('/home/jhoward/.fastai/data/mnist_tiny/train/7/9307.png'),
Path('/home/jhoward/.fastai/data/mnist_tiny/train/3/8241.png'),
Path('/home/jhoward/.fastai/data/mnist_tiny/valid/3/8924.png')]
ImageDataLoaders.from_path_func
ImageDataLoaders.from_path_func (path, fnames, label_func, valid_pct=0.2, seed=None, item_tfms=None, batch_tfms=None, img_cls=<class 'fastai.vision.core.PILImage'>, bs:int=64, val_bs:int=None, shuffle:bool=True, device=None)
Create from list of fnames
in path
s with label_func
Type | Default | Details | |
---|---|---|---|
path | str | pathlib.Path | . | Path to put in DataLoaders |
fnames | |||
label_func | |||
valid_pct | float | 0.2 | |
seed | NoneType | None | |
item_tfms | NoneType | None | |
batch_tfms | NoneType | None | |
img_cls | BypassNewMeta | PILImage | |
bs | int | 64 | Size of batch |
val_bs | int | None | Size of batch for validation DataLoader |
shuffle | bool | True | Whether to shuffle data |
device | NoneType | None | Device to put DataLoaders |
The validation set is a random subset
of valid_pct
, optionally created with seed
for reproducibility.
Here is how to create the same DataLoaders
on the MNIST dataset as the previous example with a label_func
:
= get_image_files(path)
fnames def label_func(x): return x.parent.name
= ImageDataLoaders.from_path_func(path, fnames, label_func) dls
Here is another example on the pets dataset. Here filenames are all in an “images” folder and their names have the form class_name_123.jpg
. One way to properly label them is thus to throw away everything after the last _
:
ImageDataLoaders.from_path_re
ImageDataLoaders.from_path_re (path, fnames, pat, valid_pct=0.2, seed=None, item_tfms=None, batch_tfms=None, img_cls=<class 'fastai.vision.core.PILImage'>, bs:int=64, val_bs:int=None, shuffle:bool=True, device=None)
Create from list of fnames
in path
s with re expression pat
Type | Default | Details | |
---|---|---|---|
path | str | pathlib.Path | . | Path to put in DataLoaders |
fnames | |||
pat | |||
valid_pct | float | 0.2 | |
seed | NoneType | None | |
item_tfms | NoneType | None | |
batch_tfms | NoneType | None | |
img_cls | BypassNewMeta | PILImage | |
bs | int | 64 | Size of batch |
val_bs | int | None | Size of batch for validation DataLoader |
shuffle | bool | True | Whether to shuffle data |
device | NoneType | None | Device to put DataLoaders |
The validation set is a random subset of valid_pct
, optionally created with seed
for reproducibility.
Here is how to create the same DataLoaders
on the MNIST dataset as the previous example (you will need to change the initial two / by a on Windows):
= r'/([^/]*)/\d+.png$'
pat = ImageDataLoaders.from_path_re(path, fnames, pat) dls
ImageDataLoaders.from_name_func
ImageDataLoaders.from_name_func (path:str|Path, fnames:list, label_func:callable, valid_pct=0.2, seed=None, item_tfms=None, batch_tfms=None, img_cls=<class 'fastai.vision.core.PILImage'>, bs:int=64, val_bs:int=None, shuffle:bool=True, device=None)
Create from the name attrs of fnames
in path
s with label_func
Type | Default | Details | |
---|---|---|---|
path | str | Path | Set the default path to a directory that a Learner can use to save files like models |
|
fnames | list | A list of os.Pathlike ’s to individual image files |
|
label_func | callable | A function that receives a string (the file name) and outputs a label | |
valid_pct | float | 0.2 | |
seed | NoneType | None | |
item_tfms | NoneType | None | |
batch_tfms | NoneType | None | |
img_cls | BypassNewMeta | PILImage | |
bs | int | 64 | Size of batch |
val_bs | int | None | Size of batch for validation DataLoader |
shuffle | bool | True | Whether to shuffle data |
device | NoneType | None | Device to put DataLoaders |
Returns | DataLoaders |
The validation set is a random subset of valid_pct
, optionally created with seed
for reproducibility. This method does the same as ImageDataLoaders.from_path_func
except label_func
is applied to the name of each filenames, and not the full path.
ImageDataLoaders.from_name_re
ImageDataLoaders.from_name_re (path, fnames, pat, bs:int=64, val_bs:int=None, shuffle:bool=True, device=None)
Create from the name attrs of fnames
in path
s with re expression pat
Type | Default | Details | |
---|---|---|---|
path | str | pathlib.Path | . | Path to put in DataLoaders |
fnames | |||
pat | |||
bs | int | 64 | Size of batch |
val_bs | int | None | Size of batch for validation DataLoader |
shuffle | bool | True | Whether to shuffle data |
device | NoneType | None | Device to put DataLoaders |
The validation set is a random subset of valid_pct
, optionally created with seed
for reproducibility. This method does the same as ImageDataLoaders.from_path_re
except pat
is applied to the name of each filenames, and not the full path.
ImageDataLoaders.from_df
ImageDataLoaders.from_df (df, path='.', valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='', label_col=1, label_delim=None, y_block=None, valid_col=None, item_tfms=None, batch_tfms=None, img_cls=<class 'fastai.vision.core.PILImage'>, bs:int=64, val_bs:int=None, shuffle:bool=True, device=None)
Create from df
using fn_col
and label_col
Type | Default | Details | |
---|---|---|---|
df | |||
path | str | pathlib.Path | . | Path to put in DataLoaders |
valid_pct | float | 0.2 | |
seed | NoneType | None | |
fn_col | int | 0 | |
folder | NoneType | None | |
suff | str | ||
label_col | int | 1 | |
label_delim | NoneType | None | |
y_block | NoneType | None | |
valid_col | NoneType | None | |
item_tfms | NoneType | None | |
batch_tfms | NoneType | None | |
img_cls | BypassNewMeta | PILImage | |
bs | int | 64 | Size of batch |
val_bs | int | None | Size of batch for validation DataLoader |
shuffle | bool | True | Whether to shuffle data |
device | NoneType | None | Device to put DataLoaders |
The validation set is a random subset of valid_pct
, optionally created with seed
for reproducibility. Alternatively, if your df
contains a valid_col
, give its name or its index to that argument (the column should have True
for the elements going to the validation set).
You can add an additional folder
to the filenames in df
if they should not be concatenated directly to path
. If they do not contain the proper extensions, you can add suff
. If your label column contains multiple labels on each row, you can use label_delim
to warn the library you have a multi-label problem.
y_block
should be passed when the task automatically picked by the library is wrong, you should then give CategoryBlock
, MultiCategoryBlock
or RegressionBlock
. For more advanced uses, you should use the data block API.
The tiny mnist example from before also contains a version in a dataframe:
= untar_data(URLs.MNIST_TINY)
path = pd.read_csv(path/'labels.csv')
df df.head()
name | label | |
---|---|---|
0 | train/3/7463.png | 3 |
1 | train/3/9829.png | 3 |
2 | train/3/7881.png | 3 |
3 | train/3/8065.png | 3 |
4 | train/3/7046.png | 3 |
Here is how to load it using ImageDataLoaders.from_df
:
= ImageDataLoaders.from_df(df, path) dls
/home/jhoward/git/fastai/fastai/data/transforms.py:212: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`
o = r[c] if isinstance(c, int) or not c in getattr(r, '_fields', []) else getattr(r, c)
Here is another example with a multi-label problem:
= untar_data(URLs.PASCAL_2007)
path = pd.read_csv(path/'train.csv')
df df.head()
fname | labels | is_valid | |
---|---|---|---|
0 | 000005.jpg | chair | True |
1 | 000007.jpg | car | True |
2 | 000009.jpg | horse person | True |
3 | 000012.jpg | car | False |
4 | 000016.jpg | bicycle | True |
= ImageDataLoaders.from_df(df, path, folder='train', valid_col='is_valid') dls
/home/jhoward/git/fastai/fastai/data/transforms.py:212: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`
o = r[c] if isinstance(c, int) or not c in getattr(r, '_fields', []) else getattr(r, c)
Note that can also pass 2
to valid_col (the index, starting with 0).
ImageDataLoaders.from_csv
ImageDataLoaders.from_csv (path, csv_fname='labels.csv', header='infer', delimiter=None, quoting=0, valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='', label_col=1, label_delim=None, y_block=None, valid_col=None, item_tfms=None, batch_tfms=None, img_cls=<class 'fastai.vision.core.PILImage'>, bs:int=64, val_bs:int=None, shuffle:bool=True, device=None)
Create from path/csv_fname
using fn_col
and label_col
Type | Default | Details | |
---|---|---|---|
path | str | pathlib.Path | . | Path to put in DataLoaders |
csv_fname | str | labels.csv | |
header | str | infer | |
delimiter | NoneType | None | |
quoting | int | 0 | |
valid_pct | float | 0.2 | |
seed | NoneType | None | |
fn_col | int | 0 | |
folder | NoneType | None | |
suff | str | ||
label_col | int | 1 | |
label_delim | NoneType | None | |
y_block | NoneType | None | |
valid_col | NoneType | None | |
item_tfms | NoneType | None | |
batch_tfms | NoneType | None | |
img_cls | BypassNewMeta | PILImage | |
bs | int | 64 | Size of batch |
val_bs | int | None | Size of batch for validation DataLoader |
shuffle | bool | True | Whether to shuffle data |
device | NoneType | None | Device to put DataLoaders |
Same as ImageDataLoaders.from_df
after loading the file with header
and delimiter
.
Here is how to load the same dataset as before with this method:
= ImageDataLoaders.from_csv(path, 'train.csv', folder='train', valid_col='is_valid') dls
/home/jhoward/git/fastai/fastai/data/transforms.py:212: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`
o = r[c] if isinstance(c, int) or not c in getattr(r, '_fields', []) else getattr(r, c)
ImageDataLoaders.from_lists
ImageDataLoaders.from_lists (path, fnames, labels, valid_pct=0.2, seed:int=None, y_block=None, item_tfms=None, batch_tfms=None, img_cls=<class 'fastai.vision.core.PILImage'>, bs:int=64, val_bs:int=None, shuffle:bool=True, device=None)
Create from list of fnames
and labels
in path
Type | Default | Details | |
---|---|---|---|
path | str | pathlib.Path | . | Path to put in DataLoaders |
fnames | |||
labels | |||
valid_pct | float | 0.2 | |
seed | int | None | |
y_block | NoneType | None | |
item_tfms | NoneType | None | |
batch_tfms | NoneType | None | |
img_cls | BypassNewMeta | PILImage | |
bs | int | 64 | Size of batch |
val_bs | int | None | Size of batch for validation DataLoader |
shuffle | bool | True | Whether to shuffle data |
device | NoneType | None | Device to put DataLoaders |
The validation set is a random subset of valid_pct
, optionally created with seed
for reproducibility. y_block
can be passed to specify the type of the targets.
= untar_data(URLs.PETS)
path = get_image_files(path/"images")
fnames = ['_'.join(x.name.split('_')[:-1]) for x in fnames]
labels = ImageDataLoaders.from_lists(path, fnames, labels) dls
Downloading a new version of this dataset...
SegmentationDataLoaders
SegmentationDataLoaders (*loaders, path:str|pathlib.Path='.', device=None)
Basic wrapper around several DataLoader
s with factory methods for segmentation problems
SegmentationDataLoaders.from_label_func
SegmentationDataLoaders.from_label_func (path, fnames, label_func, valid_pct=0.2, seed=None, codes=None, item_tfms=None, batch_tfms=None, img_cls=<class 'fastai.vision.core.PILImage'>, bs:int=64, val_bs:int=None, shuffle:bool=True, device=None)
Create from list of fnames
in path
s with label_func
.
Type | Default | Details | |
---|---|---|---|
path | str | pathlib.Path | . | Path to put in DataLoaders |
fnames | |||
label_func | |||
valid_pct | float | 0.2 | |
seed | NoneType | None | |
codes | NoneType | None | |
item_tfms | NoneType | None | |
batch_tfms | NoneType | None | |
img_cls | BypassNewMeta | PILImage | |
bs | int | 64 | Size of batch |
val_bs | int | None | Size of batch for validation DataLoader |
shuffle | bool | True | Whether to shuffle data |
device | NoneType | None | Device to put DataLoaders |
The validation set is a random subset of valid_pct
, optionally created with seed
for reproducibility. codes
contain the mapping index to label.
= untar_data(URLs.CAMVID_TINY)
path = get_image_files(path/'images')
fnames def label_func(x): return path/'labels'/f'{x.stem}_P{x.suffix}'
= np.loadtxt(path/'codes.txt', dtype=str)
codes
= SegmentationDataLoaders.from_label_func(path, fnames, label_func, codes=codes) dls