= CrossEntropyLossFlat(reduction='none')
tst = torch.randn(32, 5, 10)
output = torch.randint(0, 10, (32,5))
target #nn.CrossEntropy would fail with those two tensors, but not our flattened version.
= tst(output, target)
_
lambda x: nn.CrossEntropyLoss()(output,target))
test_fail(
#Associated activation is softmax
=-1))
test_eq(tst.activation(output), F.softmax(output, dim#This loss function has a decodes which is argmax
=-1)) test_eq(tst.decodes(output), output.argmax(dim
Loss Functions
BaseLoss
BaseLoss (loss_cls, *args, axis:int=-1, flatten:bool=True, floatify:bool=False, is_2d:bool=True, **kwargs)
Same as loss_cls
, but flattens input and target.
Type | Default | Details | |
---|---|---|---|
loss_cls | Uninitialized PyTorch-compatible loss | ||
args | |||
axis | int | -1 | |
flatten | bool | True | |
floatify | bool | False | |
is_2d | bool | True | |
kwargs |
Wrapping a general loss function inside of BaseLoss
provides extra functionalities to your loss functions:
- flattens the tensors before trying to take the losses since it’s more convenient (with a potential tranpose to put
axis
at the end) - a potential
activation
method that tells the library if there is an activation fused in the loss (useful for inference and methods such asLearner.get_preds
orLearner.predict
) - a potential
decodes
method that is used on predictions in inference (for instance, an argmax in classification)
The args
and kwargs
will be passed to loss_cls
during the initialization to instantiate a loss function. axis
is put at the end for losses like softmax that are often performed on the last axis. If floatify=True
, the targs
will be converted to floats (useful for losses that only accept float targets like BCEWithLogitsLoss
), and is_2d
determines if we flatten while keeping the first dimension (batch size) or completely flatten the input. We want the first for losses like Cross Entropy, and the second for pretty much anything else.
CrossEntropyLossFlat
CrossEntropyLossFlat (*args, axis:int=-1, weight=None, ignore_index=-100, reduction='mean', flatten:bool=True, floatify:bool=False, is_2d:bool=True)
Same as nn.CrossEntropyLoss
, but flattens input and target.
#In a segmentation task, we want to take the softmax over the channel dimension
= CrossEntropyLossFlat(axis=1)
tst = torch.randn(32, 5, 128, 128)
output = torch.randint(0, 5, (32, 128, 128))
target = tst(output, target)
_
=1))
test_eq(tst.activation(output), F.softmax(output, dim=1)) test_eq(tst.decodes(output), output.argmax(dim
Focal Loss is the same as cross entropy except easy-to-classify observations are down-weighted in the loss calculation. The strength of down-weighting is proportional to the size of the gamma
parameter. Put another way, the larger gamma
the less the easy-to-classify observations contribute to the loss.
FocalLossFlat
FocalLossFlat (*args, gamma:float=2.0, axis:int=-1, weight=None, reduction='mean', **kwargs)
Same as CrossEntropyLossFlat but with focal paramter, gamma
. Focal loss is introduced by Lin et al. https://arxiv.org/pdf/1708.02002.pdf. Note the class weighting factor in the paper, alpha, can be implemented through pytorch weight
argument passed through to F.cross_entropy.
FocalLoss
FocalLoss (gamma:float=2.0, weight:Tensor=None, reduction:str='mean')
Same as nn.Module
, but no need for subclasses to call super().__init__
Type | Default | Details | |
---|---|---|---|
gamma | float | 2.0 | Focusing parameter. Higher values down-weight easy examples’ contribution to loss |
weight | Tensor | None | Manual rescaling weight given to each class |
reduction | str | mean | PyTorch reduction to apply to the output |
#Compare focal loss with gamma = 0 to cross entropy
= FocalLossFlat(gamma=0)
fl = CrossEntropyLossFlat()
ce = torch.randn(32, 5, 10)
output = torch.randint(0, 10, (32,5))
target
test_close(fl(output, target), ce(output, target))#Test focal loss with gamma > 0 is different than cross entropy
= FocalLossFlat(gamma=2)
fl test_ne(fl(output, target), ce(output, target))
#In a segmentation task, we want to take the softmax over the channel dimension
= FocalLossFlat(gamma=0, axis=1)
fl = CrossEntropyLossFlat(axis=1)
ce = torch.randn(32, 5, 128, 128)
output = torch.randint(0, 5, (32, 128, 128))
target =1e-4)
test_close(fl(output, target), ce(output, target), eps=1))
test_eq(fl.activation(output), F.softmax(output, dim=1)) test_eq(fl.decodes(output), output.argmax(dim
BCEWithLogitsLossFlat
BCEWithLogitsLossFlat (*args, axis:int=-1, floatify:bool=True, thresh:float=0.5, weight=None, reduction='mean', pos_weight=None, flatten:bool=True, is_2d:bool=True)
Same as nn.BCEWithLogitsLoss
, but flattens input and target.
= BCEWithLogitsLossFlat()
tst = torch.randn(32, 5, 10)
output = torch.randn(32, 5, 10)
target #nn.BCEWithLogitsLoss would fail with those two tensors, but not our flattened version.
= tst(output, target)
_ lambda x: nn.BCEWithLogitsLoss()(output,target))
test_fail(= torch.randn(32, 5)
output = torch.randint(0,2,(32, 5))
target #nn.BCEWithLogitsLoss would fail with int targets but not our flattened version.
= tst(output, target)
_ lambda x: nn.BCEWithLogitsLoss()(output,target))
test_fail(
= BCEWithLogitsLossFlat(pos_weight=torch.ones(10))
tst = torch.randn(32, 5, 10)
output = torch.randn(32, 5, 10)
target = tst(output, target)
_ lambda x: nn.BCEWithLogitsLoss()(output,target))
test_fail(
#Associated activation is sigmoid
test_eq(tst.activation(output), torch.sigmoid(output))
BCELossFlat
BCELossFlat (*args, axis:int=-1, floatify:bool=True, weight=None, reduction='mean')
Same as nn.BCELoss
, but flattens input and target.
= BCELossFlat()
tst = torch.sigmoid(torch.randn(32, 5, 10))
output = torch.randint(0,2,(32, 5, 10))
target = tst(output, target)
_ lambda x: nn.BCELoss()(output,target)) test_fail(
MSELossFlat
MSELossFlat (*args, axis:int=-1, floatify:bool=True, reduction='mean')
Same as nn.MSELoss
, but flattens input and target.
= MSELossFlat()
tst = torch.sigmoid(torch.randn(32, 5, 10))
output = torch.randint(0,2,(32, 5, 10))
target = tst(output, target)
_ lambda x: nn.MSELoss()(output,target)) test_fail(
L1LossFlat
L1LossFlat (*args, axis=-1, floatify=True, reduction='mean')
Same as nn.L1Loss
, but flattens input and target.
LabelSmoothingCrossEntropy
LabelSmoothingCrossEntropy (eps:float=0.1, weight:Tensor=None, reduction:str='mean')
Same as nn.Module
, but no need for subclasses to call super().__init__
Type | Default | Details | |
---|---|---|---|
eps | float | 0.1 | The weight for the interpolation formula |
weight | Tensor | None | Manual rescaling weight given to each class passed to F.nll_loss |
reduction | str | mean | PyTorch reduction to apply to the output |
= LabelSmoothingCrossEntropy()
lmce = torch.randn(32, 5, 10)
output = torch.randint(0, 10, (32,5))
target 0,1), target.flatten()), lmce(output.transpose(-1,-2), target)) test_close(lmce(output.flatten(
On top of the formula we define:
- a
reduction
attribute, that will be used when we callLearner.get_preds
weight
attribute to pass to BCE.- an
activation
function that represents the activation fused in the loss (since we use cross entropy behind the scenes). It will be applied to the output of the model when callingLearner.get_preds
orLearner.predict
- a
decodes
function that converts the output of the model to a format similar to the target (here indices). This is used inLearner.predict
andLearner.show_results
to decode the predictions
LabelSmoothingCrossEntropyFlat
LabelSmoothingCrossEntropyFlat (*args, axis:int=-1, eps=0.1, reduction='mean', flatten:bool=True, floatify:bool=False, is_2d:bool=True)
Same as LabelSmoothingCrossEntropy
, but flattens input and target.
#These two should always equal each other since the Flat version is just passing data through
= LabelSmoothingCrossEntropy()
lmce = LabelSmoothingCrossEntropyFlat()
lmce_flat = torch.randn(32, 5, 10)
output = torch.randint(0, 10, (32,5))
target -1,-2), target), lmce_flat(output,target)) test_close(lmce(output.transpose(
We present a general Dice
loss for segmentation tasks. It is commonly used together with CrossEntropyLoss
or FocalLoss
in kaggle competitions. This is very similar to the DiceMulti
metric, but to be able to derivate through, we replace the argmax
activation by a softmax
and compare this with a one-hot encoded target mask. This function also adds a smooth
parameter to help numerical stabilities in the intersection over union division. If your network has problem learning with this DiceLoss, try to set the square_in_union
parameter in the DiceLoss constructor to True
.
DiceLoss
DiceLoss (axis:int=1, smooth:float=1e-06, reduction:str='sum', square_in_union:bool=False)
Dice loss for segmentation
Type | Default | Details | |
---|---|---|---|
axis | int | 1 | Class axis |
smooth | float | 1e-06 | Helps with numerical stabilities in the IoU division |
reduction | str | sum | PyTorch reduction to apply to the output |
square_in_union | bool | False | Squares predictions to increase slope of gradients |
= DiceLoss()
dl = tensor( [[[1, 0, 2],
_x 2, 2, 1]]])
[= tensor([[[[0, 1, 0],
_one_hot_x 0, 0, 0]],
[1, 0, 0],
[[0, 0, 1]],
[0, 0, 1],
[[1, 1, 0]]]])
[3), _one_hot_x) test_eq(dl._one_hot(_x,
= DiceLoss()
dl = tensor([[[[2., 1.],
model_output 1., 5.]],
[1, 2.],
[[3., 1.]],
[3., 0],
[[4., 3.]]]])
[= tensor([[[2, 1],
target 2, 0]]])
[= dl(model_output, target)
dl_out test_eq(dl.decodes(model_output), target)
= DiceLoss(reduction="mean")
dl #identical masks
= tensor([[[.1], [.1], [100.]]])
model_output = tensor([[2]])
target 0)
test_close(dl(model_output, target),
#50% intersection
= tensor([[[.1, 100.], [.1, .1], [100., .1]]])
model_output = tensor([[2, 1]])
target .66, eps=0.01) test_close(dl(model_output, target),
As a test case for the dice loss consider satellite image segmentation. Let us say we have three classes: Background (0), River (1) and Road (2). Let us look at a specific target
= torch.zeros(100,100)
target 5] = 1
target[:,50] = 2
target[:,; plt.imshow(target)
Nearly everything is background in this example, and we have a thin river at the left of the image as well as a thin road in the middle of the image. If all our data looks similar to this, we say that there is a class imbalance, meaning that some classes (like river and road) appear relatively infrequently. If our model just predicted “background” (i.e. the value 0) for all pixels, it would be correct for most pixels. But this would be a bad model and the diceloss should reflect that
= torch.zeros(3, 100,100)
model_output_all_background # assign probability 1 to class 0 everywhere
# to get probability 1, we just need a high model output before softmax gets applied
0,:,:] = 100 model_output_all_background[
# add a batch dimension
= torch.unsqueeze(model_output_all_background,0)
model_output_all_background = torch.unsqueeze(target,0) target
Our dice score should be around 1/3 here, because the “background” class is predicted correctly (and that for nearly every pixel), but the other two clases are never predicted correctly. Dice score of 1/3 means dice loss of 1 - 1/3 = 2/3:
0.67, eps=0.01) test_close(dl(model_output_all_background, target),
If the model would predict everything correctly, the dice loss should be zero:
= torch.zeros(3, 100,100)
correct_model_output 0,:,:] = 100
correct_model_output[0,:,5] = 0
correct_model_output[0,:,50] = 0
correct_model_output[1,:,5] = 100
correct_model_output[2,:,50] = 100
correct_model_output[= torch.unsqueeze(correct_model_output, 0) correct_model_output
0) test_close(dl(correct_model_output, target),
You could easily combine this loss with FocalLoss
defining a CombinedLoss
, to balance between global (Dice) and local (Focal) features on the target mask.
class CombinedLoss:
"Dice and Focal combined"
def __init__(self, axis=1, smooth=1., alpha=1.):
store_attr()self.focal_loss = FocalLossFlat(axis=axis)
self.dice_loss = DiceLoss(axis, smooth)
def __call__(self, pred, targ):
return self.focal_loss(pred, targ) + self.alpha * self.dice_loss(pred, targ)
def decodes(self, x): return x.argmax(dim=self.axis)
def activation(self, x): return F.softmax(x, dim=self.axis)
= CombinedLoss()
cl = torch.randn(32, 4, 5, 10)
output = torch.randint(0,2,(32, 5, 10))
target = cl(output, target) _
# Tests to catch future changes to pickle which cause some loss functions to be 'unpicklable'.
# This causes problems with `Learner.export` as the model can't be pickled with these particular loss funcitons.
= [
losses_picklable True),
(BCELossFlat(), True),
(BCEWithLogitsLossFlat(), True),
(CombinedLoss(), True),
(CrossEntropyLossFlat(), True),
(DiceLoss(), True),
(FocalLoss(), True),
(FocalLossFlat(), True),
(L1LossFlat(), True),
(LabelSmoothingCrossEntropyFlat(), True),
(LabelSmoothingCrossEntropy(), True),
(MSELossFlat(),
]
for loss, picklable in losses_picklable:
try:
=2)
pickle.dumps(loss, protocolexcept (pickle.PicklingError, TypeError) as e:
if picklable:
# Loss was previously picklable but isn't currently
raise e