Focal Loss

Focal Loss

This function implements binary focal loss for tensors of arbitrary size/shape.

class focal_loss.BinaryFocalLoss(gamma=2, reduction='mean')

Bases: _Loss

Inherits from torch.nn.modules.loss._Loss. Finds the binary focal loss between each element in the input and target tensors.

Parameters
  • gamma (float (optional)) – power to raise (1-pt) to when computing focal loss. Default is 2

  • reduction (string (optional)) –

    “sum”, “mean”, or “none”. If sum, the output will be summed, if mean, the output will

    be averaged, if none, no reduction will be applied. Default is mean

gamma

focusing parameter – power to raise (1-pt) to when computing focal loss. Default is 2

Type

float (optional)

reduction
“sum”, “mean”, or “none”. If sum, the output will be summed, if mean, the output will

be averaged, if none, no reduction will be applied. Default is mean

Type

string (optional)

forward(input_tensor, target)

Compute binary focal loss for an input prediction map and target mask.

Parameters
  • input_tensor (torch.Tensor) – input prediction map

  • target (torch.Tensor) – target mask

Returns

loss_tensor – binary focal loss, summed, averaged, or raw depending on self.reduction

Return type

torch.Tensor

Indices and tables