Implementation of a flexible training API

TrainingPhase and General scheduler

Creates a scheduler that lets you train a model with following different TrainingPhase.

class TrainingPhase[source]

TrainingPhase(length:int, lrs:Floats, moms:Floats, lr_anneal:AnnealFunc=None, mom_anneal:AnnealFunc=None)

Create a phase for training a model during length iterations, following a schedule given by lrs and lr_anneal, moms and mom_anneal. More specifically, the phase will make the learning rate (or momentum) vary from the first value of lrs (or moms) to the second, following lr_anneal (or mom_anneal). If an annealing function is speficied but lrs or moms is a float, it will decay to 0. If no annealing function is specified, the default is a linear annealing if lrs (or moms) is a tuple, a constant parameter if it's a float.

class GeneralScheduler[source]

GeneralScheduler(learn:Learner, phases:Collection[TrainingPhase]) :: Callback

Schedule multiple TrainingPhase for a Learner.

on_batch_end[source]

on_batch_end(train, kwargs:Any)

Takes a step in the current phase and prepare the hyperparameters for the next batch.

on_train_begin[source]

on_train_begin(n_epochs:int, kwargs:Any)

Initiates the hyperparameters to the start values of the first phase.

Let's make an example by using this to code SGD with warm restarts.

def fit_sgd_warm(learn, n_cycles, lr, mom, cycle_len, cycle_mult):
    n = len(learn.data.train_dl)
    phases = [TrainingPhase(n * (cycle_len * cycle_mult**i), lr, mom, lr_anneal=annealing_cos) for i in range(n_cycles)]
    sched = GeneralScheduler(learn, phases)
    learn.callbacks.append(sched)
    if cycle_mult != 1:
        total_epochs = int(cycle_len * (1 - (cycle_mult)**n_cycles)/(1-cycle_mult)) 
    else: total_epochs = n_cycles * cycle_len
    learn.fit(total_epochs)
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = Learner(data, simple_cnn((3,16,16,2)))
fit_sgd_warm(learn, 3, 1e-3, 0.9, 1, 2)
Total time: 00:16
epoch  train loss  valid loss
0      0.203685    0.176289    (00:02)
1      0.139156    0.147694    (00:02)
2      0.132314    0.131610    (00:02)
3      0.118946    0.118343    (00:02)
4      0.116849    0.105648    (00:02)
5      0.105146    0.105442    (00:02)
6      0.099159    0.102690    (00:02)

learn.recorder.plot_lr()