Source code for dpipe.torch.functional

import warnings
from typing import Union, Callable

import numpy as np
import torch
from torch.nn import functional

from dpipe.im.axes import AxesLike

__all__ = [
    'focal_loss_with_logits', 'linear_focal_loss_with_logits', 'weighted_cross_entropy_with_logits',
    'tversky_loss', 'focal_tversky_loss', 'tversky_loss_with_logits', 'focal_tversky_loss_with_logits',
    'dice_loss', 'dice_loss_with_logits',
    'masked_loss', 'moveaxis', 'softmax',
]


[docs]def focal_loss_with_logits(logits: torch.Tensor, target: torch.Tensor, weight: torch.Tensor = None, gamma: float = 2, alpha: float = 0.25, reduce: Union[Callable, None] = torch.mean): """ Function that measures Focal Loss between target and output logits. Parameters ---------- logits: torch.Tensor tensor of an arbitrary shape. target: torch.Tensor tensor of the same shape as ``logits``. weight: torch.Tensor, None, optional a manual rescaling weight. Must be broadcastable to ``logits``. gamma: float the power of focal loss factor. Defaults to 2. alpha: float, None, optional weighting factor of the focal loss. If ``None``, no weighting will be performed. Defaults to 0.25. reduce: Callable, None, optional the reduction operation to be applied to the final loss. Defaults to ``torch.mean``. If ``None``, no reduction will be performed. References ---------- `Focal Loss <https://arxiv.org/abs/1708.02002>`_ """ if not (target.size() == logits.size()): raise ValueError("Target size ({}) must be the same as logits size ({})".format(target.size(), logits.size())) if alpha is not None: if not (0 <= alpha <= 1): raise ValueError(f'`alpha` should be between 0 and 1, {alpha} was given') rescale_w = (2 * alpha - 1) * target + 1 - alpha else: rescale_w = 1 min_val = - logits.clamp(min=0) max_val = (-logits).clamp(min=0) prob = (min_val + logits).exp() / (min_val.exp() + (min_val + logits).exp()) loss = rescale_w * ((1 - 2 * prob) * target + prob) ** gamma * ( logits - logits * target + max_val + ((-max_val).exp() + (-logits - max_val).exp()).log()) if weight is not None: loss = loss * weight if reduce is not None: loss = reduce(loss) return loss
[docs]def linear_focal_loss_with_logits(logits: torch.Tensor, target: torch.Tensor, gamma: float, beta: float, weight: torch.Tensor = None, reduce: Union[Callable, None] = torch.mean): """ Function that measures Linear Focal Loss between target and output logits. Equals to BinaryCrossEntropy( ``gamma`` * ``logits`` + ``beta``, ``target`` , ``weights``). Parameters ---------- logits: torch.Tensor tensor of an arbitrary shape. target: torch.Tensor tensor of the same shape as ``logits``. gamma: float multiplication coefficient for ``logits`` tensor. beta: float coefficient to be added to all the elements in ``logits`` tensor. weight: torch.Tensor a manual rescaling weight. Must be broadcastable to ``logits``. reduce: Callable, None, optional the reduction operation to be applied to the final loss. Defaults to ``torch.mean``. If None - no reduction will be performed. References ---------- `Focal Loss <https://arxiv.org/abs/1708.02002>`_ """ loss = functional.binary_cross_entropy_with_logits(gamma * logits + beta, target, weight, reduction='none') / gamma if reduce is not None: loss = reduce(loss) return loss
[docs]def weighted_cross_entropy_with_logits(logit: torch.Tensor, target: torch.Tensor, weight: torch.Tensor = None, alpha: float = 1, adaptive: bool = False, reduce: Union[Callable, None] = torch.mean): """ Function that measures Binary Cross Entropy between target and output logits. This version of BCE has additional options of constant or adaptive weighting of positive examples. Parameters ---------- logit: torch.Tensor tensor of an arbitrary shape. target: torch.Tensor tensor of the same shape as ``logits``. weight: torch.Tensor a manual rescaling weight. Must be broadcastable to ``logits``. alpha: float, optional a weight for the positive class examples. adaptive: bool, optional If ``True``, uses adaptive weight ``[N - sum(p_i)] / sum(p_i)`` for a positive class examples. reduce: Callable, None, optional the reduction operation to be applied to the final loss. Defaults to ``torch.mean``. If None - no reduction will be performed. References ---------- `WCE <https://arxiv.org/abs/1707.03237>`_ """ if not (target.size() == logit.size()): raise ValueError("Target size ({}) must be the same as logit size ({})".format(target.size(), logit.size())) if adaptive: # TODO: torch.sigmoid(logit).sum() can be reused pos_weight = alpha * (logit.numel() - (torch.sigmoid(logit)).sum()) / (torch.sigmoid(logit)).sum() else: pos_weight = alpha max_val = - logit.clamp(min=0) loss = - pos_weight * target * (logit + max_val - (max_val.exp() + (logit + max_val).exp()).log()) \ + (1 - target) * (-max_val + (max_val.exp() + (logit + max_val).exp()).log()) if weight is not None: loss = loss * weight if reduce is not None: loss = reduce(loss) return loss
[docs]def dice_loss(pred: torch.Tensor, target: torch.Tensor, epsilon=1e-7): """ References ---------- `Dice Loss <https://arxiv.org/abs/1606.04797>`_ """ if not (target.size() == pred.size()): raise ValueError("Target size ({}) must be the same as logit size ({})".format(target.size(), pred.size())) sum_dims = list(range(1, target.dim())) dice = 2 * torch.sum(pred * target, dim=sum_dims) / (torch.sum(pred ** 2 + target ** 2, dim=sum_dims) + epsilon) loss = 1 - dice return loss.mean()
[docs]def tversky_loss(pred: torch.Tensor, target: torch.Tensor, alpha=0.5, epsilon=1e-7, reduce: Union[Callable, None] = torch.mean): """ References ---------- `Tversky Loss https://arxiv.org/abs/1706.05721`_ """ if not (target.size() == pred.size()): raise ValueError("Target size ({}) must be the same as logit size ({})".format(target.size(), pred.size())) if alpha < 0 or alpha > 1: raise ValueError("Invalid alpha value, expected to be in (0, 1) interval") sum_dims = list(range(1, target.dim())) beta = 1 - alpha intersection = pred*target fps, fns = pred*(1-target), (1-pred)*target numerator = torch.sum(intersection, dim=sum_dims) denumenator = torch.sum(intersection, dim=sum_dims) + alpha*torch.sum(fps, dim=sum_dims) + beta*torch.sum(fns, dim=sum_dims) tversky = numerator / (denumenator + epsilon) loss = 1 - tversky if reduce is not None: loss = reduce(loss) return loss
[docs]def focal_tversky_loss(pred: torch.Tensor, target: torch.Tensor, gamma=4/3, alpha=0.5, epsilon=1e-7): """ References ---------- `Focal Tversky Loss https://arxiv.org/abs/1810.07842`_ """ if gamma <= 1: warnings.warn("Gamma is <=1, to focus on less accurate predictions choose gamma > 1.") tl = tversky_loss(pred, target, alpha, epsilon, reduce=None) return torch.pow(tl, 1/gamma).mean()
def loss_with_logits(criterion: Callable, logit: torch.Tensor, target: torch.Tensor, **kwargs): if not (target.size() == logit.size()): raise ValueError("Target size ({}) must be the same as logit size ({})".format(target.size(), logit.size())) pred = torch.sigmoid(logit) return criterion(pred, target, **kwargs) def dice_loss_with_logits(logit: torch.Tensor, target: torch.Tensor): return loss_with_logits(dice_loss, logit, target) def tversky_loss_with_logits(logit: torch.Tensor, target: torch.Tensor, alpha=0.5): return loss_with_logits(tversky_loss, logit, target, alpha=alpha) def focal_tversky_loss_with_logits(logit: torch.Tensor, target: torch.Tensor, gamma, alpha=0.5): return loss_with_logits(focal_tversky_loss, logit, target, gamma=gamma, alpha=alpha)
[docs]def masked_loss(mask: torch.Tensor, criterion: Callable, prediction: torch.Tensor, target: torch.Tensor, **kwargs): """ Calculates the ``criterion`` between the masked ``prediction`` and ``target``. ``args`` and ``kwargs`` are passed to ``criterion`` as additional arguments. If the ``mask`` is empty - returns 0 wrapped in a torch tensor. """ if not mask.any(): # https://github.com/neuro-ml/deep_pipe/issues/75 return 0 * prediction.flatten()[0] return criterion(prediction[mask], target[mask], **kwargs)
# simply copied from np.moveaxis
[docs]def moveaxis(x: torch.Tensor, source: AxesLike, destination: AxesLike): """ Move axes of a torch.Tensor to new positions. Other axes remain in their original order. """ source = np.core.numeric.normalize_axis_tuple(source, x.ndim, 'source') destination = np.core.numeric.normalize_axis_tuple(destination, x.ndim, 'destination') if len(source) != len(destination): raise ValueError('`source` and `destination` arguments must have ' 'the same number of elements') order = [n for n in range(x.ndim) if n not in source] for dest, src in sorted(zip(destination, source)): order.insert(dest, src) return x.permute(*order)
[docs]def softmax(x: torch.Tensor, axis: AxesLike): """ A multidimensional version of softmax. """ source = np.core.numeric.normalize_axis_tuple(axis, x.ndim, 'axis') dim = len(source) destination = range(-dim, 0) x = moveaxis(x, source, destination) shape = x.shape x = x.reshape(*shape[:-dim], -1) x = functional.softmax(x, -1).reshape(*shape) x = moveaxis(x, destination, source) return x