Callbacks and helper functions to train in parallel or use distributed training

When using multiple GPUs, you will most probably want to fit using distributed training. See examples/ for a complete example. To use distributed training, there are only two required steps:

  1. Add with learn.distrib_ctx(): before your call
  2. Run your training script with python -m fastai.launch ...args...

After fastai.launch you can add --gpus 0,1 for instance, to use only using GPUs 1 and 2.

If you're using untar_data, or may be downloading or uncompressing data or models as part of your script, you should wrap that code with rank0_first, which forces that step to occur first just once on the master process, prior to the remaining processes running it in parallel. E.g. instead of:

path = untar_data(URLs.IMAGEWOOF_320) instead use:

path = rank0_first(untar_data, URLs.IMAGEWOOF_320)

See below for details on the full API and underlying helper functions, if needed -- however, note that you will not need anything except the above unless you need to change how the distributed training is implemented.




Patch required reset call into DataParallel

class ParallelTrainer[source]

ParallelTrainer(device_ids) :: Callback

Wrap a model DataParallel automatically



Add ParallelTrainer callback to a Learner



Remove ParallelTrainer callback from a Learner



A context manager to adapt a learner to train in data parallel mode.


Helper functions



Patch required reset call into DistributedDataParallel



Setup this process to participate in distributed training



Free distributed training resources


class DistributedDL[source]

DistributedDL(dl, rank=None, world_size=None) :: TfmdDL

A TfmdDL which splits a batch into equal size pieces for each worker

dl = TfmdDL(list(range(50)), bs=12, num_workers=2)
for i in range(4):
    dl1 = DistributedDL(dl, i, 4)
    test_eq(list(dl1), (torch.arange(i*13, i*13+12)%50,torch.tensor([i*13+12])%50))

class DistributedTrainer[source]

DistributedTrainer(cuda_id=0, sync_bn=True) :: Callback

Wrap model in DistributedDataParallel and dls in DistributedDL


Learner.to_distributed(cuda_id, sync_bn=True)

Add DistributedTrainer to a learner



Remove DistributedTrainer from a learner

distrib_ctx context manager


Learner.distrib_ctx(cuda_id=None, sync_bn=True)

A context manager to adapt a learner to train in distributed data parallel mode.

distrib_ctx prepares a learner to train in distributed data parallel mode. It assumes these environment variables have all been setup properly, such as those launched by python -m fastai.launch.

Typical usage:

with learn.distrib_ctx():

It attaches a DistributedTrainer callback and DistributedDL data loader to the learner, then executes Upon exiting the context, it removes the DistributedTrainer and DistributedDL, and destroys any locally created distributed process group. The process is still attached to the GPU though.


rank0_first(func, *args, **kwargs)

Execute func in the Rank-0 process first, then in other ranks in parallel.

rank0_first calls f() in rank-0 process first, then in parallel on the rest, in distributed training mode. In single process, non-distributed training mode, f() is called only once as expected.

One application of rank0_first() is to make fresh downloads via untar_data safe in distributed training scripts launched by python -m fastai.launch <script>:

path = untar_data(URLs.IMDB)


path = rank0_first(lambda: untar_data(URLs.IMDB))

Some learner factory methods may use untar_data to download pretrained models:

learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)


learn = rank0_first(lambda: text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy))

Otherwise, multiple processes will download at the same time and corrupt the data.