Source code for dpipe.batch_iter.pipeline

import multiprocessing
from functools import partial
from itertools import islice
from typing import Iterable, Callable, Union

import numpy as np

from ..itertools import zip_equal
from ..im.axes import AxesParams
from .utils import pad_batch_equal

__all__ = [
    'Infinite',
    'Threads', 'Loky', 'Iterator',
    'combine_batches', 'combine_to_arrays', 'combine_pad',
]


[docs]def combine_batches(inputs): """ Combines tuples from ``inputs`` into batches: [(x, y), (x, y)] -> [(x, x), (y, y)] """ return tuple(zip_equal(*inputs))
[docs]def combine_to_arrays(inputs): """ Combines tuples from ``inputs`` into batches of numpy arrays. """ return tuple(map(np.array, combine_batches(inputs)))
[docs]def combine_pad(inputs, padding_values: AxesParams = 0, ratio: AxesParams = 0.5): """ Combines tuples from ``inputs`` into batches and pads each batch in order to obtain a correctly shaped numpy array. Parameters ---------- inputs padding_values values to pad with. If Callable (e.g. `numpy.min`) - ``padding_values(x)`` will be used. ratio the fraction of the padding that will be applied to the left, ``1.0 - ratio`` will be applied to the right. By default ``0.5 - ratio``, it is applied uniformly to the left and right. References ---------- `pad_to_shape` """ batches = combine_batches(inputs) padding_values = np.broadcast_to(padding_values, [len(batches)]) return tuple(pad_batch_equal(x, values, ratio) for x, values in zip(batches, padding_values))
class Transform: component = None
[docs]class Infinite: """ Combine ``source`` and ``transformers`` into a batch iterator that yields batches of size ``batch_size``. Parameters ---------- source: Iterable an infinite iterable. transformers: Callable the callable that transforms the objects generated by the previous element of the pipeline. batch_size: int, Callable the size of batch. batches_per_epoch: int the number of batches to yield each epoch. buffer_size: int the number of objects to keep buffered in each pipeline element. Default is 1. combiner: Callable combines chunks of single batches in multiple batches, e.g. combiner([(x, y), (x, y)]) -> ([x, x], [y, y]). Default is `combine_to_arrays`. kwargs: additional keyword arguments passed to the ``combiner``. References ---------- See the :doc:`tutorials/batch_iter` tutorial for more details. """ def __init__(self, source: Iterable, *transformers: Union[Callable, Transform], batch_size: Union[int, Callable], batches_per_epoch: int, buffer_size: int = 1, combiner: Callable = combine_to_arrays, **kwargs): if batches_per_epoch <= 0: raise ValueError(f'Expected a positive amount of batches per epoch, but got {batches_per_epoch}') if not isinstance(combiner, Transform): combiner = Threads(partial(combiner, **kwargs)) elif kwargs: raise ValueError('The `combiner` is already wrapped in a `Transform`, passing `kwargs` has no effect') self.batches_per_epoch = batches_per_epoch self.pipeline = None self._pipeline_factory = lambda: wrap_pipeline( source, *transformers, self._make_stacker(batch_size), combiner, buffer_size=buffer_size ) @staticmethod def _make_stacker(batch_size): if callable(batch_size): should_add = batch_size elif isinstance(batch_size, int): if batch_size <= 0: raise ValueError(f'`batch_size` must be greater than zero, not {batch_size}.') def should_add(chunk, item): return len(chunk) < batch_size else: raise TypeError(f'`batch_size` must be either int or callable, not {type(batch_size)}.') def stacker(iterable): chunk = [] for value in iterable: if not chunk or should_add(chunk, value): chunk.append(value) else: yield chunk chunk = [value] if chunk: yield chunk return Iterator(stacker)
[docs] def close(self): """Stop all background processes.""" self.__exit__(None, None, None)
@property def closing_callback(self): """ A callback to make this interface compatible with `Lightning` which allows for a safe release of resources Examples -------- >>> batch_iter = Infinite(...) >>> trainer = Trainer(callbacks=[batch_iter.closing_callback, ...]) """ from lightning.pytorch.callbacks import Callback class ClosingCallback(Callback): def teardown(self, trainer, pl_module, stage): this.close() def on_exception(self, trainer, pl_module, exception): this.close() this = self return ClosingCallback() def __iter__(self): return self() def __call__(self): if self.pipeline is None: self.pipeline = self._pipeline_factory() if not self.pipeline.pipeline_active: self.__enter__() return islice(self.pipeline, self.batches_per_epoch) def __enter__(self): if self.pipeline is None: self.pipeline = self._pipeline_factory() self.pipeline.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): if self.pipeline is not None: self.pipeline, pipeline = None, self.pipeline return pipeline.__exit__(exc_type, exc_val, exc_tb) def __del__(self): self.close()
def wrap_pipeline(source, *transformers, buffer_size=1): from ._pdp import Pipeline, ComponentDescription, Source, One2One def wrap(o): if isinstance(o, Transform): return o.component if not isinstance(o, ComponentDescription): return One2One(o, buffer_size=buffer_size) return o if not isinstance(source, ComponentDescription): source = Source(source, buffer_size=buffer_size) return Pipeline(source, *map(wrap, transformers))
[docs]class Iterator(Transform): """ Apply ``transform`` to the iterator of values that flow through the batch iterator. Parameters ---------- transform: Callable(Iterable) -> Iterable a function that takes an iterable and yields transformed values. n_workers: int the number of threads to which ``transform`` will be moved. buffer_size: int the number of objects to keep buffered. args: additional positional arguments passed to ``transform``. kwargs: additional keyword arguments passed to ``transform``. References ---------- See the :doc:`tutorials/batch_iter` tutorial for more details. """ def __init__(self, transform: Callable, *args, n_workers: int = 1, buffer_size: int = 1, **kwargs): from ._pdp import ComponentDescription, start_iter assert n_workers > 0 assert buffer_size > 0 self.component = ComponentDescription(partial( start_iter, transform=transform, n_workers=n_workers, args=args, kwargs=kwargs ), n_workers, buffer_size)
[docs]class Threads(Iterator): """ Apply ``func`` concurrently to each object in the batch iterator by moving it to ``n_workers`` threads. Parameters ---------- transform: Callable(Iterable) -> Iterable a function that takes an iterable and yields transformed values. n_workers: int the number of threads to which ``transform`` will be moved. buffer_size: int the number of objects to keep buffered. args: additional positional arguments passed to ``transform``. kwargs: additional keyword arguments passed to ``transform``. References ---------- See the :doc:`tutorials/batch_iter` tutorial for more details. """ def __init__(self, func: Callable, *args, n_workers: int = 1, buffer_size: int = 1, **kwargs): def transform_map(iterable): for value in iterable: yield func(value, *args, **kwargs) super().__init__(transform_map, n_workers=n_workers, buffer_size=buffer_size)
[docs]class Loky(Transform): """ Apply ``func`` concurrently to each object in the batch iterator by moving it to ``n_workers`` processes. Parameters ---------- transform: Callable(Iterable) -> Iterable a function that takes an iterable and yields transformed values. n_workers: int the number of threads to which ``transform`` will be moved. buffer_size: int the number of objects to keep buffered. args: additional positional arguments passed to ``transform``. kwargs: additional keyword arguments passed to ``transform``. Notes ----- Process-based parallelism is implemented with the ``loky`` backend. References ---------- See the :doc:`tutorials/batch_iter` tutorial for more details. """ def __init__(self, func: Callable, *args, n_workers: int = 1, buffer_size: int = 1, **kwargs): from ._pdp import start_loky, ComponentDescription if n_workers < 0: n_workers = max(1, multiprocessing.cpu_count() + n_workers + 1) assert n_workers > 0 assert buffer_size > 0 self.component = ComponentDescription(partial( start_loky, transform=func, n_workers=n_workers, args=args, kwargs=kwargs ), n_workers, buffer_size)