Source code for dpipe.dataset.wrappers

"""
Wrappers change the dataset's behaviour.
See the :doc:`tutorials/wrappers` tutorial for more details.
"""
import functools
from itertools import chain
from types import MethodType, FunctionType
from typing import Sequence, Callable, Iterable
from collections import ChainMap, namedtuple
from pathlib import Path

import numpy as np

from dpipe.checks import join
from dpipe.io import save_numpy, PathLike, load_or_create, load_numpy
from dpipe.itertools import zdict, collect
from dpipe.im.preprocessing import normalize
from .base import Dataset


[docs]class Proxy: """Base class for all wrappers.""" def __init__(self, shadowed): self._shadowed = shadowed def __getattr__(self, name): return getattr(self._shadowed, name) def __dir__(self): return list(set(super().__dir__()) | set(dir(self._shadowed)))
@collect def _get_public_methods(instance): for attr in dir(instance): if not attr.startswith('_') and isinstance(getattr(instance, attr), (MethodType, FunctionType)): yield attr
[docs]def cache_methods(instance, methods: Iterable[str] = None, maxsize: int = None): """Cache the ``instance``'s ``methods``. If ``methods`` is None, all public methods will be cached.""" if methods is None: methods = _get_public_methods(instance) cache = functools.lru_cache(maxsize) new_methods = {method: staticmethod(cache(getattr(instance, method))) for method in methods} proxy = type('Cached', (Proxy,), new_methods) return proxy(instance)
[docs]def cache_methods_to_disk(instance, base_path: PathLike, loader: Callable = load_numpy, saver: Callable = save_numpy, **methods: str): """ Cache the ``instance``'s ``methods`` to disk. Parameters ---------- instance arbitrary object base_path: str the path, all other paths of ``methods`` relative to. methods: str each keyword argument has the form ``method_name=path_to_cache``. The methods are assumed to take a single argument of type ``str``. loader loads a single object given its path. saver: Callable(value, path) saves a single object to the given path. """ base_path = Path(base_path) def decorator(method, folder): method = getattr(instance, method) path = base_path / folder path.mkdir(parents=True, exist_ok=True) @functools.wraps(method) def wrapper(identifier, *args, **kwargs): return load_or_create( path / f'{identifier}.npy', method, identifier, *args, **kwargs, save=saver, load=loader) return staticmethod(wrapper) new_methods = {method: decorator(method, folder) for method, folder in methods.items()} proxy = type('CachedToDisk', (Proxy,), new_methods) return proxy(instance)
[docs]def apply(instance, **methods: Callable): """ Applies a given function to the output of a given method. Parameters ---------- instance arbitrary object methods: Callable each keyword argument has the form ``method_name=func_to_apply``. ``func_to_apply`` is applied to the ``method_name`` method. Examples -------- >>> # normalize will be applied to the output of load_image >>> dataset = apply(base_dataset, load_image=normalize) """ def decorator(method, func): @functools.wraps(method) def wrapper(*args, **kwargs): return func(method(*args, **kwargs)) return staticmethod(wrapper) new_methods = {method: decorator(getattr(instance, method), func) for method, func in methods.items()} proxy = type('Apply', (Proxy,), new_methods) return proxy(instance)
[docs]def set_attributes(instance, **attributes): """ Sets or overwrites attributes with those provided as keyword arguments. Parameters ---------- instance arbitrary object attributes each keyword argument has the form ``attr_name=attr_value``. """ proxy = type('SetAttr', (Proxy,), attributes) return proxy(instance)
[docs]def change_ids(dataset: Dataset, change_id: Callable, methods: Iterable[str] = None) -> Dataset: """ Change the ``dataset``'s ids according to the ``change_id`` function and adapt the provided ``methods`` to work with the new ids. Parameters ---------- dataset: Dataset the dataset to perform ids changing on. change_id: Callable(str) -> str the method which allows change ids. Output ids should be unique as well as old ids. methods: Iterable[str] the list of methods to be adapted. Each method takes a single argument - the identifier. """ if methods is None: methods = _get_public_methods(dataset) assert 'ids' not in methods ids = tuple(map(change_id, dataset.ids)) if len(set(ids)) != len(ids): raise ValueError('The resulting ids are not unique.') new_to_old = zdict(ids, dataset.ids) def decorator(method): @functools.wraps(method) def wrapper(identifier): return method(new_to_old[identifier]) return staticmethod(wrapper) attributes = {method: decorator(getattr(dataset, method)) for method in methods} attributes['ids'] = ids proxy = type('ChangedID', (Proxy,), attributes) return proxy(dataset)
[docs]def merge(*datasets: Dataset, methods: Sequence[str] = None, attributes: Sequence[str] = ()) -> Dataset: """ Merge several ``datasets`` into one by preserving the provided ``methods`` and ``attributes``. Parameters ---------- datasets: Dataset sequence of datasets. methods: Sequence[str], None, optional the list of methods to be preserved. Each method should take an identifier as its first argument. If ``None``, all the common methods will be preserved. attributes: Sequence[str] the list of attributes to be preserved. For each dataset their values should be the same. Default is the empty sequence ``()``. """ if methods is None: methods = set(_get_public_methods(datasets[0])) for dataset in datasets: methods = methods & set(_get_public_methods(dataset)) clash = set(methods) & set(attributes) if clash: raise ValueError(f'Method names clash with attribute names: {join(clash)}.') ids = tuple(id_ for dataset in datasets for id_ in dataset.ids) if len(set(ids)) != len(ids): raise ValueError('The ids are not unique.') preserved_attributes = [] for attribute in attributes: # can't use a set here, because not all attributes can be hashed values = [] for dataset in datasets: value = getattr(dataset, attribute) if value not in values: values.append(value) if len(values) != 1: raise ValueError(f'Datasets have different values of attribute "{attribute}".') preserved_attributes.append(values[0]) def decorator(method_name): def wrapper(identifier, *args, **kwargs): if identifier not in id_to_dataset: raise KeyError(f"This dataset doesn't contain the id {identifier}") return getattr(id_to_dataset[identifier], method_name)(identifier, *args, **kwargs) return wrapper id_to_dataset = ChainMap(*({id_: dataset for id_ in dataset.ids} for dataset in datasets)) Merged = namedtuple('Merged', chain(['ids'], methods, attributes)) return Merged(*chain([ids], map(decorator, methods), preserved_attributes))
[docs]def apply_mask(dataset: Dataset, mask_modality_id: int = -1, mask_value: int = None) -> Dataset: """ Applies the ``mask_modality_id`` modality as the binary mask to the other modalities and remove the mask from sequence of modalities. Parameters ---------- dataset: Dataset dataset which is used in the current task. mask_modality_id: int the index of mask in the sequence of modalities. Default is ``-1``, which means the last modality will be used as the mask. mask_value: int, None, optional the value in the mask to filter other modalities with. If ``None``, greater than zero filtering will be applied. Default is ``None``. Examples -------- >>> modalities = ['flair', 't1', 'brain_mask'] # we are to apply brain mask to other modalities >>> target = 'target' >>> >>> dataset = apply_mask( >>> dataset=Wmh2017( >>> data_path=data_path, >>> modalities=modalities, >>> target=target >>> ), >>> mask_modality_id=-1, >>> mask_value=1 >>> ) """ class MaskedDataset(Proxy): def load_image(self, patient_id): images = self._shadowed.load_image(patient_id) mask = images[mask_modality_id] mask_bin = mask > 0 if mask_value is None else mask == mask_value if not np.sum(mask_bin) > 0: raise ValueError('The obtained mask is empty') images = [image * mask for image in images[:-1]] return np.array(images) @property def n_chans_image(self): return self._shadowed.n_chans_image - 1 return MaskedDataset(dataset)