Focal Loss¶
Focal Loss¶
This function implements binary focal loss for tensors of arbitrary size/shape.
- class focal_loss.BinaryFocalLoss(gamma=2, reduction='mean')¶
Bases:
_LossInherits from torch.nn.modules.loss._Loss. Finds the binary focal loss between each element in the input and target tensors.
- Parameters
gamma (float (optional)) – power to raise (1-pt) to when computing focal loss. Default is 2
reduction (string (optional)) –
- “sum”, “mean”, or “none”. If sum, the output will be summed, if mean, the output will
be averaged, if none, no reduction will be applied. Default is mean
- gamma¶
focusing parameter – power to raise (1-pt) to when computing focal loss. Default is 2
- Type
float (optional)
- reduction¶
- “sum”, “mean”, or “none”. If sum, the output will be summed, if mean, the output will
be averaged, if none, no reduction will be applied. Default is mean
- Type
string (optional)
- forward(input_tensor, target)¶
Compute binary focal loss for an input prediction map and target mask.
- Parameters
input_tensor (torch.Tensor) – input prediction map
target (torch.Tensor) – target mask
- Returns
loss_tensor – binary focal loss, summed, averaged, or raw depending on self.reduction
- Return type
torch.Tensor