from fastai.test_utils import *Progress and logging
Callback and helper function to track progress of training or log results
ProgressCallback
def ProgressCallback(
after_create:NoneType=None, before_fit:NoneType=None, before_epoch:NoneType=None, before_train:NoneType=None,
before_batch:NoneType=None, after_pred:NoneType=None, after_loss:NoneType=None, before_backward:NoneType=None,
after_cancel_backward:NoneType=None, after_backward:NoneType=None, before_step:NoneType=None,
after_cancel_step:NoneType=None, after_step:NoneType=None, after_cancel_batch:NoneType=None,
after_batch:NoneType=None, after_cancel_train:NoneType=None, after_train:NoneType=None,
before_validate:NoneType=None, after_cancel_validate:NoneType=None, after_validate:NoneType=None,
after_cancel_epoch:NoneType=None, after_epoch:NoneType=None, after_cancel_fit:NoneType=None,
after_fit:NoneType=None
):
A Callback to handle the display of progress bars
learn = synth_learner()
learn.fit(5)| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 14.523648 | 10.988108 | 00:00 |
| 1 | 12.395808 | 7.306935 | 00:00 |
| 2 | 10.121231 | 4.370981 | 00:00 |
| 3 | 8.065226 | 2.487984 | 00:00 |
| 4 | 6.374166 | 1.368232 | 00:00 |
no_bar
def no_bar(
):
Context manager that deactivates the use of progress bars
learn = synth_learner()
with learn.no_bar(): learn.fit(5)[0, 15.748106002807617, 12.352150917053223, '00:00']
[1, 13.818815231323242, 8.879858016967773, '00:00']
[2, 11.650713920593262, 5.857329845428467, '00:00']
[3, 9.595088005065918, 3.7397098541259766, '00:00']
[4, 7.814438343048096, 2.327916145324707, '00:00']
ProgressCallback.before_fit
def before_fit(
):
Setup the master bar over the epochs
ProgressCallback.before_epoch
def before_epoch(
):
Update the master bar
ProgressCallback.before_train
def before_train(
):
Launch a progress bar over the training dataloader
ProgressCallback.before_validate
def before_validate(
):
Launch a progress bar over the validation dataloader
ProgressCallback.after_batch
def after_batch(
):
Update the current progress bar
ProgressCallback.after_train
def after_train(
):
Close the progress bar over the training dataloader
ProgressCallback.after_validate
def after_validate(
):
Close the progress bar over the validation dataloader
ProgressCallback.after_fit
def after_fit(
):
Close the master bar
ShowGraphCallback
def ShowGraphCallback(
after_create:NoneType=None, before_fit:NoneType=None, before_epoch:NoneType=None, before_train:NoneType=None,
before_batch:NoneType=None, after_pred:NoneType=None, after_loss:NoneType=None, before_backward:NoneType=None,
after_cancel_backward:NoneType=None, after_backward:NoneType=None, before_step:NoneType=None,
after_cancel_step:NoneType=None, after_step:NoneType=None, after_cancel_batch:NoneType=None,
after_batch:NoneType=None, after_cancel_train:NoneType=None, after_train:NoneType=None,
before_validate:NoneType=None, after_cancel_validate:NoneType=None, after_validate:NoneType=None,
after_cancel_epoch:NoneType=None, after_epoch:NoneType=None, after_cancel_fit:NoneType=None,
after_fit:NoneType=None
):
Update a graph of training and validation loss
learn = synth_learner(cbs=ShowGraphCallback())
learn.fit(5)| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 17.683565 | 10.431150 | 00:00 |
| 1 | 15.232769 | 7.056944 | 00:00 |
| 2 | 12.470916 | 4.382421 | 00:00 |
| 3 | 10.000675 | 2.574951 | 00:00 |
| 4 | 7.943449 | 1.464153 | 00:00 |

learn.predict(torch.tensor([[0.1]]))(tensor([1.8955]), tensor([1.8955]), tensor([1.8955]))
CSVLogger
def CSVLogger(
fname:str='history.csv', append:bool=False
):
Log the results displayed in learn.path/fname
The results are appended to an existing file if append, or they overwrite it otherwise.
learn = synth_learner(cbs=CSVLogger())
learn.fit(5)| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 15.606769 | 14.485189 | 00:00 |
| 1 | 13.840394 | 10.834929 | 00:00 |
| 2 | 11.842106 | 7.582738 | 00:00 |
| 3 | 9.937692 | 5.158300 | 00:00 |
| 4 | 8.244681 | 3.432087 | 00:00 |
CSVLogger.read_log
def read_log(
):
Convenience method to quickly access the log.
df = learn.csv_logger.read_log()
test_eq(df.columns.values, learn.recorder.metric_names)
for i,v in enumerate(learn.recorder.values):
test_close(df.iloc[i][:3], [i] + v)
os.remove(learn.path/learn.csv_logger.fname)CSVLogger.before_fit
def before_fit(
):
Prepare file with metric names.
CSVLogger.after_fit
def after_fit(
):
Close the file and clean up.