Progress and logging

Callback and helper function to track progress of training or log results
from fastai.test_utils import *

source

ProgressCallback

 ProgressCallback (after_create=None, before_fit=None, before_epoch=None,
                   before_train=None, before_batch=None, after_pred=None,
                   after_loss=None, before_backward=None,
                   after_cancel_backward=None, after_backward=None,
                   before_step=None, after_cancel_step=None,
                   after_step=None, after_cancel_batch=None,
                   after_batch=None, after_cancel_train=None,
                   after_train=None, before_validate=None,
                   after_cancel_validate=None, after_validate=None,
                   after_cancel_epoch=None, after_epoch=None,
                   after_cancel_fit=None, after_fit=None)

A Callback to handle the display of progress bars

learn = synth_learner()
learn.fit(5)
epoch train_loss valid_loss time
0 14.523648 10.988108 00:00
1 12.395808 7.306935 00:00
2 10.121231 4.370981 00:00
3 8.065226 2.487984 00:00
4 6.374166 1.368232 00:00

no_bar

 no_bar ()

Context manager that deactivates the use of progress bars

learn = synth_learner()
with learn.no_bar(): learn.fit(5)
[0, 15.748106002807617, 12.352150917053223, '00:00']
[1, 13.818815231323242, 8.879858016967773, '00:00']
[2, 11.650713920593262, 5.857329845428467, '00:00']
[3, 9.595088005065918, 3.7397098541259766, '00:00']
[4, 7.814438343048096, 2.327916145324707, '00:00']

source

ProgressCallback.before_fit

 ProgressCallback.before_fit ()

Setup the master bar over the epochs


source

ProgressCallback.before_epoch

 ProgressCallback.before_epoch ()

Update the master bar


source

ProgressCallback.before_train

 ProgressCallback.before_train ()

Launch a progress bar over the training dataloader


source

ProgressCallback.before_validate

 ProgressCallback.before_validate ()

Launch a progress bar over the validation dataloader


source

ProgressCallback.after_batch

 ProgressCallback.after_batch ()

Update the current progress bar


source

ProgressCallback.after_train

 ProgressCallback.after_train ()

Close the progress bar over the training dataloader


source

ProgressCallback.after_validate

 ProgressCallback.after_validate ()

Close the progress bar over the validation dataloader


source

ProgressCallback.after_fit

 ProgressCallback.after_fit ()

Close the master bar


source

ShowGraphCallback

 ShowGraphCallback (after_create=None, before_fit=None, before_epoch=None,
                    before_train=None, before_batch=None, after_pred=None,
                    after_loss=None, before_backward=None,
                    after_cancel_backward=None, after_backward=None,
                    before_step=None, after_cancel_step=None,
                    after_step=None, after_cancel_batch=None,
                    after_batch=None, after_cancel_train=None,
                    after_train=None, before_validate=None,
                    after_cancel_validate=None, after_validate=None,
                    after_cancel_epoch=None, after_epoch=None,
                    after_cancel_fit=None, after_fit=None)

Update a graph of training and validation loss

learn = synth_learner(cbs=ShowGraphCallback())
learn.fit(5)
epoch train_loss valid_loss time
0 17.683565 10.431150 00:00
1 15.232769 7.056944 00:00
2 12.470916 4.382421 00:00
3 10.000675 2.574951 00:00
4 7.943449 1.464153 00:00

learn.predict(torch.tensor([[0.1]]))
(tensor([1.8955]), tensor([1.8955]), tensor([1.8955]))

source

CSVLogger

 CSVLogger (fname='history.csv', append=False)

Log the results displayed in learn.path/fname

The results are appended to an existing file if append, or they overwrite it otherwise.

learn = synth_learner(cbs=CSVLogger())
learn.fit(5)
epoch train_loss valid_loss time
0 15.606769 14.485189 00:00
1 13.840394 10.834929 00:00
2 11.842106 7.582738 00:00
3 9.937692 5.158300 00:00
4 8.244681 3.432087 00:00

source

CSVLogger.read_log

 CSVLogger.read_log ()

Convenience method to quickly access the log.

df = learn.csv_logger.read_log()
test_eq(df.columns.values, learn.recorder.metric_names)
for i,v in enumerate(learn.recorder.values):
    test_close(df.iloc[i][:3], [i] + v)
os.remove(learn.path/learn.csv_logger.fname)

source

CSVLogger.before_fit

 CSVLogger.before_fit ()

Prepare file with metric names.


source

CSVLogger.after_fit

 CSVLogger.after_fit ()

Close the file and clean up.