Implementation of a flexible training API

TrainingPhase and General scheduler

Creates a scheduler that lets you train a model with following different TrainingPhase.

class TrainingPhase[source]

TrainingPhase(`length`:int, `lrs`:Floats, `moms`:Floats, `lr_anneal`:AnnealFunc=`None`, `mom_anneal`:AnnealFunc=`None`)

Create a phase for training a model during length iterations, following a schedule given by lrs and lr_anneal, moms and mom_anneal. More specifically, the phase will make the learning rate (or momentum) vary from the first value of lrs (or moms) to the second, following lr_anneal (or mom_anneal). If an annealing function is speficied but lrs or moms is a float, it will decay to 0. If no annealing function is specified, the default is a linear annealing if lrs (or moms) is a tuple, a constant parameter if it's a float.

Let's make an example by using this to code SGD with warm restarts.

def fit_sgd_warm(learn, n_cycles, lr, mom, cycle_len, cycle_mult):
    n = len(
    phases = [TrainingPhase(n * (cycle_len * cycle_mult**i), lr, mom, lr_anneal=annealing_cos) for i in range(n_cycles)]
    sched = GeneralScheduler(learn, phases)
    if cycle_mult != 1:
        total_epochs = int(cycle_len * (1 - (cycle_mult)**n_cycles)/(1-cycle_mult)) 
    else: total_epochs = n_cycles * cycle_len
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)
fit_sgd_warm(learn, 3, 1e-3, 0.9, 1, 2)
Total time: 00:16

epoch train_loss valid_loss accuracy
1 0.185262 0.164344 0.945044
2 0.140157 0.129574 0.954367
3 0.124761 0.123591 0.958292
4 0.109466 0.107876 0.964671
5 0.099668 0.091696 0.966143
6 0.087345 0.085187 0.970069
7 0.085803 0.084836 0.971050

class GeneralScheduler[source]

GeneralScheduler(`learn`:Learner, `phases`) :: LearnerCallback

Schedule multiple TrainingPhase for a Learner.

Callback methods

You don't call these yourself - they're called by fastai's Callback system automatically to enable the class's functionality.


on_batch_end(`train`, `kwargs`:Any)

Takes a step in the current phase and prepare the hyperparameters for the next batch.


on_train_begin(`n_epochs`:int, `kwargs`:Any)

Initiates the hyperparameters to the start values of the first phase.