import functools
import itertools
from abc import ABC
from enum import Enum
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.hooks as hooks
from torch import Tensor
from torch.nn.modules.module import (
_global_backward_hooks,
_global_forward_hooks,
_global_forward_pre_hooks,
)
__all__ = ["CoModule", "call_mode"]
# First element is a value buffer, while others are indexes
State = Tuple[Tensor, Tensor, Tensor]
class PaddingMode(Enum):
REPLICATE = "replicate"
ZEROS = "zeros"
NEG_INF = "neg_inf"
# Not compatible with torch.jit:
# class CallMode(Enum):
# FORWARD = "forward"
# FORWARD_STEPS = "forward_steps"
# FORWARD_STEP = "forward_step"
CALL_MODES = {
"forward": torch.tensor(0),
"forward_steps": torch.tensor(1),
"forward_step": torch.tensor(2),
}
def _callmode(cm) -> torch.Tensor:
"""_summary_
Args:
cm (Union[str, int, torch.Tensor]): Identifier for call mode
Returns:
torch.Tensor: validated call_mode
"""
if isinstance(cm, str):
cm = CALL_MODES[cm.lower()]
elif isinstance(cm, int):
cm = torch.tensor(int)
return cm
class _CallModeContext(object):
"""Context-manager which temporarily specifies a call_mode
When the call_mode context is used, the ``__call__`` function of continual modules with be set accordingly
Example::
sequence = torch.randn(1, 3, 10)
step = sequence[:, :, -1]
forward_output = module(sequence) # Calls `forward`
with co.call_mode("forward_step"):
forward_step_output = module(step) # Calls `forward_step`
forward_output = module(sequence) # Calls `forward`
"""
def __init__(self):
self.cur = _callmode("forward")
self.prev = None
def __call__(self, value: Union[str, int, torch.Tensor]):
self.prev = self.cur
self.cur = _callmode(value)
return self
def __enter__(self):
pass
def __exit__(self, *args, **kwargs):
self.cur = self.prev
self.prev = None
call_mode = _CallModeContext()
def _clone_first(state: State) -> State:
# return (state[0].clone(), state[1], state[2])
return (state[0].clone(),) + state[1:]
# return (state[0].clone(),) + state[1:]
[docs]class CoModule(ABC):
"""Base class for continual modules.
Deriving from this class provides base-functionality and enforces
the implementation of necessary methods.
Attributes:
receptive_field (int): Temporal receptive field of the module.
delay (int): Number of step inputs to observe before the modules produces valid outputs.
stride (Tuple[int,...]): (Spatio)-temporal stride.
padding (Tuple[int,...]): (Spatio)-temporal padding.
"""
receptive_field: int = 1
stride: Tuple[int, ...] = (1,)
padding: Tuple[int, ...] = (0,)
_make_padding = torch.zeros_like
_state_shape: int = 0
_dynamic_state_inds: List[bool] = []
_call_mode = _callmode("forward")
def __init_subclass__(cls) -> None:
CoModule._validate_class(cls)
@staticmethod
def _validate_class(cls):
for fn, description in [
("forward_step", "forward computation for a single temporal step"),
(
"forward_steps",
"forward computation for multiple temporal step",
),
(
"forward",
"a forward computation which is identical to a regular non-continual forward.",
),
("get_state", "a retrieval of the internal state."),
("set_state", "an update of the internal state."),
("clean_state", "an internal state clean-up."),
]:
assert callable(
getattr(cls, fn, None)
), f"{cls} should implement a `{fn}` function which performs {description} to satisfy the CoModule interface."
for prop in {"delay", "receptive_field"}:
assert type(getattr(cls, prop, None)) in {
int,
torch.Tensor,
property,
}, f"{cls} should implement a `{prop}` property to satisfy the CoModule interface."
for prop in {"stride", "padding", "_state_shape", "_dynamic_state_inds"}:
assert type(getattr(cls, prop, None)) in {
int,
property,
tuple,
list,
}, f"{cls} should implement a `{prop}` property to satisfy the CoModule interface."
@staticmethod
def is_valid(module):
try:
CoModule._validate_class(module)
except AssertionError:
return False
return True
[docs] def get_state(self) -> Optional[State]:
"""Get model state.
Returns:
Optional[State]: A State tuple if the model has been initialised and otherwise None.
"""
... # pragma: no cover
[docs] def set_state(self, state: State):
"""Set model state
Args:
state (State): State tuple to set as new internal internal state
"""
... # pragma: no cover
[docs] def clean_state(self):
"""Clean model state, resetting the network memory."""
... # pragma: no cover
@property
def delay(self) -> int:
return self.receptive_field - 1 - self.padding[0]
[docs] def forward_step(
self, input: Tensor, update_state: bool = True
) -> Optional[Tensor]:
"""Performs a forward computation for a single frame and (optionally) updates internal states accordingly.
This function performs efficient continual inference.
Illustration::
O+S O+S O+S O+S (O: output, S: updated internal state)
↑ ↑ ↑ ↑
N N N N (N: network module)
↑ ↑ ↑ ↑
I I I I (I: input frame)
Args:
input (Tensor): Layer input.
update_state (bool): Whether internal state should be updated during this operation.
Returns:
Optional[Tensor]: Step output. This will be a placeholder while the module initializes and every (stride - 1) / stride.
"""
state = self.get_state()
if not update_state and state:
state = _clone_first(state)
output, state = self._forward_step(input, state)
if update_state and state:
self.set_state(state)
return output
[docs] def forward_steps(
self, input: Tensor, pad_end: bool = False, update_state=True
) -> Optional[Tensor]:
"""Performs a forward computation across multiple time-steps while updating internal states for continual inference (if update_state=True).
Start-padding is always accounted for, but end-padding is omitted per default in expectance of the next input step. It can be added by specifying pad_end=True. If so, the output-input mapping the exact same as that of forward.
Illustration::
O (O: output)
↑
----------------- (-: aggregation)
O O+S O+S O+S O (O: output, S: updated internal state)
↑ ↑ ↑ ↑ ↑
N N N N N (N: network module)
↑ ↑ ↑ ↑ ↑
P I I I P (I: input frame, P: padding)
Args:
input (Tensor): Layer input.
pad_end (bool): Whether results for temporal padding at sequence end should be included.
update_state (bool): Whether internal state should be updated during this operation.
Returns:
Optional[Tensor]: Layer output
"""
return self._forward_steps_impl(input, pad_end, update_state)
def _forward_steps_impl(
self, input: Tensor, pad_end: bool = False, update_state: bool = True
) -> Optional[Tensor]:
"""Forward computation for multiple steps with state initialisation
Args:
module (CoModule): Continual module.
input (Tensor): Layer input.
pad_end (bool): Whether results for temporal padding at sequence end should be included.
update_state (bool): Whether internal state should be updated during this operation.
Returns:
Tensor: Layer output
"""
outs = []
state: Optional[State] = self.get_state()
if not update_state and state is not None:
state = _clone_first(state)
for t in range(input.shape[2]):
o, state = self._forward_step(input[:, :, t], state)
if isinstance(o, Tensor):
outs.append(o)
if update_state and state is not None:
self.set_state(state)
if pad_end:
# Don't save state for the end-padding
opt_state = self.get_state()
if opt_state is not None:
state = _clone_first(opt_state)
for t, i in enumerate(
[self._make_padding(input[:, :, -1]) for _ in range(self.padding[0])]
):
o, state = self._forward_step(i, state)
if isinstance(o, Tensor):
outs.append(o)
if len(outs) == 0:
return None # pragma: no cover
return torch.stack(outs, dim=2)
[docs] def forward(self, input: Tensor) -> Tensor:
"""Performs a forward computation over multiple time-steps.
This function is identical to the corresponding module in _torch.nn_,
ensuring cross-compatibility. Moreover, it's handy for efficient training
on clip-based data.
Illustration::
O (O: output)
↑
N (N: network module)
↑
----------------- (-: aggregation)
P I I I P (I: input frame, P: padding)
Args:
input (Tensor): Network input.
"""
... # pragma: no cover
[docs] def warm_up(self, step_shape: List[int]):
"""Warms up the model state with a dummy input.
The initial `self.delay` steps will produce results, but they will be inexact.
To warm up the model with a user-defined data, pass the data to forward_steps::
net.forward_steps(user_data)
Args:
step_shape (Sequence[int]): input shape with which to warm the model up, including batch size.
"""
step_shape = (*step_shape[:2], self.delay, *step_shape[2:])
dummy = self._make_padding(torch.zeros(step_shape, dtype=torch.float))
self.forward_steps(dummy)
@property
def call_mode(self) -> torch.Tensor:
return self._call_mode
@call_mode.setter
def call_mode(self, value):
self._call_mode = _callmode(value)
if hasattr(self, "__len__"):
for m in self:
if hasattr(m, "call_mode"):
m.call_mode = self._call_mode
def _call_impl(self, *input, **kwargs): # noqa: C901 # pragma: no cover
"""Modified version torch.nn.Module._call_impl
Returns:
[type]: [description]
"""
_call_mode = call_mode.cur if call_mode.prev is not None else self.call_mode
forward_call = {
(True, _callmode("forward")): self._slow_forward,
(False, _callmode("forward")): self.forward,
(False, _callmode("forward_steps")): self.forward_steps,
(False, _callmode("forward_step")): self.forward_step,
}[(bool(torch._C._get_tracing_state()), _call_mode)]
# If we don't have any hooks, we want to skip the rest of the logic in
# this function, and just call forward.
if not (
self._backward_hooks
or self._forward_hooks
or self._forward_pre_hooks
or _global_backward_hooks
or _global_forward_hooks
or _global_forward_pre_hooks
):
return forward_call(*input, **kwargs)
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = [], []
if self._backward_hooks or _global_backward_hooks:
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
if _global_forward_pre_hooks or self._forward_pre_hooks:
for hook in itertools.chain(
_global_forward_pre_hooks.values(), self._forward_pre_hooks.values()
):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
bw_hook = None
if full_backward_hooks:
bw_hook = hooks.BackwardHook(self, full_backward_hooks)
input = bw_hook.setup_input_hook(input)
result = forward_call(*input, **kwargs)
if _global_forward_hooks or self._forward_hooks:
for hook in itertools.chain(
_global_forward_hooks.values(), self._forward_hooks.values()
):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if bw_hook:
result = bw_hook.setup_output_hook(result)
# Handle the non-full backward hooks
if non_full_backward_hooks:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in non_full_backward_hooks:
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
return result
__call__ = _call_impl
[docs] @staticmethod
def build_from(
module: torch.nn.Module,
*args,
**kwargs,
) -> "CoModule":
"""Copy parameters and weights from a non-continual module and
build the corresponding continual version.
Args:
module (torch.nn.Module): Module from which to copy variables and weights
Returns:
CoModule: Continual Module with the parameters and weights of the passed module.
"""