Source code for dpipe.batch_iter.utils

from typing import Callable, Iterable, Sequence

import numpy as np

from dpipe.im.axes import AxesLike, AxesParams
from dpipe.itertools import lmap, squeeze_first
from dpipe.im import pad_to_shape


[docs]def pad_batch_equal(batch, padding_values: AxesParams = 0, ratio: AxesParams = 0.5): """ Pad each element of ``batch`` to obtain a correctly shaped array. References ---------- `pad_to_shape` """ max_shapes = np.max(lmap(np.shape, batch), axis=0) # if not scalars if max_shapes.size != 0: batch = [pad_to_shape(x, max_shapes, padding_values=padding_values, ratio=ratio) for x in batch] return np.array(batch)
[docs]def unpack_args(func: Callable, *args, **kwargs): """ Returns a function that takes an iterable and unpacks it while calling ``func``. ``args`` and ``kwargs`` are passed to ``func`` as additional arguments. Examples -------- >>> def add(x, y): >>> return x + y >>> >>> add_ = unpack_args(add) >>> add(1, 2) == add_([1, 2]) >>> True """ def wrapper(xs, *args_, **kwargs_): return func(*xs, *args_, *args, **kwargs_, **kwargs) return wrapper
[docs]def multiply(func: Callable, *args, **kwargs): """ Returns a function that takes an iterable and maps ``func`` over it. Useful when multiple batches require the same function. ``args`` and ``kwargs`` are passed to ``func`` as additional arguments. """ def wrapped(xs: Iterable, *args_, **kwargs_) -> tuple: return tuple(func(x, *args_, *args, **kwargs_, **kwargs) for x in xs) return wrapped
[docs]def apply_at(index: AxesLike, func: Callable, *args, **kwargs): """ Returns a function that takes an iterable and applies ``func`` to the values at the corresponding ``index``. ``args`` and ``kwargs`` are passed to ``func`` as additional arguments. Examples -------- >>> first_sqr = apply_at(0, np.square) >>> first_sqr([3, 2, 1]) >>> (9, 2, 1) """ index = set(np.atleast_1d(index).tolist()) def wrapped(xs: Sequence, *args_, **kwargs_) -> tuple: index_ = {i + len(xs) if i < 0 else i for i in index} for idx in index_: if idx < 0 or idx >= len(xs): raise IndexError(f'Index {idx} out of bounds.') return tuple(func(x, *args_, *args, **kwargs_, **kwargs) if i in index_ else x for i, x in enumerate(xs)) return wrapped
[docs]def zip_apply(*functions: Callable, **kwargs): """ Returns a function that takes an iterable and zips ``functions`` over it. ``kwargs`` are passed to each function as additional arguments. Examples -------- >>> zipper = zip_apply(np.square, np.sqrt) >>> zipper([4, 9]) >>> (16, 3) """ def wrapped(xs: Sequence, *args, **kwargs_) -> tuple: return tuple(func(x, *args, **kwargs_, **kwargs) for func, x in zip(functions, xs)) return wrapped
[docs]def random_apply(p: float, func: Callable, *args, **kwargs): """ Returns a function that applies ``func`` with a given probability ``p``. ``args`` and ``kwargs`` are passed to ``func`` as additional arguments. """ def wrapped(*args_, **kwargs_): if np.random.binomial(1, p): return func(*args_, *args, **kwargs_, **kwargs) return squeeze_first(args_) return wrapped
[docs]def sample_args(func: Callable, *args: Callable, **kwargs: Callable): """ Returns a function that samples arguments for ``func`` from ``args`` and ``kwargs``. Each argument in ``args`` and ``kwargs`` must be a callable that samples a random value. Examples -------- >>> from scipy.ndimage import rotate >>> >>> random_rotate = sample_args(rotate, angle=np.random.normal) >>> random_rotate(x) >>> # same as >>> rotate(x, angle=np.random.normal()) """ def wrapped(*args_, **kwargs_): return func(*args_, *([arg() for arg in args]), **kwargs_, **{name: arg() for name, arg in kwargs.items()}) return wrapped