from collections import OrderedDict, abc
from enum import Enum
from functools import reduce, wraps
from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union, overload
import torch
from torch import Tensor, nn
from .delay import Delay
from .logging import getLogger
from .module import CoModule, PaddingMode, _callmode
from .skip import Skip
from .utils import (
function_repr,
load_state_dict,
num_from,
state_dict,
temporary_parameter,
)
logger = getLogger(__name__)
__all__ = [
"Sequential",
"BroadcastReduce",
"Residual",
"Broadcast",
"Parallel",
"ParallelDispatch",
"Reduce",
"Conditional",
]
T = TypeVar("T")
S = TypeVar("S")
State = List[Optional[Tensor]]
class Reduction(Enum):
"""Types of parallel tensor reduce operation.
Supported tupes are:
- SUM: Element-wise summation
- CONCAT: Channel-wise concatenation
- MUL: Hadamark product
"""
SUM = "sum"
CONCAT = "concat"
MUL = "mul"
MAX = "max"
ReductionFunc = Callable[[Sequence[Tensor]], Tensor]
ReductionFuncOrEnum = Union[Reduction, ReductionFunc, str]
def apply_forward(module: CoModule, input: Tensor):
if isinstance(module, nn.RNNBase):
return module.forward(input)[0]
return module(input)
def reduce_sum(inputs: Sequence[Tensor]) -> Tensor:
assert len(inputs) >= 2
return reduce(torch.Tensor.add, inputs[1:], inputs[0])
def reduce_concat(inputs: Sequence[Tensor]) -> Tensor:
"""Channel-wise concatenation of input
Args:
inputs (Sequence[Tensor]): Inputs with broadcastable shapes.
Returns:
Tensor: Inputs concatenated in the channel dimension
"""
return torch.cat(inputs, dim=1) # channel dim for inputs of shape (B, C, T, H, W)
def reduce_mul(inputs: Sequence[Tensor]) -> Tensor:
"""Hadamard product between inputs
Args:
inputs (Sequence[Tensor]): Inputs with broadcastable shapes.
Returns:
Tensor: Haramard product of inputs
"""
assert len(inputs) >= 2
return reduce(torch.Tensor.mul, inputs[1:], inputs[0])
def reduce_max(inputs: Sequence[Tensor]) -> Tensor:
assert len(inputs) >= 2
return reduce(torch.max, inputs[1:], inputs[0])
def nonempty(fn: ReductionFunc) -> ReductionFunc:
@wraps(fn)
def wrapped(inputs: Sequence[Tensor]) -> Tensor:
if any(inp is None or inp.shape[0] == 0 for inp in inputs):
return None # pragma: no cover
return fn(inputs)
return wrapped
class FlattenableStateDict:
"""Mixes in the ability to flatten state dicts.
It is assumed that classes that inherit this module also inherit from nn.Module
"""
flatten_state_dict = False
def __init__(self, *args, **kwargs):
... # pragma: no cover
def state_dict(
self, destination=None, prefix="", keep_vars=False, flatten=False
) -> "OrderedDict[str, Tensor]":
flatten = flatten or self.flatten_state_dict
return state_dict(self, destination, prefix, keep_vars, flatten)
def load_state_dict(
self,
state_dict: "OrderedDict[str, Tensor]",
strict: bool = True,
flatten=False,
):
flatten = flatten or self.flatten_state_dict
return load_state_dict(self, state_dict, strict, flatten)
def co_add_module(self, name: str, module: Optional["nn.Module"]) -> None:
if not CoModule.is_valid(module):
# Attempt automatic conversion
from continual.convert import continual # break cyclical import
module = continual(module)
nn.Module.add_module(self, name, module)
[docs]class Broadcast(CoModule, nn.Module):
"""Broadcast one input stream to multiple output streams.
This is needed for handling parallel streams in subsequent modules.
For instance, here is how it is used to create a residual connection::
residual = co.Sequential(
co.Broadcast(2),
co.Parallel(
co.Conv3d(32, 32, kernel_size=3, padding=1),
co.Delay(2),
),
co.Reduce("sum"),
)
Since the ``Broadcast`` -> ``Parallel`` -> ``Reduce`` sequence is so common,
identical behavior can be achieved with ``BroadcastReduce`` ::
residual = co.BroadcastReduce(
co.Conv3d(32, 32, kernel_size=3, padding=1),
co.Delay(2),
reduce="sum"
)
Even shorter, the library features a residual connection, which automatically handles delays::
residual = co.Residual(co.Conv3d(32, 32, kernel_size=3, padding=1))
Args:
num_streams (int):
Number of streams to broadcast to. If none are given, a Sequential
module may infer it automatically.
"""
_state_shape = 0
_dynamic_state_inds = []
def __init__(
self,
num_streams: int = None,
):
nn.Module.__init__(self)
self.num_streams = num_streams
def forward(self, input: T) -> List[T]:
assert isinstance(
self.num_streams, int
), "Unknown number of target streams in Broadcast."
return [input for _ in range(self.num_streams)]
def _forward_step(self, input: T, prev_state=None):
return self.forward(input), prev_state
def forward_step(self, input: T, update_state=True) -> List[T]:
return self.forward(input)
def forward_steps(self, input: T, pad_end=False, update_state=True) -> List[T]:
return self.forward(input)
[docs]class Parallel(FlattenableStateDict, CoModule, nn.Sequential):
"""Container for executing modules in parallel.
Modules will be added to it in the order they are passed in the
constructor.
For instance, here is how it is used to create a residual connection::
residual = co.Sequential(
co.Broadcast(2),
co.Parallel(
co.Conv3d(32, 32, kernel_size=3, padding=1),
co.Delay(2),
),
co.Reduce("sum"),
)
Since the ``Broadcast`` -> ``Parallel`` -> ``Reduce`` sequence is so common,
identical behavior can be achieved with ``BroadcastReduce`` ::
residual = co.BroadcastReduce(
co.Conv3d(32, 32, kernel_size=3, padding=1),
co.Delay(2),
reduce="sum"
)
Even shorter, the library features a residual connection, which automatically handles delays::
residual = co.Residual(co.Conv3d(32, 32, kernel_size=3, padding=1))
Args:
arg (OrderedDict[str, CoModule]): An OrderedDict of strings and modules.
*args (CoModule): Comma-separated modules.
auto_delay (bool, optional):
Automatically add delay to modules in order to match the longest delay.
Defaults to True.
"""
@overload
def __init__(
self,
*args: CoModule,
auto_delay=True,
) -> None:
... # pragma: no cover
@overload
def __init__(
self,
arg: "OrderedDict[str, CoModule]",
auto_delay=True,
) -> None:
... # pragma: no cover
def __init__(
self,
*args,
auto_delay=True,
):
nn.Module.__init__(self)
if len(args) == 1 and isinstance(args[0], OrderedDict):
modules = [(key, module) for key, module in args[0].items()]
else:
modules = [(str(idx), module) for idx, module in enumerate(args)]
if auto_delay:
# If there is a delay mismatch, automatically add delay to match the longest
max_delay = max([m.delay for _, m in modules])
modules = [
(
key,
(
Sequential(module, Delay(max_delay - module.delay))
if module.delay < max_delay
else module
),
)
for key, module in modules
]
assert (
len(set(num_from(getattr(m, "stride", 1)) for _, m in modules)) == 1
), f"Expected all modules to have the same stride, but got strides {[(num_from(getattr(m, 'stride', 1))) for _, m in modules]}"
for key, module in modules:
self.add_module(key, module)
delays = set(m.delay for m in self)
if len(delays) != 1: # pragma: no cover
logger.warning(
f"It recommended that parallel modules have the same delay, but found delays {delays}. "
"Temporal consistency cannot be guaranteed."
)
self._delay = max(delays)
receptive_fields = set(m.receptive_field for m in self)
self._receptive_field = max(receptive_fields)
def add_module(self, name: str, module: Optional["nn.Module"]) -> None:
co_add_module(self, name, module)
@property
def _state_shape(self):
return [m._state_shape for m in self]
@property
def _dynamic_state_inds(self):
return [m._dynamic_state_inds for m in self]
def _forward_step(
self, inputs: List[T], prev_state: Optional[List[Optional[S]]] = None
):
prev_state = prev_state or [None for _ in range(len(self))]
outs, next_state = [], []
for i, module in enumerate(self):
out, n_state = module._forward_step(inputs[i], prev_state=prev_state[i])
outs.append(out)
next_state.append(n_state)
return outs, next_state
def forward_step(self, inputs: List[T], update_state=True) -> List[T]:
outs = []
for i, m in enumerate(self):
with temporary_parameter(m, "call_mode", _callmode("forward_step")):
outs.append(m(inputs[i], update_state=update_state))
return outs
def forward_steps(
self, inputs: List[T], pad_end=False, update_state=True
) -> List[T]:
outs = []
for i, m in enumerate(self):
with temporary_parameter(m, "call_mode", _callmode("forward_steps")):
outs.append(m(inputs[i], pad_end=pad_end, update_state=update_state))
return outs
def forward(self, inputs: List[T]) -> List[T]:
outs = []
for i, m in enumerate(self):
with temporary_parameter(m, "call_mode", _callmode("forward")):
outs.append(apply_forward(m, inputs[i]))
return outs
@property
def receptive_field(self) -> int:
return self._receptive_field
@property
def delay(self) -> int:
return self._delay
@property
def stride(self) -> Tuple[int]:
return getattr(next(iter(self)), "stride", (1,))
@property
def padding(self) -> Tuple[int]:
return (max(getattr(m, "padding", (0,))[0] for m in self),)
def clean_state(self):
for m in self:
if hasattr(m, "clean_state"):
m.clean_state()
[docs]class ParallelDispatch(CoModule, nn.Module):
"""Reorder, copy, and group streams from parallel streams.
Reorder example::
net = co.Sequential(
co.Broadcast(2),
co.Parallel(co.Add(1), co.Identity()),
co.ParallelDispatch([1,0]), # Reorder stream 0 and 1
co.Parallel(co.Identity(), co.Add(2)),
co.Reduce("max"),
)
assert torch.equal(net(torch.tensor([0])), torch.tensor([3]))
Depiction of the reorder example::
| -> co.Add(1) \\ / -> co.Identity() |
[0] -> | X | -> max -> [3]
| -> co.Identity() / \\ -> co.Add(2) |
Copy example::
net = co.Sequential(
co.Broadcast(2),
co.Parallel(co.Add(1), co.Identity()),
co.ParallelDispatch([0, 0, 1]), # Copy stream 0
co.Parallel(co.Identity(), co.Add(2), co.Identity()),
co.Reduce("max"),
)
assert torch.equal(net(torch.tensor([0])), torch.tensor([3]))
Depiction of the copy example::
| -> co.Add(1) -> | -> co.Identity() -> |
[0] -> | | -> co.Add(2) -> | -> max -> [3]
| -> co.Identity() ------> co.Add(1) -> |
Group example::
net = co.Sequential(
co.Broadcast(2),
co.Parallel(co.Add(2), co.Identity()),
co.ParallelDispatch([[0, 0], 1]), # Copy and group stream 0
co.Parallel(co.Reduce("sum"), co.Identity()),
co.Reduce("max"),
)
assert torch.equal(net(torch.tensor([0])), torch.tensor([4]))
Depiction of the group example::
| -> |
| -> co.Add(2) -> | | -> sum -> |
[0] -> | | -> | | -> max -> [4]
| -> co.Identity() ----> co.Identity() -> |
Args:
dispatch_mapping (Sequence[Union[int, Sequence[int]]]):
input-to-output mapping, where the integers signify the input stream ordering
and the positions denote corresponding output ordering.
Examples::
[1,0] to shuffle order of streams.
[0,1,1] to copy stream 1 onto a new stream.
[[0,1],2] to group stream 0 and 1 while keeping stream 2 separate.
"""
_state_shape = 0
_dynamic_state_inds = []
def __init__(
self,
dispatch_mapping: Sequence[Union[int, Sequence[int]]],
):
nn.Module.__init__(self)
def is_int_or_valid_list(x):
if isinstance(x, int):
return True
elif isinstance(x, abc.Sequence):
return all(is_int_or_valid_list(z) for z in x)
else:
return False
assert isinstance(dispatch_mapping, abc.Sequence) and is_int_or_valid_list(
dispatch_mapping
), "The dispatch_mapping should be of type Sequence[Union[StreamId, Sequence[StreamId]]]"
self.dispatch_mapping = dispatch_mapping
def forward(self, input: List[T]) -> List[Union[T, List[T]]]:
def dispatch(mapping):
nonlocal input
if isinstance(mapping, abc.Sequence):
return [dispatch(m) for m in mapping]
else:
return input[mapping]
return dispatch(self.dispatch_mapping)
def _forward_step(self, input: List[T], prev_state=None):
return self.forward_step(input), prev_state
def forward_step(
self, input: List[T], update_state=True
) -> List[Union[T, List[T]]]:
return self.forward(input)
def forward_steps(
self, input: List[T], pad_end=False, update_state=True
) -> List[Union[T, List[T]]]:
return self.forward(input)
[docs]class Reduce(CoModule, nn.Module):
"""Reduce multiple input streams to a single using the selected function
For instance, here is how it is used to sum streams in a residual connection::
residual = co.Sequential(
co.Broadcast(2),
co.Parallel(
co.Conv3d(32, 32, kernel_size=3, padding=1),
co.Delay(2),
),
co.Reduce("sum"),
)
A user-defined can be passed as well::
from functools import reduce
def my_sum(inputs):
return reduce(torch.Tensor.add, inputs[1:], inputs[0])
residual = co.Sequential(
co.Broadcast(2),
co.Parallel(
co.Conv3d(32, 32, kernel_size=3, padding=1),
co.Delay(2),
),
co.Reduce(my_sum),
)
Args:
reduce (Union[str, Callable[[Sequence[Tensor]], Tensor]]):
Reduce function. Either one of ["sum", "channel", "mul", "max"] or
user-defined function mapping a sequence of tensors to a single one.
"""
_state_shape = 0
_dynamic_state_inds = []
def __init__(
self,
reduce: ReductionFuncOrEnum = "sum",
):
nn.Module.__init__(self)
self.reduce = nonempty(
reduce
if callable(reduce)
else {
Reduction.SUM: reduce_sum,
Reduction.CONCAT: reduce_concat,
Reduction.MUL: reduce_mul,
Reduction.MAX: reduce_max,
}[Reduction(reduce)]
)
def forward(self, inputs: List[T]) -> T:
return self.reduce(inputs)
def _forward_step(self, inputs: List[T], prev_state=None):
if all(isinstance(i, Tensor) for i in inputs):
return self.reduce(inputs), prev_state
return None, prev_state # pragma: no cover
def forward_step(self, inputs: List[T], update_state=True) -> T:
return self._forward_step(inputs)[0]
def forward_steps(self, inputs: List[T], pad_end=False, update_state=True) -> T:
return self.reduce(inputs)
[docs]class Sequential(FlattenableStateDict, CoModule, nn.Sequential):
"""A sequential container.
This module is an augmentation of `torch.nn.Sequential`
which adds continual inference methods
Modules will be added to it in the order they are passed in the
constructor. Alternatively, an ``OrderedDict`` of modules can be
passed in. The ``forward()``, ``forward_step()`` and ``forward_steps()``
methods of ``Sequential`` accept any input and forwards it to the first
module it contains. It then "chains" outputs to inputs sequentially for
each subsequent module, finally returning the output of the last module.
The value a ``Sequential`` provides over manually calling a sequence
of modules is that it allows treating the whole container as a
single module, such that performing a transformation on the
``Sequential`` applies to each of the modules it stores (which are
each a registered submodule of the ``Sequential``).
Example::
# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = co.Sequential(
co.Conv2d(1,20,5),
nn.ReLU(),
co.Conv2d(20,64,5),
nn.ReLU()
)
# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = co.Sequential(OrderedDict([
('conv1', co.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', co.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
"""
@overload
def __init__(self, *args: nn.Module) -> None:
... # pragma: no cover
@overload
def __init__(self, arg: "OrderedDict[str, nn.Module]") -> None:
... # pragma: no cover
def __init__(self, *args):
nn.Module.__init__(self)
modules = []
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
modules.append((key, module))
else:
for idx, module in enumerate(args):
modules.append((str(idx), module))
# If a co.Broadcast is followed by a co.Parallel, automatically infer num_streams
for i in range(len(modules)):
if isinstance(modules[i][1], Broadcast) and i < len(modules) - 1:
if isinstance(modules[i + 1][1], Parallel):
modules[i][1].num_streams = modules[i][1].num_streams or len(
modules[i + 1][1]
)
for n, m in modules:
self.add_module(n, m)
def add_module(self, name: str, module: Optional["nn.Module"]) -> None:
co_add_module(self, name, module)
def forward(self, input):
for m in self:
with temporary_parameter(m, "call_mode", _callmode("forward")):
input = apply_forward(m, input)
return input
def forward_step(self, input, update_state=True):
for module in self:
# ptflops only works when __call__ is triggered
with temporary_parameter(module, "call_mode", _callmode("forward_step")):
input = module(
input, update_state=update_state
) # == module.forward_step
if not type(input) in {Tensor, list}:
return None
return input
def _forward_step(self, input: torch.Tensor, prev_state: List[State]):
prev_state = prev_state or [None for _ in range(len(self))]
next_state = prev_state.copy()
for i, module in enumerate(self):
input, n_state = module._forward_step(input, prev_state=prev_state[i])
next_state[i] = n_state
if input is None:
return None, next_state
return input, next_state
@property
def _state_shape(self):
return [m._state_shape for m in self]
@property
def _dynamic_state_inds(self):
return [m._dynamic_state_inds for m in self]
def forward_steps(self, input: Tensor, pad_end=False, update_state=True):
for m in self:
if not type(input) in {Tensor, list} or len(input) == 0:
return None # pragma: no cover
# ptflops only works when __call__ is triggered
with temporary_parameter(m, "call_mode", _callmode("forward_steps")):
# == m.forward_steps
input = m(input, pad_end=pad_end, update_state=update_state)
return input
@property
def receptive_field(self) -> int:
reverse_modules = [m for m in self][::-1]
rf = reverse_modules[0].receptive_field
for m in reverse_modules[1:]:
s = getattr(m, "stride", [1])
s = s[0]
rf = s * rf + m.receptive_field - s
return rf
@property
def stride(self) -> Tuple[int]:
tot = 1
for m in self:
tot *= m.stride[0]
return (tot,)
@property
def padding(self) -> Tuple[int]:
m = [m for m in self]
p = m[0].padding[0]
s = m[0].stride[0]
for i in range(1, len(m)):
p += m[i].padding[0] * s
s = s * m[i].stride[0]
return (p,)
@staticmethod
def build_from(module: nn.Sequential) -> "Sequential":
from .convert import continual # import here due to circular import
return Sequential(
OrderedDict([(k, continual(m)) for k, m in module._modules.items()])
)
def clean_state(self):
for m in self:
if hasattr(m, "clean_state"):
m.clean_state()
[docs] def append(self, module: nn.Module) -> "Sequential":
r"""Appends a given module to the end.
Args:
module (nn.Module): module to append
"""
self.add_module(str(len(self)), module)
return self
[docs]class BroadcastReduce(Sequential):
"""Broadcast an input to parallel modules and reduce.
This module is a shorthand for::
co.Sequential(co.Broadcast(), co.Parallel(*args), co.Reduce(reduce))
For instance, it can be used to succinctly create a continual 3D Inception Module::
def norm_relu(module, channels):
return co.Sequential(
module,
nn.BatchNorm3d(channels),
nn.ReLU(),
)
inception_module = co.BroadcastReduce(
co.Conv3d(192, 64, kernel_size=1),
co.Sequential(
norm_relu(co.Conv3d(192, 96, kernel_size=1), 96),
norm_relu(co.Conv3d(96, 128, kernel_size=3, padding=1), 128),
),
co.Sequential(
norm_relu(co.Conv3d(192, 16, kernel_size=1), 16),
norm_relu(co.Conv3d(16, 32, kernel_size=5, padding=2), 32),
),
co.Sequential(
co.MaxPool3d(kernel_size=(1, 3, 3), padding=(0, 1, 1), stride=1),
norm_relu(co.Conv3d(192, 32, kernel_size=1), 32),
),
reduce="concat",
)
Args:
arg (OrderedDict[str, CoModule]): An OrderedDict or modules to be applied in parallel.
*args (CoModule): Modules to be applied in parallel.
reduce (ReductionFuncOrEnum, optional):
Function used to reduce the parallel outputs.
Sum or concatenation can be specified by passing "sum" or "concat" respectively.
Custom reduce functions can also be passed.
Defaults to "sum".
auto_delay (bool, optional):
Automatically add delay to modules in order to match the longest delay.
Defaults to True.
"""
@overload
def __init__(
self,
*args: CoModule,
reduce: ReductionFuncOrEnum = "sum",
auto_delay=True,
) -> None:
... # pragma: no cover
@overload
def __init__(
self,
arg: "OrderedDict[str, CoModule]",
reduce: ReductionFuncOrEnum = "sum",
auto_delay=True,
) -> None:
... # pragma: no cover
def __init__(
self,
*args,
reduce: ReductionFuncOrEnum = "sum",
auto_delay=True,
):
nn.Module.__init__(self)
if len(args) == 1 and isinstance(args[0], OrderedDict):
modules = [(key, module) for key, module in args[0].items()]
else:
modules = [(str(idx), module) for idx, module in enumerate(args)]
assert (
len(modules) > 1
), "You should pass at least two modules for the map-reduce operation to make sense."
if auto_delay:
# If there is a delay mismatch, automatically add delay to match the longest
max_delay = max([m.delay for _, m in modules])
modules = [
(
key,
(
Sequential(module, Delay(max_delay - module.delay))
if module.delay < max_delay
else module
),
)
for key, module in modules
]
assert (
len(set(num_from(getattr(m, "stride", 1)) for _, m in modules)) == 1
), f"Expected all modules to have the same stride, but got strides {[(num_from(getattr(m, 'stride', 1))) for _, m in modules]}"
for key, module in modules:
self.add_module(key, module)
self.reduce = nonempty(
reduce
if callable(reduce)
else {
Reduction.SUM: reduce_sum,
Reduction.CONCAT: reduce_concat,
Reduction.MUL: reduce_mul,
Reduction.MAX: reduce_max,
}[Reduction(reduce)]
)
delays = set(m.delay for m in self)
self._delay = max(delays)
receptive_fields = set(m.receptive_field for m in self)
self._receptive_field = max(receptive_fields)
def add_module(self, name: str, module: Optional["nn.Module"]) -> None:
co_add_module(self, name, module)
@property
def _state_shape(self):
return [m._state_shape for m in self]
@property
def _dynamic_state_inds(self):
return [m._dynamic_state_inds for m in self]
def _forward_step(self, input: torch.Tensor, prev_state: List[State] = None):
prev_state = prev_state or [None for _ in range(len(self))]
next_state = prev_state.copy()
outs = []
for i, module in enumerate(self):
out, n_state = module._forward_step(input, prev_state=prev_state[i])
next_state[i] = n_state
outs.append(out)
if all(isinstance(o, Tensor) for o in outs):
return self.reduce(outs), next_state
return None, next_state
def forward_step(self, input: Tensor, update_state=True) -> Tensor:
outs = []
for m in self:
with temporary_parameter(m, "call_mode", "forward_step"):
outs.append(m(input, update_state=update_state)) # == m.forward_step
if all(isinstance(o, Tensor) for o in outs):
return self.reduce(outs)
return None
def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
outs = []
for m in self:
with temporary_parameter(m, "call_mode", "forward_steps"):
# m.forward_steps
outs.append(m(input, pad_end=pad_end, update_state=update_state))
return self.reduce(outs)
def forward(self, input: Tensor) -> Tensor:
outs = []
for m in self:
with temporary_parameter(m, "call_mode", "forward"):
outs.append(apply_forward(m, input))
return self.reduce(outs)
@property
def receptive_field(self) -> int:
return self._receptive_field
@property
def delay(self) -> int:
return self._delay
@property
def stride(self) -> Tuple[int]:
return getattr(next(iter(self)), "stride", (1,))
@property
def padding(self) -> Tuple[int]:
return (max(getattr(m, "padding", (0,))[0] for m in self),)
def clean_state(self):
for m in self:
if hasattr(m, "clean_state"):
m.clean_state()
def extra_repr(self):
return f"reduce={self.reduce.__name__}"
[docs]def Residual(
module: CoModule,
temporal_fill: PaddingMode = None,
reduce: Reduction = "sum",
residual_shrink: Union[bool, str] = False,
) -> BroadcastReduce:
"""Residual connection wrapper for input.
This module produces a short form of BroadCast reduce with one delay stream::
conv = co.Conv3d(32, 32, kernel_size=3, padding=1)
res1 = co.BroadcastReduce(conv, co.Delay(2), reduce="sum")
res2 = co.Residual(conv)
x = torch.randn(1, 32, 5, 5, 5)
assert torch.equal(res1(x), res2(x))
Args:
module (CoModule): module to which a residual should be added.
temporal_fill (PaddingMode, optional): temporal fill type in delay. Defaults to None.
reduce (Reduction, optional): Reduction function. Defaults to "sum".
residual_shrink (bool, optional):
Set residual to shrink its forward to match the temporal dimension reduction of the wrapped module.
Options: "centered", "lagging" or True: Centered residual shrink;
"lagging": lagging shrink. Defaults to False.
"leading": leading shrink, i.e. no delay during forward_step(s).
Returns:
BroadcastReduce: BroadcastReduce module with residual.
"""
assert num_from(getattr(module, "stride", 1)) == 1, (
"The simple `Residual` only works for modules with temporal stride=1. "
"Complex residuals can be achieved using `BroadcastReduce` or the `Broadcast`, `Parallel`, and `Reduce` modules."
)
temporal_fill = temporal_fill or getattr(
module, "temporal_fill", PaddingMode.REPLICATE.value
)
delay = module.delay
equal_padding = module.receptive_field - num_from(module.padding) * 2 == 1
if equal_padding:
residual_shrink = False
if residual_shrink in {True, "centered"}:
assert delay % 2 == 0, "Auto-shrink only works for even-number delays."
delay = delay // 2
if residual_shrink == "leading":
res = Skip(delay)
else:
res = Delay(delay, temporal_fill, auto_shrink=residual_shrink)
return BroadcastReduce(
res, # Residual first yields easier broadcasting in reduce functions
module,
reduce=reduce,
auto_delay=False,
)
[docs]class Conditional(FlattenableStateDict, CoModule, nn.Module):
"""Module wrapper for conditional invocations at runtime.
For instance, it can be used to apply a softmax if the module isn't training::
net = co.Sequential()
def not_training(module, x):
return not net.training
net.append(co.Conditional(not_training, torch.nn.Softmax(dim=1)))
Args:
predicate (Callable[[CoModule, Tensor], bool]):
Function used to evaluate whether on module or the other should be invoked.
on_true: CoModule: Module to invoke on True.
on_false: Optional[CoModule]: Module to invoke on False. If no module is passed, execution is skipped.
"""
def __init__(
self,
predicate: Callable[[CoModule, Tensor], bool],
on_true: CoModule,
on_false: CoModule = None,
):
from continual.convert import continual # Break cyclical import
assert callable(predicate), "The pased function should be callable."
if not isinstance(on_true, CoModule):
on_true = continual(on_true)
if not (isinstance(on_false, CoModule) or on_false is None):
on_false = continual(on_false)
nn.Module.__init__(self)
self.predicate = predicate
# Ensure modules have the same delay
self._delay = max(on_true.delay, getattr(on_false, "delay", 0))
self._receptive_field = max(
on_true.receptive_field, getattr(on_false, "receptive_field", 1)
)
self.add_module(
"0",
on_true
if on_true.delay == self._delay
else Sequential(Delay(self._delay - on_true.delay), on_true),
)
if on_false is not None:
self.add_module(
"1",
on_false
if on_false.delay == self._delay
else Sequential(Delay(self._delay - on_false.delay), on_false),
)
def forward(self, input: Tensor) -> Tensor:
if self.predicate(self, input):
return apply_forward(self._modules["0"], input)
elif "1" in self._modules:
return apply_forward(self._modules["1"], input)
return input
def forward_step(self, input: Tensor, update_state=True) -> Tensor:
if self.predicate(self, input):
return self._modules["0"].forward_step(input, update_state)
elif "1" in self._modules:
return self._modules["1"].forward_step(input, update_state)
return input
def _forward_step(
self, input: Tensor, prev_state: Optional[State] = None
) -> Tuple[Tensor, Optional[State]]:
prev_state = prev_state or [None, None]
if self.predicate(self, input):
return self._modules["0"]._forward_step(input, prev_state[0])
elif "1" in self._modules:
return self._modules["1"]._forward_step(input, prev_state[1])
return input, prev_state
def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
if self.predicate(self, input):
return self._modules["0"].forward_steps(input)
elif "1" in self._modules:
return self._modules["1"].forward_steps(input)
return input
@property
def _state_shape(self):
return [m._state_shape for m in self._modules.values()]
@property
def _dynamic_state_inds(self):
return [m._dynamic_state_inds for m in self._modules.values()]
@property
def delay(self) -> int:
return self._delay
@property
def receptive_field(self) -> int:
return self._receptive_field
def extra_repr(self):
return f"predicate={function_repr(self.predicate)}"