from contextlib import suppress
from functools import wraps
from itertools import chain
from operator import itemgetter
from typing import Iterable, Sized, Union, Callable, Sequence, Any, Tuple
import numpy as np
[docs]def pam(functions: Iterable[Callable], *args, **kwargs):
"""
Inverse of `map`. Apply a sequence of callables to fixed arguments.
Examples
--------
>>> list(pam([np.sqrt, np.square, np.cbrt], 64))
[8, 4096, 4]
"""
for f in functions:
yield f(*args, **kwargs)
[docs]def zip_equal(*args: Union[Sized, Iterable]) -> Iterable[Tuple]:
"""
zip over the given iterables, but enforce that all of them exhaust simultaneously.
Examples
--------
>>> zip_equal([1, 2, 3], [4, 5, 6]) # ok
>>> zip_equal([1, 2, 3], [4, 5, 6, 7]) # raises ValueError
# ValueError is raised even if the lengths are not known
>>> zip_equal([1, 2, 3], map(np.sqrt, [4, 5, 6])) # ok
>>> zip_equal([1, 2, 3], map(np.sqrt, [4, 5, 6, 7])) # raises ValueError
"""
if not args:
return
lengths = []
all_lengths = []
for arg in args:
try:
lengths.append(len(arg))
all_lengths.append(len(arg))
except TypeError:
all_lengths.append('?')
if lengths and not all(x == lengths[0] for x in lengths):
from .checks import join
raise ValueError(f'The arguments have different lengths: {join(all_lengths)}.')
iterables = [iter(arg) for arg in args]
while True:
result = []
for it in iterables:
with suppress(StopIteration):
result.append(next(it))
if len(result) != len(args):
break
yield tuple(result)
if len(result) != 0:
raise ValueError(f'The iterables did not exhaust simultaneously.')
[docs]def head_tail(iterable: Iterable) -> Tuple[Any, Iterable]:
"""
Split the ``iterable`` into the first and the rest of the elements.
Examples
--------
>>> head, tail = head_tail(map(np.square, [1, 2, 3]))
>>> head, list(tail)
1, [4, 9]
"""
iterable = iter(iterable)
return next(iterable), iterable
[docs]def peek(iterable: Iterable) -> Tuple[Any, Iterable]:
"""
Return the first element from ``iterable`` and the whole iterable.
Notes
-----
The incoming ``iterable`` might be mutated, use the returned iterable instead.
Examples
--------
>>> original_iterable = map(np.square, [1, 2, 3])
>>> head, iterable = peek(original_iterable)
>>> head, list(iterable)
1, [1, 4, 9]
# list(original_iterable) would return [4, 9]
"""
head, tail = head_tail(iterable)
return head, chain([head], tail)
[docs]def lmap(func: Callable, *iterables: Iterable) -> list:
"""Composition of list and map."""
return list(map(func, *iterables))
[docs]def pmap(func: Callable, iterable: Iterable, *args, **kwargs) -> Iterable:
"""
Partial map.
Maps ``func`` over ``iterable`` using ``args`` and ``kwargs`` as additional arguments.
"""
for value in iterable:
yield func(value, *args, **kwargs)
[docs]def dmap(func: Callable, dictionary: dict, *args, **kwargs):
"""
Transform the ``dictionary`` by mapping ``func`` over its values.
``args`` and ``kwargs`` are passed as additional arguments.
Examples
--------
>>> dmap(np.square, {'a': 1, 'b': 2})
{'a': 1, 'b': 4}
"""
return {k: func(v, *args, **kwargs) for k, v in dictionary.items()}
[docs]def zdict(keys: Iterable, values: Iterable) -> dict:
"""Create a dictionary from ``keys`` and ``values``."""
return dict(zip_equal(keys, values))
[docs]def squeeze_first(inputs):
"""Remove the first dimension in case it is singleton."""
if len(inputs) == 1:
inputs = inputs[0]
return inputs
[docs]def flatten(iterable: Iterable, iterable_types: Union[tuple, type] = None) -> list:
"""
Recursively flattens an ``iterable`` as long as it is an instance of ``iterable_types``.
Examples
--------
>>> flatten([1, [2, 3], [[4]]])
[1, 2, 3, 4]
>>> flatten([1, (2, 3), [[4]]])
[1, (2, 3), 4]
>>> flatten([1, (2, 3), [[4]]], iterable_types=(list, tuple))
[1, 2, 3, 4]
"""
if iterable_types is None:
iterable_types = type(iterable)
if not isinstance(iterable, iterable_types):
return [iterable]
return sum((flatten(value, iterable_types) for value in iterable), [])
[docs]def filter_mask(iterable: Iterable, mask: Iterable[bool]) -> Iterable:
"""Filter values from ``iterable`` according to ``mask``."""
return map(itemgetter(1), filter(itemgetter(0), zip_equal(mask, iterable)))
[docs]def negate_indices(indices: Iterable, length: int):
"""Return valid indices for a sequence of len ``length`` that are not present in ``indices``."""
other_indices = np.ones(length, bool)
other_indices[list(indices)] = False
return np.where(other_indices)[0]
[docs]def make_chunks(iterable: Iterable, chunk_size: int, incomplete: bool = True):
"""
Group ``iterable`` into chunks of size ``chunk_size``.
Parameters
----------
iterable
chunk_size
incomplete
whether to yield the last chunk in case it has a smaller size.
"""
chunk = []
for value in iterable:
chunk.append(value)
if len(chunk) == chunk_size:
yield tuple(chunk)
chunk = []
if incomplete and chunk:
yield chunk
[docs]def collect(func: Callable):
"""
Make a function that returns a list from a function that returns an iterator.
Examples
--------
>>> @collect
>>> def squares(n):
>>> for i in range(n):
>>> yield i ** 2
>>>
>>> squares(3)
[1, 4, 9]
"""
@wraps(func)
def wrapper(*args, **kwargs):
return list(func(*args, **kwargs))
wrapper.__annotations__['return'] = list
return wrapper
[docs]def stack(axis: int = 0, dtype: np.dtype = None):
"""
Stack the values yielded by a generator function along a given ``axis``.
``dtype`` (if any) determines the data type of the resulting array.
Examples
--------
>>> @stack(1)
>>> def consecutive(n):
>>> for i in range(n):
>>> yield i, i+1
>>>
>>> consecutive(3)
array([[0, 1, 2],
[1, 2, 3]])
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
value = np.stack(list(func(*args, **kwargs)), axis=axis)
if dtype is not None:
value = value.astype(dtype)
return value
wrapper.__annotations__['return'] = np.ndarray
return wrapper
return decorator
[docs]def recursive_conditional_map(xr, f, condition):
"""Walks recursively through iterable data structure ``xr``. Applies ``f`` on objects that satisfy ``condition``."""
return tuple(f(x) if condition(x) else recursive_conditional_map(x, f, condition) for x in xr)