Source code for continual.delay
from typing import Tuple, Union
import torch
from torch import Tensor
from .module import CoModule, PaddingMode
from .utils import temporary_parameter
State = Tuple[Tensor, Tensor]
__all__ = ["Delay"]
[docs]class Delay(CoModule, torch.nn.Module):
"""Delay an input by a number of steps.
This module only introduces a delay in the continual modes, i.e. on ``forward_step`` and ``forward_steps``.
In essence it caches the input for ``delay`` steps before outputting it again.
The ``Delay`` modules is used extensively in various container modules to align delays of
different computational branches. For instance, it is used to align the :class:`Residual` module
as shown in the example below.
Arguments:
delay: The number of steps to delay an output.
temporal_fill: Temporal state initialisation mode ("zeros" or "replicate")
auto_shrink: Whether to shrink the temporal dimension of the feature map during forward.
This module is handy for residuals that are parallel to modules which reduce the number of temporal steps.
Options: "centered" or True: Centered residual shrink; "lagging": lagging shrink.
Examples::
conv = co.Conv3d(32, 32, kernel_size=3, padding=1)
residual = co.BroadcastReduce(conv, co.Delay(2), reduce="sum")
"""
@property
def _state_shape(self):
return 2 if self.delay > 0 else 0
@property
def _dynamic_state_inds(self):
return [True, False] if self.delay > 0 else []
def __init__(
self,
delay: int,
temporal_fill: PaddingMode = "zeros",
auto_shrink: Union[bool, str] = False,
):
assert delay >= 0
self._delay = delay
assert auto_shrink in {True, False, "centered", "lagging"}
self.auto_shrink = auto_shrink
assert temporal_fill in {"zeros", "replicate"}
self._make_padding = {"zeros": torch.zeros_like, "replicate": torch.clone}[
temporal_fill
]
super(Delay, self).__init__()
self.register_buffer("state_buffer", torch.tensor([]), persistent=False)
self.register_buffer("state_index", torch.tensor(0), persistent=False)
def init_state(
self,
first_output: Tensor,
) -> State:
padding = self._make_padding(first_output)
state_buffer = torch.stack([padding for _ in range(self.delay)], dim=0)
state_index = torch.tensor(
-2 * self.delay
if self.auto_shrink and isinstance(self.auto_shrink, bool)
else -self.delay
)
return state_buffer, state_index
def clean_state(self):
self.state_buffer = torch.tensor([], device=self.state_buffer.device)
self.state_index = torch.tensor(0)
def get_state(self):
if len(self.state_buffer) > 0:
return (self.state_buffer, self.state_index)
return None
def set_state(self, state: State):
self.state_buffer, self.state_index = state
def _forward_step(self, input: Tensor, prev_state: State) -> Tuple[Tensor, State]:
if self._delay == 0:
return input, prev_state
if prev_state is None:
buffer, index = self.init_state(input)
else:
buffer, index = prev_state
# Get output
if index >= 0:
output = buffer[index].clone()
else:
output = None
# Update state
buffer[index % self.delay] = input
new_index = index + 1
if new_index > 0:
new_index = new_index % self.delay
return output, (buffer, new_index)
def forward_step(self, input: Tensor, update_state=True) -> Tensor:
if self._delay == 0:
return input
return CoModule.forward_step(self, input, update_state)
def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
if self._delay == 0:
return input
with temporary_parameter(self, "padding", (self.delay,)):
output = CoModule.forward_steps(self, input, pad_end, update_state)
return output
def forward(self, input: Tensor) -> Tensor:
# No delay during regular forward
if not self.auto_shrink or self.delay == 0:
return input
if self.auto_shrink == "lagging":
return input[:, :, : -self.delay]
return input[:, :, self.delay : -self.delay]
@property
def receptive_field(self) -> int:
return self.delay + 1
@property
def delay(self) -> int:
return self._delay
@property
def stride(self) -> Tuple[int]:
return (1,)
def extra_repr(self):
shrink_str = ", auto_shrink=True" if self.auto_shrink else ""
return f"{self.delay}" + shrink_str