Various callbacks to customize get_preds behaviors

MCDropoutCallback

Turns on dropout during inference, allowing you to call Learner.get_preds multiple times to approximate your model uncertainty using Monte Carlo Dropout.

class MCDropoutCallback[source]

MCDropoutCallback(after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_backward=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None) :: Callback

Basic class handling tweaks of the training loop by changing a Learner in various events

learn = synth_learner()

# Call get_preds 10 times, then stack the predictions, yielding a tensor with shape [# of samples, batch_size, ...]
dist_preds = []
for i in range(10):
    preds, targs = learn.get_preds(cbs=[MCDropoutCallback()])
    dist_preds += [preds]

torch.stack(dist_preds).shape
torch.Size([10, 32, 1])