Integration with Weights & Biases

First thing first, you need to install wandb with

pip install wandb

Create a free account then run

wandb login

in your terminal. Follow the link to get an API token that you will need to paste, then you're all set!

class WandbCallback[source]

WandbCallback(log='gradients', log_preds=True, log_model=True, log_dataset=False, dataset_name=None, valid_dl=None, n_preds=36, seed=12345, reorder=True) :: Callback

Saves model topology, losses & metrics

Optionally logs weights and or gradients depending on log (can be "gradients", "parameters", "all" or None), sample predictions if log_preds=True that will come from valid_dl or a random sample pf the validation set (determined by seed). n_preds are logged in this case.

If used in combination with SaveModelCallback, the best model is saved as well (can be deactivated with log_model=False).

Datasets can also be tracked:

  • if log_dataset is True, tracked folder is retrieved from learn.dls.path
  • log_dataset can explicitly be set to the folder to track
  • the name of the dataset can explicitly be given through dataset_name, otherwise it is set to the folder name
  • Note: the subfolder "models" is always ignored

For custom scenarios, you can also manually use functions log_dataset and log_model to respectively log your own datasets and models.

log_dataset[source]

log_dataset(path, name=None, metadata={})

Log dataset folder

log_model[source]

log_model(path, name=None, metadata={})

Log model file

Example of use:

Once your have defined your Learner, before you call to fit or fit_one_cycle, you need to initialize wandb:

import wandb
wandb.init()

To use Weights & Biases without an account, you can call wandb.init(anonymous='allow').

Then you add the callback to your learner or call to fit methods, potentially with SaveModelCallback if you want to save the best model:

from fastai.callback.wandb import *

# To log only during one training phase
learn.fit(..., cbs=WandbCallback())

# To log continuously for all training phases
learn = learner(..., cbs=WandbCallback())

Datasets and models can be tracked through the callback or directly through log_model and log_dataset functions.

For more details, refer to W&B documentation.