Custom fastai loss functions

class BaseLoss[source]

BaseLoss(loss_cls, *args, axis=-1, flatten=True, floatify=False, is_2d=True, **kwargs)

Same as loss_cls, but flattens input and target.

Wrapping a general loss function inside of BaseLoss provides extra functionalities to your loss functions:

  • flattens the tensors before trying to take the losses since it's more convenient (with a potential tranpose to put axis at the end)
  • a potential activation method that tells the library if there is an activation fused in the loss (useful for inference and methods such as Learner.get_preds or Learner.predict)
  • a potential decodes method that is used on predictions in inference (for instance, an argmax in classification)

The args and kwargs will be passed to loss_cls during the initialization to instantiate a loss function. axis is put at the end for losses like softmax that are often performed on the last axis. If floatify=True, the targs will be converted to floats (useful for losses that only accept float targets like BCEWithLogitsLoss), and is_2d determines if we flatten while keeping the first dimension (batch size) or completely flatten the input. We want the first for losses like Cross Entropy, and the second for pretty much anything else.

class CrossEntropyLossFlat[source]

CrossEntropyLossFlat(*args, axis=-1, weight=None, ignore_index=-100, reduction='mean', flatten=True, floatify=False, is_2d=True) :: BaseLoss

Same as nn.CrossEntropyLoss, but flattens input and target.

tst = CrossEntropyLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
#nn.CrossEntropy would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.CrossEntropyLoss()(output,target))

#Associated activation is softmax
test_eq(tst.activation(output), F.softmax(output, dim=-1))
#This loss function has a decodes which is argmax
test_eq(tst.decodes(output), output.argmax(dim=-1))
tst = CrossEntropyLossFlat(axis=1)
output = torch.randn(32, 5, 128, 128)
target = torch.randint(0, 5, (32, 128, 128))
_ = tst(output, target)

test_eq(tst.activation(output), F.softmax(output, dim=1))
test_eq(tst.decodes(output), output.argmax(dim=1))

class BCEWithLogitsLossFlat[source]

BCEWithLogitsLossFlat(*args, axis=-1, floatify=True, thresh=0.5, weight=None, reduction='mean', pos_weight=None, flatten=True, is_2d=True) :: BaseLoss

Same as nn.BCEWithLogitsLoss, but flattens input and target.

tst = BCEWithLogitsLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randn(32, 5, 10)
#nn.BCEWithLogitsLoss would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
output = torch.randn(32, 5)
target = torch.randint(0,2,(32, 5))
#nn.BCEWithLogitsLoss would fail with int targets but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))

tst = BCEWithLogitsLossFlat(pos_weight=torch.ones(10))
output = torch.randn(32, 5, 10)
target = torch.randn(32, 5, 10)
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))

#Associated activation is sigmoid
test_eq(tst.activation(output), torch.sigmoid(output))

BCELossFlat[source]

BCELossFlat(*args, axis=-1, floatify=True, weight=None, reduction='mean')

Same as nn.BCELoss, but flattens input and target.

tst = BCELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.BCELoss()(output,target))

MSELossFlat[source]

MSELossFlat(*args, axis=-1, floatify=True, reduction='mean')

Same as nn.MSELoss, but flattens input and target.

tst = MSELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.MSELoss()(output,target))

L1LossFlat[source]

L1LossFlat(*args, axis=-1, floatify=True, reduction='mean')

Same as nn.L1Loss, but flattens input and target.

class LabelSmoothingCrossEntropy[source]

LabelSmoothingCrossEntropy(eps:float=0.1, reduction='mean') :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

On top of the formula we define:

  • a reduction attribute, that will be used when we call Learner.get_preds
  • an activation function that represents the activation fused in the loss (since we use cross entropy behind the scenes). It will be applied to the output of the model when calling Learner.get_preds or Learner.predict
  • a decodes function that converts the output of the model to a format similar to the target (here indices). This is used in Learner.predict and Learner.show_results to decode the predictions

class LabelSmoothingCrossEntropyFlat[source]

LabelSmoothingCrossEntropyFlat(*args, axis=-1, eps=0.1, reduction='mean', flatten=True, floatify=False, is_2d=True) :: BaseLoss

Same as LabelSmoothingCrossEntropy, but flattens input and target.