Source code for dpipe.layers.structure

from typing import Callable, Sequence

import numpy as np
import torch
import torch.nn as nn

from dpipe.im.utils import build_slices, pam, identity


[docs]def make_consistent_seq(layer: Callable, channels: Sequence[int], *args, **kwargs): """ Builds a sequence of layers that have consistent input and output channels/features. ``args`` and ``kwargs`` are passed as additional parameters. Examples -------- >>> make_consistent_seq(nn.Conv2d, [16, 32, 64, 128], kernel_size=3, padding=1) >>> # same as >>> nn.Sequential( >>> nn.Conv2d(16, 32, kernel_size=3, padding=1), >>> nn.Conv2d(32, 64, kernel_size=3, padding=1), >>> nn.Conv2d(64, 128, kernel_size=3, padding=1), >>> ) """ return ConsistentSequential(layer, channels, *args, **kwargs)
[docs]class ConsistentSequential(nn.Sequential): """ A sequence of layers that have consistent input and output channels/features. ``args`` and ``kwargs`` are passed as additional parameters. Examples -------- >>> ConsistentSequential(nn.Conv2d, [16, 32, 64, 128], kernel_size=3, padding=1) >>> # same as >>> nn.Sequential( >>> nn.Conv2d(16, 32, kernel_size=3, padding=1), >>> nn.Conv2d(32, 64, kernel_size=3, padding=1), >>> nn.Conv2d(64, 128, kernel_size=3, padding=1), >>> ) """ def __init__(self, layer: Callable, channels: Sequence[int], *args, **kwargs): if len(channels) < 2: raise ValueError('`channels` must contain at least two elements.') super().__init__(*(layer(in_, out, *args, **kwargs) for in_, out in zip(channels, channels[1:])))
[docs]class PreActivation(nn.Module): """ Runs a sequence of batch_norm, activation, and ``layer``. in -> (BN -> activation -> layer) -> out Parameters ---------- in_features: int the number of incoming features/channels. out_features: int the number of the output features/channels. batch_norm_module module to build up batch normalization layer, e.g. ``torch.nn.BatchNorm3d``. activation_module module to build up activation layer. Default is ``torch.nn.ReLU``. layer_module: Callable(in_features, out_features, **kwargs) module to build up the main layer, e.g. ``torch.nn.Conv3d`` or ``torch.nn.Linear``. kwargs additional arguments passed to ``layer_module``. """ def __init__(self, in_features: int, out_features: int, *, layer_module, batch_norm_module=None, activation_module=nn.ReLU, **kwargs): super().__init__() if batch_norm_module is not None: self.bn = batch_norm_module(in_features) else: self.bn = identity self.activation = activation_module() self.layer = layer_module(in_features, out_features, **kwargs) def forward(self, x): return self.layer(self.activation(self.bn(x)))
[docs]class PostActivation(nn.Module): """ Performs a sequence of layer, batch_norm and activation: in -> (layer -> BN -> activation) -> out Parameters ---------- in_features: int the number of incoming features/channels. out_features: int the number of the output features/channels. batch_norm_module module to build up batch normalization layer, e.g. ``torch.nn.BatchNorm3d``. activation_module module to build up activation layer. Default is ``torch.nn.ReLU``. layer_module: Callable(in_features, out_features, **kwargs) module to build up the main layer, e.g. ``torch.nn.Conv3d`` or ``torch.nn.Linear``. kwargs additional arguments passed to ``layer_module``. Notes ----- If ``layer`` supports a bias term, make sure to pass ``bias=False``. """ def __init__(self, in_features: int, out_features: int, *, layer_module, batch_norm_module=None, activation_module=nn.ReLU, **kwargs): super().__init__() self.layer = layer_module(in_features, out_features, **kwargs) self.activation = activation_module() if batch_norm_module is not None: self.bn = batch_norm_module(out_features) else: self.bn = identity def forward(self, x): return self.activation(self.bn(self.layer(x)))
class CenteredCrop(nn.Module): def __init__(self, start, stop=None): super().__init__() if stop is None: start = np.asarray(start) stop = np.where(start, -start, None) self.slices = (slice(None), slice(None), *build_slices(start, stop)) def forward(self, x): return x[self.slices] class SplitReduce(nn.Module): def __init__(self, reduce, *paths): super().__init__() self.reduce = reduce self.paths = nn.ModuleList(list(paths)) def forward(self, x): return self.reduce(pam(self.paths, x)) class Split(SplitReduce): def __init__(self, *paths): super().__init__(tuple, *paths) class SplitCat(SplitReduce): def __init__(self, *paths, axis=1): super().__init__(lambda x: torch.cat(tuple(x), dim=axis), *paths) class SplitAdd(nn.Module): def __init__(self, *paths): super().__init__() self.init_path, *paths = paths self.other_paths = nn.ModuleList(list(paths)) def forward(self, x): result = self.init_path(x) for path in self.other_paths: result = result + path(x) return result
[docs]class Lambda(nn.Module): """ Applies ``func`` to the incoming tensor. ``kwargs`` are passed as additional arguments. """ def __init__(self, func, **kwargs): super().__init__() self.func = func self.kwargs = kwargs def forward(self, *args, **kwargs): return self.func(*args, **kwargs, **self.kwargs)