Various callbacks to customize get_preds behaviors
MCDropoutCallback
Turns on dropout during inference, allowing you to call Learner.get_preds multiple times to approximate your model uncertainty using Monte Carlo Dropout.
learn = synth_learner()
# Call get_preds 10 times, then stack the predictions, yielding a tensor with shape [# of samples, batch_size, ...]
dist_preds = []
for i in range(10):
preds, targs = learn.get_preds(cbs=[MCDropoutCallback()])
dist_preds += [preds]
torch.stack(dist_preds).shape