Basic classes to contain the data for model training.

Get your data ready for training

This module defines the basic DataBunch object that is used inside Learner to train a model. This is the generic class, that can take any kind of fastai Dataset or DataLoader. You'll find helpful functions in the data module of every application to directly create this DataBunch for you.

class DataBunch[source]

DataBunch(train_dl:DataLoader, valid_dl:DataLoader, test_dl:Optional[DataLoader]=None, device:device=None, tfms:Optional[Collection[Callable]]=None, path:PathOrStr='.', collate_fn:Callable='data_collate')

Bind together a train_dl, a valid_dl and optionally a test_dl, ensures they are on device and apply to them tfms as batch are drawn. path is used internally to store temporary files, collate_fn is passed to the pytorch Dataloader (replacing the one there) to explain how to collate the samples picked for a batch. By default, it applies data to the object sent (see in vision.image why this can be important).

An example of tfms to pass is normalization. train_dl, valid_dl and optionally test_dl will be wrapped in DeviceDataLoader.


create(train_ds:Dataset, valid_ds:Dataset, test_ds:Dataset=None, path:PathOrStr='.', bs:int=64, num_workers:int=4, tfms:Optional[Collection[Callable]]=None, device:device=None, collate_fn:Callable='data_collate') → DataBunch

Create a DataBunch from train_ds, valid_ds and optionally test_ds, with batch size bs and by using num_workers. tfms and device are passed to the init method.


show_batch(rows:int=None, ds_type:DatasetType=<DatasetType.Train: 1>, kwargs)

Show a batch of data in ds_type on a few rows.


dl(ds_type:DatasetType=<DatasetType.Valid: 2>) → DeviceDataLoader

Returns appropriate Dataset for validation, training, or test (ds_type).



Adds a transform to all dataloaders.

class DeviceDataLoader[source]

DeviceDataLoader(dl:DataLoader, device:device, tfms:List[Callable]=None, collate_fn:Callable='data_collate', skip_size1:bool=False)

Put the batches of dl on device after applying an optional list of tfms. collate_fn will replace the one of dl. All dataloaders of a DataBunch are of this type.

Factory method


create(dataset:Dataset, bs:int=64, shuffle:bool=False, device:device=device(type='cuda'), tfms:Collection[Callable]=None, num_workers:int=4, collate_fn:Callable='data_collate', kwargs:Any)

Create a DeviceDataLoader on device from a dataset with batch size bs, num_workersprocesses and a given collate_fn. The dataloader will shuffle the data if that flag is set to True, and tfms are passed to the init method. All kwargs are passed to the pytorch DataLoader class initialization.



one_batch() → Collection[Tensor]

Get one batch from the data loader.



Add a transform (i.e. same as self.tfms.append(tfm)).



Remove a transform.


Enum = [Train, Valid, Test, Single]

Internal enumerator to name the training, validation and test dataset/dataloader.