Training

Checkpoints

class dpipe.train.checkpoint.Checkpoints(base_path: Union[Path, str], objects: Iterable, frequency: Optional[int] = None)[source]

Bases: object

Saves the most recent iteration to base_path and removes the previous one.

Parameters
  • base_path (str) – path to save/restore checkpoint object in/from.

  • objects (Dict[PathLike, Any]) – objects to save. Each key-value pair represents the path relative to base_path and the corresponding object.

  • frequency (int) – the frequency with which the objects are stored. By default only the latest checkpoint is saved.

save(iteration: int, train_losses: Optional[Sequence] = None, metrics: Optional[dict] = None)[source]

Save the states of all tracked objects.

restore() int[source]

Restore the most recent states of all tracked objects and return next iteration’s index.

dpipe.train.checkpoint.CheckpointManager

alias of Checkpoints

Policies

class dpipe.train.policy.Policy[source]

Bases: object

Interface for various policies.

epoch_started(epoch: int)[source]

Update the policy before an epoch will start. The epochs numeration starts at zero.

train_step_started(epoch: int, iteration: int)[source]

Update the policy before a new train step. iteration denotes the iteration index inside the current epoch. The epochs and iterations numeration starts at zero.

train_step_finished(epoch: int, iteration: int, loss: Any)[source]

Update the policy after a train step. iteration denotes the iteration index inside the current epoch. loss is the value returned by the last train step. The epochs and iterations numeration starts at zero.

validation_started(epoch: int, train_losses: Sequence)[source]

Update the policy after the batch iterator was depleted. The epochs numeration starts at zero.

The history of train_losses and metrics from the entire epoch is provided as additional information.

epoch_finished(epoch: int, train_losses: Sequence, metrics: Optional[dict] = None, policies: Optional[dict] = None)[source]

Update the policy after an epoch is finished. The epochs numeration starts at zero.

The history of train_losses and metrics and policies from the entire epoch is provided as additional information.

class dpipe.train.policy.ValuePolicy(initial)[source]

Bases: Policy

Interface for policies that have a value which changes over time.

value
Type

the current value carried by the policy.

dpipe.train.policy.Constant

alias of ValuePolicy

class dpipe.train.policy.DecreasingOnPlateau(*, initial: float, multiplier: float, patience: int, rtol, atol)[source]

Bases: ValuePolicy

Policy that traces average train loss and if it didn’t decrease according to atol or rtol for patience epochs, multiply value by multiplier. atol :- absolute tolerance for detecting change in training loss value. rtol :- relative tolerance for detecting change in training loss value.

class dpipe.train.policy.Exponential(initial: float, multiplier: float, step_length: int = 1, floordiv: bool = True, min_value: float = -inf, max_value: float = inf)[source]

Bases: ValuePolicy

Exponentially change the value by a factor of multiplier each step_length epochs. If floordiv is False - the value will be changed continuously.

class dpipe.train.policy.Schedule(initial: float, epoch2value_multiplier: Dict[int, float])[source]

Bases: ValuePolicy

Multiply value by multipliers given by epoch2value_multiplier at corresponding epochs.

class dpipe.train.policy.Switch(initial: float, epoch_to_value: Dict[int, Any])[source]

Bases: ValuePolicy

Changes the value at specific epochs to the values given in epoch_to_value.

class dpipe.train.policy.LambdaEpoch(func: Callable, *args, **kwargs)[source]

Bases: ValuePolicy

Use the passed function to calculate the value for the current epoch (starting with 0).

exception dpipe.train.policy.EarlyStopping[source]

Bases: StopIteration

Exception raised by policies in order to trigger early stopping.

class dpipe.train.policy.TQDM(loss: bool = True)[source]

Bases: Policy

Adds a tqdm progressbar. If loss is True - the progressbar will also display the current train loss.

Logging

Validation