Source code for dpipe.batch_iter.expiration_pool

from typing import Iterable
from functools import partial

import numpy as np

from .pipeline import Iterator


[docs]class ExpirationPool(Iterator): """ A simple expiration pool for time consuming operations that don't fit into RAM. See `expiration_pool` for details. Examples -------- >>> batch_iter = Infinite( # ... some expensive operations, e.g. loading from disk, or preprocessing ExpirationPool(pool_size, repetitions), # ... here are the values from pool # ... other lightweight operations # ... ) """ def __init__(self, pool_size: int, repetitions: int, iterations: int = 1): super().__init__(partial(expiration_pool, pool_size=pool_size, repetitions=repetitions, iterations=iterations))
[docs]def expiration_pool(iterable: Iterable, pool_size: int, repetitions: int, iterations: int = 1): """ Caches ``pool_size`` items from ``iterable``. The item is removed from cache after it was generated ``repetitions`` times. After an item is removed, a new one is extracted from the ``iterable``. Finally, ``iterations`` controls how many values are generated after a new value is added, thus speeding up the pipeline at early stages. """ assert pool_size > 0 assert repetitions > 0 iterable = enumerate(iterable) def sample_value(): # TODO: use randomdict? idx = np.random.choice(list(value_frequency)) value, frequency = value_frequency[idx] frequency += 1 if frequency == repetitions: del value_frequency[idx] else: value_frequency[idx] = [value, frequency] return value value_frequency = {} # i -> [value, frequency] for idx, value in iterable: value_frequency[idx] = [value, 0] for _ in range(iterations): if value_frequency: yield sample_value() while len(value_frequency) >= pool_size: yield sample_value() while len(value_frequency): yield sample_value()