from functools import wraps
from typing import Union, Callable, Type
import numpy as np
from ..im.axes import broadcast_to_axis, AxesLike, AxesParams, axis_from_dim, resolve_deprecation
from ..im.grid import divide, combine, get_boxes, PatchCombiner, Average
from ..itertools import extract, pmap
from ..im.shape_ops import pad_to_shape, crop_to_shape, pad_to_divisible
from ..im.shape_utils import prepend_dims, extract_dims
__all__ = 'add_extract_dims', 'divisible_shape', 'patches_grid'
[docs]def divisible_shape(divisor: AxesLike, axis: AxesLike = None, padding_values: Union[AxesParams, Callable] = 0,
ratio: AxesParams = 0.5):
"""
Pads an incoming array to be divisible by ``divisor`` along the ``axes``. Afterwards the padding is removed.
Parameters
----------
divisor
a value an incoming array should be divisible by.
axis
axes along which the array will be padded. If None - the last ``len(divisor)`` axes are used.
padding_values
values to pad with. If Callable (e.g. ``numpy.min``) - ``padding_values(x)`` will be used.
ratio
the fraction of the padding that will be applied to the left, ``1 - ratio`` will be applied to the right.
References
----------
`pad_to_divisible`
"""
def decorator(predict):
@wraps(predict)
def wrapper(x, *args, **kwargs):
local_axis = axis_from_dim(axis, x.ndim)
local_divisor, local_ratio = broadcast_to_axis(local_axis, divisor, ratio)
shape = np.array(x.shape)[list(local_axis)]
x = pad_to_divisible(x, local_divisor, local_axis, padding_values, local_ratio)
result = predict(x, *args, **kwargs)
return crop_to_shape(result, shape, local_axis, local_ratio)
return wrapper
return decorator
[docs]def patches_grid(patch_size: AxesLike, stride: AxesLike, axis: AxesLike = None,
padding_values: Union[AxesParams, Callable] = 0, ratio: AxesParams = 0.5,
combiner: Type[PatchCombiner] = Average, get_boxes: Callable = get_boxes):
"""
Divide an incoming array into patches of corresponding ``patch_size`` and ``stride`` and then combine
the predicted patches by aggregating the overlapping regions using the ``combiner`` - Average by default.
If ``padding_values`` is not None, the array will be padded to an appropriate shape to make a valid division.
Afterwards the padding is removed. Otherwise if input cannot be patched without remainder
``ValueError`` is raised.
References
----------
`grid.divide`, `grid.combine`, `pad_to_shape`
"""
valid = padding_values is not None
def decorator(predict):
@wraps(predict)
def wrapper(x, *args, **kwargs):
input_axis = resolve_deprecation(axis, x.ndim, patch_size, stride)
local_size, local_stride = broadcast_to_axis(input_axis, patch_size, stride)
shape = extract(x.shape, input_axis)
if valid:
padded_shape = np.maximum(shape, local_size)
new_shape = padded_shape + (local_stride - padded_shape + local_size) % local_stride
x = pad_to_shape(x, new_shape, input_axis, padding_values, ratio)
elif ((shape - local_size) < 0).any() or ((local_stride - shape + local_size) % local_stride).any():
raise ValueError('Input cannot be patched without remainder.')
patches = pmap(
predict,
divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes),
*args, **kwargs
)
prediction = combine(
patches, extract(x.shape, input_axis), local_stride, axis,
combiner=combiner, get_boxes=get_boxes
)
if valid:
prediction = crop_to_shape(prediction, shape, axis, ratio)
return prediction
return wrapper
return decorator