class _A:
def __init__(self, a): self.a = a
@contextmanager
def a_changed(self, v): return replacing_yield(self, 'a', v)
= _A(42)
a with a.a_changed(32):
32)
test_eq(a.a, 42) test_eq(a.a,
Learner, Metrics, Callbacks
You probably want to jump directly to the definition of Learner
.
Utils function
replacing_yield
replacing_yield (o, attr, val)
Context manager to temporarily replace an attribute
mk_metric
mk_metric (m)
Convert m
to an AvgMetric
, unless it’s already a Metric
See the class Metric
below for more information.
save_model
save_model (file, model, opt, with_opt=True, pickle_protocol=2, **torch_save_kwargs)
Save model
to file
along with opt
(if available, and if with_opt
)
file
can be a Path
object, a string or an opened file object. pickle_protocol
and torch_save_kwargs
is passed along to torch.save
load_model
load_model (file, model, opt, with_opt=True, device=None, strict=True, **torch_load_kwargs)
Load model
from file
along with opt
(if available, and if with_opt
)
file
can be a Path
object, a string or an opened file object. If a device
is passed, the model is loaded on it, otherwise it’s loaded on the CPU.
If strict
is True
, the file must exactly contain weights for every parameter key in model
, if strict
is False
, only the keys that are in the saved model are loaded in model
.
You can pass in other kwargs to torch.load
through torch_load_kwargs
.
SkipToEpoch
SkipToEpoch (epoch:int)
Skip training up to epoch
Learner
Group together a model
, some dls
and a loss_func
to handle training
opt_func
will be used to create an optimizer when Learner.fit
is called, with lr
as a default learning rate. splitter
is a function that takes self.model
and returns a list of parameter groups (or just one parameter group if there are no different parameter groups). The default is trainable_params
, which returns all trainable parameters of the model.
cbs
is one or a list of Callback
s to pass to the Learner
. Callback
s are used for every tweak of the training loop. Each Callback
is registered as an attribute of Learner
(with camel case). At creation, all the callbacks in defaults.callbacks
(TrainEvalCallback
, Recorder
and ProgressCallback
) are associated to the Learner
.
metrics
is an optional list of metrics, that can be either functions or Metric
s (see below).
path
and model_dir
are used to save and/or load models. Often path
will be inferred from dls
, but you can override it or pass a Path
object to model_dir
. Make sure you can write in path/model_dir
!
wd
is the default weight decay used when training the model; moms
, the default momentums used in Learner.fit_one_cycle
. wd_bn_bias
controls if weight decay is applied to BatchNorm
layers and bias.
Lastly, train_bn
controls if BatchNorm
layers are trained even when they are supposed to be frozen according to the splitter
. Our empirical experiments have shown that it’s the best behavior for those layers in transfer learning.
PyTorch interop
You can use regular PyTorch functionality for most of the arguments of the Learner
, although the experience will be smoother with pure fastai objects and you will be able to use the full functionality of the library. The expectation is that the training loop will work smoothly even if you did not use fastai end to end. What you might lose are interpretation objects or showing functionality. The list below explains how to use plain PyTorch objects for all the arguments and what you might lose.
The most important is opt_func
. If you are not using a fastai optimizer, you will need to write a function that wraps your PyTorch optimizer in an OptimWrapper
. See the optimizer module for more details. This is to ensure the library’s schedulers/freeze API work with your code.
dls
is aDataLoaders
object, that you can create from standard PyTorch dataloaders. By doing so, you will lose all showing functionality likeshow_batch
/show_results
. You can check the data block API or the mid-level data API tutorial to learn how to use fastai to gather your data!model
is a standard PyTorch model. You can use anyone you like, just make sure it accepts the number of inputs you have in yourDataLoaders
and returns as many outputs as you have targets.loss_func
can be any loss function you like. It needs to be one of fastai’s if you want to useLearn.predict
orLearn.get_preds
, or you will have to implement special methods (see more details after theBaseLoss
documentation).
Training loop
Now let’s look at the main thing the Learner
class implements: the training loop.
Learner.fit
Learner.fit (n_epoch, lr=None, wd=None, cbs=None, reset_opt=False, start_epoch=0)
Fit self.model
for n_epoch
using cbs
. Optionally reset_opt
.
Uses lr
and wd
if they are provided, otherwise use the defaults values given by the lr
and wd
attributes of Learner
.
All the examples use synth_learner
which is a simple Learner
training a linear regression model.
#Training a few epochs should make the model better
= synth_learner(lr=0.1)
learn
learn(_before_epoch)= learn.model.cpu()
learn.model = learn.dls.one_batch()
xb,yb = learn.loss_func(learn.model(xb), yb)
init_loss 10)
learn.fit(= learn.dls.one_batch()
xb,yb = learn.loss_func(learn.model(xb), yb)
final_loss assert final_loss < init_loss, (final_loss,init_loss)
Learner.one_batch
Learner.one_batch (i, b)
Train or evaluate self.model
on batch (xb,yb)
This is an internal method called by Learner.fit
. If passed, i
is the index of this iteration in the epoch. In training mode, this does a full training step on the batch (compute predictions, loss, gradients, update the model parameters and zero the gradients). In validation mode, it stops at the loss computation. Training or validation is controlled internally by the TrainEvalCallback
through the training
attribute.
Nothing is returned, but the attributes x
, y
, pred
, loss
of the Learner
are set with the proper values:
= learn.dls.one_batch()
b 0, b)
learn.one_batch(0])
test_eq(learn.x, b[1])
test_eq(learn.y, b[= learn.model(learn.x)
out
test_eq(learn.pred, out)1])) test_eq(learn.loss, learn.loss_func(out, b[
Learner.all_batches
Learner.all_batches ()
Train or evaluate self.model
on all the batches of self.dl
Learner.create_opt
Learner.create_opt ()
Create an optimizer with default hyper-parameters
This method is called internally to create the optimizer, the hyper-parameters are then adjusted by what you pass to Learner.fit
or your particular schedulers (see callback.schedule
).
= synth_learner(n_train=5, cbs=VerboseCallback())
learn assert learn.opt is None
learn.create_opt()assert learn.opt is not None
0]['lr'], learn.lr) test_eq(learn.opt.hypers[
after_create
= synth_learner(n_train=5, cbs=VerboseCallback(), opt_func=partial(OptimWrapper, opt=torch.optim.Adam))
learn assert learn.opt is None
learn.create_opt()assert learn.opt is not None
0]['lr'], learn.lr) test_eq(learn.opt.hypers[
after_create
= 1
wrapper_lr = synth_learner(n_train=5, cbs=VerboseCallback(), opt_func=partial(OptimWrapper, opt=torch.optim.Adam, lr=wrapper_lr))
learn assert learn.opt is None
learn.create_opt()assert learn.opt is not None
0]['lr'], wrapper_lr) test_eq(learn.opt.hypers[
after_create
Callback handling
We only describe the basic functionality linked to Callback
s here. To learn more about Callback
s and how to write them, check the callback.core module documentation.
Let’s first see how the Callback
s become attributes of Learner
:
#Test init with callbacks
class TstCallback(Callback):
def batch_begin(self): self.learn.a = self.a + 1
= synth_learner()
tst_learn len(tst_learn.cbs), 1)
test_eq(assert hasattr(tst_learn, ('train_eval'))
= synth_learner(cbs=TstCallback())
tst_learn len(tst_learn.cbs), 2)
test_eq(assert hasattr(tst_learn, ('tst'))
Learner.__call__
Learner.__call__ (event_name)
Call event_name
for all Callback
s in self.cbs
This how the Callback
s are called internally. For instance a VerboseCallback
just prints the event names (can be useful for debugging):
= synth_learner(cbs=VerboseCallback())
learn 'after_fit') learn(
after_create
after_fit
Learner.add_cb
Learner.add_cb (cb)
Add cb
to the list of Callback
and register self
as their learner
= synth_learner()
learn
learn.add_cb(TestTrainEvalCallback())len(learn.cbs), 2)
test_eq(assert isinstance(learn.cbs[1], TestTrainEvalCallback)
test_eq(learn.train_eval.learn, learn)
Learner.add_cbs
Learner.add_cbs (cbs)
Add cbs
to the list of Callback
and register self
as their learner
learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()])len(learn.cbs), 4) test_eq(
Learner.added_cbs
Learner.added_cbs (cbs)
= synth_learner()
learn len(learn.cbs), 1)
test_eq(with learn.added_cbs(TestTrainEvalCallback()):
len(learn.cbs), 2) test_eq(
Learner.ordered_cbs
Learner.ordered_cbs (event)
List of Callback
s, in order, for an event
in the training loop
By order, we mean using the internal ordering of the Callback
s (see callback.core
for more information on how it works).
= synth_learner()
learn
learn.add_cb(TestTrainEvalCallback())'before_fit') learn.ordered_cbs(
[TrainEvalCallback, TestTrainEvalCallback]
Learner.remove_cb
Learner.remove_cb (cb)
Add cb
from the list of Callback
and deregister self
as their learner
= synth_learner()
learn
learn.add_cb(TestTrainEvalCallback())= learn.cbs[1]
cb 1])
learn.remove_cb(learn.cbs[len(learn.cbs), 1)
test_eq(assert cb.learn is None
assert not getattr(learn,'test_train_eval',None)
cb
can simply be the class of the Callback
we want to remove (in which case all instances of that callback are removed).
= synth_learner()
learn
learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()])
learn.remove_cb(TestTrainEvalCallback)len(learn.cbs), 1)
test_eq(assert not getattr(learn,'test_train_eval',None)
Learner.remove_cbs
Learner.remove_cbs (cbs)
Remove cbs
from the list of Callback
and deregister self
as their learner
Elements of cbs
can either be types of callbacks or actual callbacks of the Learner
.
= synth_learner()
learn for _ in range(3)])
learn.add_cbs([TestTrainEvalCallback() = learn.cbs[1]
cb 1:])
learn.remove_cbs(learn.cbs[len(learn.cbs), 1) test_eq(
Learner.removed_cbs
Learner.removed_cbs (cbs)
Elements of cbs
can either be types of callbacks or actual callbacks of the Learner
.
= synth_learner()
learn
learn.add_cb(TestTrainEvalCallback())with learn.removed_cbs(learn.cbs[1]):
len(learn.cbs), 1)
test_eq(len(learn.cbs), 2) test_eq(
Learner.show_training_loop
Learner.show_training_loop ()
Show each step in the training loop
At each step, callbacks are shown in order, which can help debugging.
= synth_learner()
learn learn.show_training_loop()
Start Fit
- before_fit : [TrainEvalCallback]
Start Epoch Loop
- before_epoch : []
Start Train
- before_train : [TrainEvalCallback]
Start Batch Loop
- before_batch : []
- after_pred : []
- after_loss : []
- before_backward: []
- before_step : []
- after_step : []
- after_cancel_batch: []
- after_batch : [TrainEvalCallback]
End Batch Loop
End Train
- after_cancel_train: []
- after_train : []
Start Valid
- before_validate: [TrainEvalCallback]
Start Batch Loop
- **CBs same as train batch**: []
End Batch Loop
End Valid
- after_cancel_validate: []
- after_validate : []
End Epoch Loop
- after_cancel_epoch: []
- after_epoch : []
End Fit
- after_cancel_fit: []
- after_fit : []
before_batch_cb
before_batch_cb (f)
Shortcut for creating a Callback on the before_batch
event, which takes and returns xb,yb
In order to change the data passed to your model, you will generally want to hook into the before_batch
event, like so:
class TstCallback(Callback):
def before_batch(self):
self.learn.xb = self.xb + 1000
self.learn.yb = self.yb - 1000
Since that is so common, we provide the before_batch_cb
decorator to make it easier.
@before_batch_cb
def cb(self, xb, yb): return xb+1000,yb-1000
Serializing
Learner.save
Learner.save (file, with_opt=True, pickle_protocol=2)
Save model and optimizer state (if with_opt
) to self.path/self.model_dir/file
file
can be a Path
, a string
or a buffer. pickle_protocol
is passed along to torch.save
.
Learner.load
Learner.load (file, device=None, with_opt=True, strict=True)
Load model and optimizer state (if with_opt
) from self.path/self.model_dir/file
using device
file
can be a Path
, a string
or a buffer. Use device
to load the model/optimizer state on a device different from the one it was saved.
with tempfile.TemporaryDirectory() as d:
= synth_learner(path=d)
learn 1)
learn.fit(
#Test save created a file
'tmp')
learn.save(assert (Path(d)/'models/tmp.pth').exists()
#Test load did load the model
= synth_learner(path=d)
learn1 = learn1.load('tmp')
learn1
test_eq(learn.a, learn1.a)
test_eq(learn.b, learn1.b) test_eq(learn.opt.state_dict(), learn1.opt.state_dict())
Learner.export
Learner.export (fname='export.pkl', pickle_module=<module 'pickle' from '/usr/lib/python3.10/pickle.py'>, pickle_protocol=2)
Export the content of self
without the items and the optimizer state for inference
The Learner
is saved in self.path/fname
, using pickle_protocol
. Note that serialization in Python saves the names of functions, not the code itself. Therefore, any custom code you have for models, data transformation, loss function etc… should be put in a module that you will import in your training environment before exporting, and in your deployment environment before loading it.
load_learner
load_learner (fname, cpu=True, pickle_module=<module 'pickle' from '/usr/lib/python3.10/pickle.py'>)
Load a Learner
object in fname
, by default putting it on the cpu
load_learner
requires all your custom code be in the exact same place as when exporting your Learner
(the main script, or the module you imported it from).
Learner.to_detach
Learner.to_detach (b, cpu=True, gather=True)
Calls to_detach
if self.dl
provides a .to_detach
function otherwise calls global to_detach
fastai provides to_detach
which by default detachs tensor gradients, and gathers (calling maybe_gather
) tensors from all ranks if running in distributed data parallel (DDP) mode.
When running in DDP mode all ranks need to have the same batch size, and DistributedDL
takes care of padding batches as needed; however when gathering all tensors (e.g. for calculating metrics, inference, etc.) we need to discard the padded items. DistributedDL
provides a method to_detach
that removes padding appropriately.
Calling the learner’s to_detach
method will attempt to find a to_detach
method in the learner’s last used DataLoader
dl
and use that one if found, otherwise it will resort to the vanilla to_detach
.
Metric
Metric ()
Blueprint for defining a metric
Metrics can be simple averages (like accuracy) but sometimes their computation is a little bit more complex and can’t be averaged over batches (like precision or recall), which is why we need a special class for them. For simple functions that can be computed as averages over batches, we can use the class AvgMetric
, otherwise you’ll need to implement the following methods.
If your Metric
has state depending on tensors, don’t forget to store it on the CPU to avoid any potential memory leaks.
Metric.reset
Metric.reset ()
Reset inner state to prepare for new computation
Metric.accumulate
Metric.accumulate (learn)
Use learn
to update the state with new results
Metric.value
Metric.value ()
Metric.name
Metric.name ()
AvgMetric
AvgMetric (func)
Average the values of func
taking into account potential different batch sizes
= synth_learner()
learn = AvgMetric(lambda x,y: (x-y).abs().mean())
tst = torch.randn(100),torch.randn(100)
t,u
tst.reset()for i in range(0,100,25):
= t[i:i+25],(u[i:i+25],)
learn.pred,learn.yb
tst.accumulate(learn)-u).abs().mean()) test_close(tst.value, (t
AvgLoss
AvgLoss ()
Average the losses taking into account potential different batch sizes
= AvgLoss()
tst = torch.randn(100)
t
tst.reset()for i in range(0,100,25):
= t[i:i+25],t[i:i+25].mean()
learn.yb,learn.loss
tst.accumulate(learn) test_close(tst.value, t.mean())
AvgSmoothLoss
AvgSmoothLoss (beta=0.98)
Smooth average of the losses (exponentially weighted with beta
)
= AvgSmoothLoss()
tst = torch.randn(100)
t
tst.reset()= tensor(0.)
val for i in range(4):
= t[i*25:(i+1)*25].mean()
learn.loss
tst.accumulate(learn)= val*0.98 + t[i*25:(i+1)*25].mean()*(1-0.98)
val /(1-0.98**(i+1)), tst.value) test_close(val
ValueMetric
ValueMetric (func, metric_name=None)
Use to include a pre-calculated metric value (for instance calculated in a Callback
) and returned by func
def metric_value_fn(): return 5e-3
= ValueMetric(metric_value_fn, 'custom_value_metric')
vm 5e-3)
test_eq(vm.value, 'custom_value_metric')
test_eq(vm.name,
= ValueMetric(metric_value_fn)
vm 'metric_value_fn') test_eq(vm.name,
Recorder –
Recorder
Recorder (add_time=True, train_metrics=False, valid_metrics=True, beta=0.98)
Callback that registers statistics (lr, loss and metrics) during training
By default, metrics are computed on the validation set only, although that can be changed by adjusting train_metrics
and valid_metrics
. beta
is the weight used to compute the exponentially weighted average of the losses (which gives the smooth_loss
attribute to Learner
).
The logger
attribute of a Learner
determines what happens to those metrics. By default, it just print them:
#Test printed output
def tst_metric(out, targ): return F.mse_loss(out, targ)
= synth_learner(n_train=5, metrics=tst_metric, default_cbs=False, cbs=[TrainEvalCallback, Recorder])
learn # pat = r"[tensor\(\d.\d*\), tensor\(\d.\d*\), tensor\(\d.\d*\), 'dd:dd']"
= r"\[\d, \d+.\d+, \d+.\d+, \d+.\d+, '\d\d:\d\d'\]"
pat lambda: learn.fit(1), pat, regex=True) test_stdout(
Internals
Recorder.before_fit
Recorder.before_fit ()
Prepare state for training
Recorder.before_epoch
Recorder.before_epoch ()
Set timer if self.add_time=True
Recorder.before_validate
Recorder.before_validate ()
Reset loss and metrics state
Recorder.after_batch
Recorder.after_batch ()
Update all metrics and records lr and smooth loss in training
Recorder.after_epoch
Recorder.after_epoch ()
Store and log the loss/metric values
Plotting tools
Recorder.plot_loss
Recorder.plot_loss (skip_start=5, with_valid=True, log=False, show_epochs=False, ax=None)
Plot the losses from skip_start
and onward. Optionally log=True
for logarithmic axis, show_epochs=True
for indicate epochs and a matplotlib axis ax
to plot on.
CastToTensor
CastToTensor (after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_cancel_backward=None, after_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None)
Cast Subclassed Tensors to Tensor
Workaround for bug in PyTorch where subclassed tensors, such as TensorBase
, train up to ~20% slower than Tensor
when passed to a model. Added to Learner
by default.
CastToTensor’s order is right before MixedPrecision
so callbacks which make use of fastai’s tensor subclasses still can use them.
If inputs are not a subclassed tensor or tuple of tensors, you may need to cast inputs in Learner.xb
and Learner.yb
to Tensor
via your own callback or in the dataloader before Learner
performs the forward pass.
If the CastToTensor workaround interferes with custom code, it can be removed:
= Learner(...)
learn learn.remove_cb(CastToTensor)
You should verify your inputs are of type Tensor
or implement a cast to Tensor
via a custom callback or dataloader if CastToTensor is removed.
Inference functions
Learner.validate
Learner.validate (ds_idx=1, dl=None, cbs=None)
Validate on dl
with potential new cbs
.
#Test result
= synth_learner(n_train=5, metrics=tst_metric)
learn = learn.validate()
res 0], res[1])
test_eq(res[= learn.dls.valid_ds.tensors
x,y 0], F.mse_loss(learn.model(x), y), 1e-3) test_close(res[
Learner.get_preds
Learner.get_preds (ds_idx:int=1, dl=None, with_input:bool=False, with_decoded:bool=False, with_loss:bool=False, act=None, inner:bool=False, reorder:bool=True, cbs:Callback|MutableSequence|None=None, save_preds:Path=None, save_targs:Path=None, with_preds:bool=True, with_targs:bool=True, concat_dim:int=0, pickle_protocol:int=2)
Get the predictions and targets on the ds_idx
-th dbunchset or dl
, optionally with_input
and with_loss
Type | Default | Details | |
---|---|---|---|
ds_idx | int | 1 | DataLoader to use for predictions if dl is None. 0: train. 1: valid |
dl | NoneType | None | DataLoader to use for predictions, defaults to ds_idx=1 if None |
with_input | bool | False | Return inputs with predictions |
with_decoded | bool | False | Return decoded predictions |
with_loss | bool | False | Return per item loss with predictions |
act | NoneType | None | Apply activation to predictions, defaults to self.loss_func ’s activation |
inner | bool | False | If False, create progress bar, show logger, use temporary cbs |
reorder | bool | True | Reorder predictions on dataset indicies, if applicable |
cbs | Callback | MutableSequence | None | None | Temporary Callback s to apply during prediction |
save_preds | Path | None | Path to save predictions |
save_targs | Path | None | Path to save targets |
with_preds | bool | True | Whether to return predictions |
with_targs | bool | True | Whether to return targets |
concat_dim | int | 0 | Dimension to concatenate returned tensors |
pickle_protocol | int | 2 | Pickle protocol used to save predictions and targets |
Returns | tuple |
with_decoded
will also return the decoded predictions using the decodes
function of the loss function (if it exists). For instance, fastai’s CrossEntropyFlat
takes the argmax or predictions in its decodes.
Depending on the loss_func
attribute of Learner
, an activation function will be picked automatically so that the predictions make sense. For instance if the loss is a case of cross-entropy, a softmax will be applied, or if the loss is binary cross entropy with logits, a sigmoid will be applied. If you want to make sure a certain activation function is applied, you can pass it with act
.
save_preds
and save_targs
should be used when your predictions are too big to fit all in memory. Give a Path
object that points to a folder where the predictions and targets will be saved.
concat_dim
is the batch dimension, where all the tensors will be concatenated.
inner
is an internal attribute that tells get_preds
it’s called internally, inside another training loop, to avoid recursion errors.
If you want to use the option with_loss=True
on a custom loss function, make sure you have implemented a reduction
attribute that supports ‘none’
#Test result
= synth_learner(n_train=5, metrics=tst_metric)
learn = learn.get_preds()
preds,targs = learn.dls.valid_ds.tensors
x,y
test_eq(targs, y)
test_close(preds, learn.model(x))
= learn.get_preds(act = torch.sigmoid)
preds,targs
test_eq(targs, y) test_close(preds, torch.sigmoid(learn.model(x)))
Learner.predict
Learner.predict (item, rm_type_tfms=None, with_input=False)
Prediction on item
, fully decoded, loss function decoded and probabilities
It returns a tuple of three elements with, in reverse order, - the prediction from the model, potentially passed through the activation of the loss function (if it has one) - the decoded prediction, using the potential decodes
method from it - the fully decoded prediction, using the transforms used to build the Datasets
/DataLoaders
rm_type_tfms
is a deprecated argument that should not be used and will be removed in a future version. with_input
will add the decoded inputs to the result.
class _FakeLossFunc(Module):
= 'none'
reduction def forward(self, x, y): return F.mse_loss(x,y)
def activation(self, x): return x+1
def decodes(self, x): return 2*x
class _Add1(Transform):
def encodes(self, x): return x+1
def decodes(self, x): return x-1
= synth_learner(n_train=5)
learn = TfmdDL(Datasets(torch.arange(50), tfms = [L(), [_Add1()]]))
dl = DataLoaders(dl, dl)
learn.dls = _FakeLossFunc()
learn.loss_func
= tensor([2.])
inp = learn.model(inp).detach()+1 #applying model + activation
out = 2*out #decodes from loss function
dec = dec-1 #decodes from _Add1
full_dec
test_eq(learn.predict(inp), [full_dec,dec,out])=True), [inp,full_dec,dec,out]) test_eq(learn.predict(inp, with_input
Learner.show_results
Learner.show_results (ds_idx=1, dl=None, max_n=9, shuffle=True, **kwargs)
Show some predictions on ds_idx
-th dataset or dl
Will show max_n
samples (unless the batch size of ds_idx
or dl
is less than max_n
, in which case it will show as many samples) and shuffle
the data unless you pass false
to that flag. kwargs
are application-dependent.
We can’t show an example on our synthetic Learner
, but check all the beginners tutorials which will show you how that method works across applications.
The last functions in this section are used internally for inference, but should be less useful to you.
Learner.no_logging
Learner.no_logging ()
= synth_learner(n_train=5, metrics=tst_metric)
learn with learn.no_logging():
lambda: learn.fit(1), '')
test_stdout(print) test_eq(learn.logger,
Learner.loss_not_reduced
Learner.loss_not_reduced ()
This requires your loss function to either have a reduction
attribute or a reduction
argument (like all fastai and PyTorch loss functions).
Transfer learning
Learner.unfreeze
Learner.unfreeze ()
Unfreeze the entire model
Learner.freeze
Learner.freeze ()
Freeze up to last parameter group
Learner.freeze_to
Learner.freeze_to (n)
Freeze parameter groups up to n
TTA
Learner.tta
Learner.tta (ds_idx=1, dl=None, n=4, item_tfms=None, batch_tfms=None, beta=0.25, use_max=False)
Return predictions on the ds_idx
dataset or dl
using Test Time Augmentation
In practice, we get the predictions n
times with the transforms of the training set and average those. The final predictions are (1-beta)
multiplied by this average + beta
multiplied by the predictions obtained with the transforms of the dataset. Set beta
to None
to get a tuple of the predictions and tta results. You can also use the maximum of all predictions instead of an average by setting use_max=True
.
If you want to use new transforms, you can pass them with item_tfms
and batch_tfms
.