Source code for dpipe.im.augmentation
from functools import partial
import numpy as np
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
from dpipe.itertools import extract
from .utils import apply_along_axes
from .axes import AxesLike, axis_from_dim
[docs]def elastic_transform(x: np.ndarray, amplitude: float, axis: AxesLike = None, order: int = 1):
"""Apply a gaussian elastic distortion with a given amplitude to a tensor along the given axes."""
axis = axis_from_dim(axis, x.ndim)
grid_shape = extract(x.shape, axis)
deltas = [gaussian_filter(np.random.uniform(-amplitude, amplitude, grid_shape), 1) for _ in grid_shape]
grid = np.mgrid[tuple(map(slice, grid_shape))] + deltas
return apply_along_axes(partial(map_coordinates, coordinates=grid, order=order), x, axis)