from typing import Dict, Callable, Union, Sequence

import numpy as np
from scipy.ndimage.morphology import distance_transform_edt, binary_erosion

from ..checks import add_check_bool, add_check_shapes, check_shapes, check_bool
from dpipe.itertools import zip_equal

__all__ = [
    'dice_score', 'sensitivity', 'specificity', 'precision', 'recall', 'iou', 'assd', 'hausdorff_distance',
    'convert_to_aggregated', 'to_aggregated', 'fraction',

[docs]def to_aggregated(metric: Callable, aggregate: Callable = np.mean, *args, **kwargs): """ Converts a ``metric`` that receives two values to a metric that receives two sequences and returns an aggregated value. ``args`` and ``kwargs`` are passed as additional arguments ot ``aggregate``. Examples -------- >>> mean_dice = to_aggregated(dice_score) >>> worst_dice = to_aggregated(dice_score, aggregate=np.min) """ def wrapper(xs: Sequence, ys: Sequence, *args_, **kwargs_): return aggregate([metric(x, y, *args_, **kwargs_) for x, y in zip(xs, ys)], *args, **kwargs) return wrapper
[docs]def fraction(numerator, denominator, empty_val: float = 1): assert numerator <= denominator, f'{numerator}, {denominator}' return numerator / denominator if denominator != 0 else empty_val
[docs]@add_check_bool @add_check_shapes def dice_score(x: np.ndarray, y: np.ndarray) -> float: return fraction(2 * np.sum(x & y), np.sum(x) + np.sum(y))
[docs]@add_check_bool @add_check_shapes def sensitivity(y_true, y_pred): return fraction(np.sum(y_pred & y_true), np.sum(y_true))
[docs]@add_check_bool @add_check_shapes def specificity(y_true, y_pred): tn = np.sum((~y_true) & (~y_pred)) fp = np.sum(y_pred & (~y_true)) return fraction(tn, tn + fp, empty_val=0)
[docs]@add_check_bool @add_check_shapes def recall(y_true, y_pred): tp = np.count_nonzero(np.logical_and(y_pred, y_true)) fn = np.count_nonzero(np.logical_and(~y_pred, y_true)) return fraction(tp, tp + fn, 0)
[docs]@add_check_bool @add_check_shapes def precision(y_true, y_pred): tp = np.count_nonzero(y_pred & y_true) fp = np.count_nonzero(y_pred & ~y_true) return fraction(tp, tp + fp, 0)
[docs]@add_check_bool @add_check_shapes def iou(x: np.ndarray, y: np.ndarray) -> float: return fraction(np.sum(x & y), np.sum(x | y))
def get_area(start, stop): return np.product(np.maximum(stop - start, 0)) def box_iou(a_start_stop, b_start_stop): i = get_area(np.maximum(a_start_stop[0], b_start_stop[0]), np.minimum(a_start_stop[1], b_start_stop[1])) u = get_area(*a_start_stop) + get_area(*b_start_stop) - i if u <= 0: print(f'{a_start_stop} {b_start_stop}') return fraction(i, u) def aggregate_metric(xs, ys, metric, aggregate_fn=np.mean): """Aggregate a `metric` computed on pairs from `xs` and `ys`""" return aggregate_fn([metric(x, y) for x, y in zip_equal(xs, ys)])
[docs]def convert_to_aggregated(metrics: Dict[str, Callable], aggregate_fn: Callable = np.mean, key_prefix: str = '', key_suffix: str = '', *args, **kwargs): return { key_prefix + key + key_suffix: to_aggregated(metric, aggregate_fn, *args, **kwargs) for key, metric in metrics.items() }
def surface_distances(y_true, y_pred, voxel_shape=None): check_bool(y_pred, y_true) check_shapes(y_pred, y_true) pred_border = np.logical_xor(y_pred, binary_erosion(y_pred)) true_border = np.logical_xor(y_true, binary_erosion(y_true)) dt = distance_transform_edt(~true_border, sampling=voxel_shape) return dt[pred_border]
[docs]def assd(x, y, voxel_shape=None): sd1 = surface_distances(y, x, voxel_shape=voxel_shape) sd2 = surface_distances(x, y, voxel_shape=voxel_shape) if sd1.size == 0 and sd2.size == 0: return 0 if sd1.size == 0 or sd2.size == 0: return np.nan return np.mean([sd1.mean(), sd2.mean()])
[docs]def hausdorff_distance(x, y, voxel_shape=None): sd1 = surface_distances(y, x, voxel_shape=voxel_shape) sd2 = surface_distances(x, y, voxel_shape=voxel_shape) if sd1.size == 0 and sd2.size == 0: return 0 if sd1.size == 0 or sd2.size == 0: return np.nan return max(sd1.max(), sd2.max())
[docs]def cross_entropy_with_logits(target: np.ndarray, logits: np.ndarray, axis: int = 1, reduce: Union[Callable, None] = np.mean): """ A numerically stable cross entropy for numpy arrays. ``target`` and ``logits`` must have the same shape except for ``axis``. Parameters ---------- target integer array of shape (d1, ..., di, dj, ..., dn) logits array of shape (d1, ..., di, k, dj, ..., dn) axis the axis containing the logits for each class: ``logits.shape[axis] == k`` reduce the reduction operation to be applied to the final loss. If None - no reduction will be performed. """ main = np.take_along_axis(logits, np.expand_dims(target, axis), axis) max_ = np.maximum(0, logits.max(axis, keepdims=True)) loss = -main + max_ + np.log(np.exp(logits - max_).sum(axis, keepdims=True)) loss = loss.squeeze(axis) if reduce is not None: loss = reduce(loss) return loss