from typing import Callable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.modules.conv import (
_ConvNd,
_pair,
_reverse_repeat_tuple,
_single,
_size_1_t,
_size_2_t,
_size_3_t,
_triple,
)
from .logging import getLogger
from .module import CoModule, PaddingMode
logger = getLogger(__name__)
State = Tuple[Tensor, Tensor, Tensor]
# _forward_step_impl = None
# from pathlib import Path
# try:
# from torch.utils.cpp_extension import load as load_cpp
# _forward_step_impl = load_cpp(
# name="cpp_impl",
# sources=[str(Path(__file__).parent / "conv.cpp")],
# verbose=False,
# ).forward_step
# except Exception as e: # pragma: no cover
# logger.warning(
# "Unable to compile CoConv C++ implementation. Falling back to Python version."
# )
# logger.warning(e)
__all__ = ["Conv1d", "Conv2d", "Conv3d"]
class _ConvCoNd(CoModule, _ConvNd):
def __init__(
self,
ConvClass: torch.nn.Module,
conv_func: Callable,
input_shape_desciption: Tuple[str],
size_fn: Callable,
in_channels: int,
out_channels: int,
kernel_size,
stride,
padding,
dilation,
groups: int = 1,
bias: bool = True,
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
temporal_fill: PaddingMode = "zeros",
):
assert issubclass(
ConvClass, _ConvNd
), "The ConvClass should be a subclass of `_ConvNd`"
kernel_size = size_fn(kernel_size)
padding = size_fn(padding)
stride = size_fn(stride)
self._ConvClass = ConvClass
self._conv_func = conv_func
self.input_shape_desciption = input_shape_desciption
self._input_len = len(self.input_shape_desciption)
if stride[0] > 1:
logger.warning(
f"Temporal stride of {stride[0]} will result in skipped outputs every {stride[0] - 1} / {stride[0]} steps"
)
dilation = size_fn(dilation)
assert dilation[0] == 1, "Temporal dilation > 1 is not supported currently."
self.padding_mode = PaddingMode(padding_mode).value
self.t_padding_mode = PaddingMode(temporal_fill).value
_ConvNd.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
transposed=False,
output_padding=size_fn(0),
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
self._make_padding = {
PaddingMode.ZEROS.value: torch.zeros_like,
PaddingMode.REPLICATE.value: torch.clone,
}[self.t_padding_mode]
# Padding used in for `forward`
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)
# Padding used in for `forward_step`
self._step_space_rprt = _reverse_repeat_tuple((0, *self.padding[1:]), 2)
self._step_time_pad = (
self.kernel_size[0] - 1,
*[0 for _ in self.padding[1:]],
)
self._step_padding = (self.kernel_size[0] - 1, *self.padding[1:])
self._step_stride = (1, *self.stride[1:])
self.register_buffer("state_buffer", torch.tensor([]), persistent=False)
self.register_buffer("state_index", torch.tensor(0), persistent=False)
self.register_buffer("stride_index", torch.tensor(0), persistent=False)
@property
def _stateless(self) -> bool:
return self.kernel_size[0] == 1 and self.padding[0] == 0 and self.stride[0] == 1
@property
def _state_shape(self) -> int:
if self._stateless:
return 0
else:
return 3
@property
def _dynamic_state_inds(self) -> List[bool]:
if self._stateless:
return []
else:
return [True, False, False]
def init_state(
self,
first_output: Tensor,
) -> State:
padding = self._make_padding(first_output)
repeat_shape = [self.kernel_size[0] - 1]
repeat_shape.extend((1,) * len(self.input_shape_desciption))
state_buffer = padding.repeat(repeat_shape)
state_index = torch.tensor(0)
stride_index = torch.tensor(
self.stride[0] - len(state_buffer) - 1 + self.padding[0]
)
return (state_buffer, state_index, stride_index)
def clean_state(self):
self.state_buffer = torch.tensor([], device=self.state_buffer.device)
self.state_index = torch.tensor(0)
self.stride_index = torch.tensor(0)
def get_state(self) -> Optional[State]:
if len(self.state_buffer) > 0:
return (self.state_buffer, self.state_index, self.stride_index)
return None
def set_state(self, state: State):
self.state_buffer, self.state_index, self.stride_index = state
@torch.jit.export
def _forward_step(
self, input: Tensor, prev_state: Optional[State]
) -> Tuple[Optional[Tensor], Optional[State]]:
# assert (
# len(input.shape) == self._input_len - 1
# ), f"A tensor of shape {(*self.input_shape_desciption[:2], *self.input_shape_desciption[3:])} should be passed as input but got {input.shape}"
# if (
# _forward_step_impl is not None
# and not self.training
# and self.padding_mode == "zeros"
# ):
# # Call C++ impl
# output, next_state = _forward_step_impl(
# input,
# self.weight,
# self.bias,
# self.stride,
# self.padding,
# self._step_padding,
# self.dilation,
# self.groups,
# *(prev_state or (None, None, None)),
# )
# if output is None:
# output = None
# return output, next_state
return self._forward_step_py(input, prev_state)
def _forward_step_py(
self, input: Tensor, prev_state: Optional[State]
) -> Tuple[Optional[Tensor], Optional[State]]:
# e.g. B, C -> B, C, 1
x = input.unsqueeze(2).to(device=self.weight.device)
if self._stateless:
return self.forward(x).squeeze(2), prev_state
if self.padding_mode == "zeros":
x = self._conv_func(
input=x,
weight=self.weight,
bias=None,
stride=self._step_stride,
padding=self._step_padding,
dilation=self.dilation,
groups=self.groups,
)
else:
x = self._conv_func(
input=F.pad(x, self._step_space_rprt, mode=self.padding_mode),
weight=self.weight,
bias=None,
stride=self._step_stride,
padding=self._step_time_pad,
dilation=self.dilation,
groups=self.groups,
)
x_out, x_rest = x[:, :, 0], x[:, :, 1:]
# Prepare previous state
if prev_state is None:
buffer, index, stride_index = self.init_state(x_rest)
else:
buffer, index, stride_index = prev_state
assert index is not None
assert stride_index is not None
tot = self.kernel_size[0] - 1
output_is_valid = stride_index == self.stride[0] - 1
if output_is_valid:
x_out = x_out + (
torch.sum(
buffer[
torch.remainder(torch.arange(tot) + index, tot),
:,
:,
torch.arange(tot - 1, -1, -1),
],
dim=0,
)
)
if self.bias is not None:
bias = self.bias.unsqueeze(0)
for _ in range(self._input_len - 3):
bias = bias.unsqueeze(-1)
x_out += bias
# Update next state
if self.kernel_size[0] > 1:
next_buffer = buffer.clone() if self.training else buffer
next_buffer[index] = x_rest
next_index = (index + 1) % tot
else:
next_buffer = buffer
next_index = index
next_stride_index = stride_index + 1
if next_stride_index > 0:
next_stride_index = next_stride_index % self.stride[0]
if output_is_valid:
return x_out, (next_buffer, next_index, next_stride_index)
return None, (next_buffer, next_index, next_stride_index)
def forward_steps(
self, input: Tensor, pad_end: bool = False, update_state: bool = True
) -> Optional[Tensor]:
# assert (
# len(input.shape) == self._input_len
# ), f"A tensor of shape {self.input_shape_desciption} should be passed as input but got {input.shape}."
return self._forward_steps_impl(input, pad_end, update_state)
def forward(self, input: Tensor) -> Tensor:
"""Performs a full forward computation exactly as the regular layer would.
This method is handy for efficient training on clip-based data.
Args:
input (Tensor): Layer input
Returns:
Tensor: Layer output
"""
assert (
len(input.shape) == self._input_len
), f"A tensor of shape {self.input_shape_desciption} should be passed as input but got {input.shape}."
# output = self._ConvClass._conv_forward(self, input, self.weight, self.bias)
if self.padding_mode == "zeros":
output = self._conv_func(
input=input,
weight=self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
else:
output = self._conv_func(
input=F.pad(
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
),
weight=self.weight,
bias=self.bias,
stride=self.stride,
padding=(0,) * len(self.padding),
dilation=self.dilation,
groups=self.groups,
)
return output
@property
def receptive_field(self) -> int:
return self.kernel_size[0] + (self.kernel_size[0] - 1) * (self.dilation[0] - 1)
[docs]class Conv1d(_ConvCoNd):
r"""Continual 1D convolution over a temporal input signal.
Continual Convolutions were proposed by
Hedegaard et al.: "Continual 3D Convolutional Neural Networks for Real-time Processing of Videos", in ECCV (2022),
https://arxiv.org/pdf/2106.00050.pdf (paper) https://www.youtube.com/watch?v=Jm2A7dVEaF4 (video).
Assuming an input of shape `(B, C, T)`, it computes the convolution over one temporal instant `t` at a time
where `t` ∈ `range(T)`, and keeps an internal state.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. NB: stride > 1 over the first channel is not supported. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. NB: padding over the first channel is not supported. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. NB: dilation > 1 over the first channel is not supported. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
temporal_fill (string, optional): ``'zeros'`` or ``'replicate'`` (= "boring video"). `temporal_fill` determines how state is initialised and which padding is applied during `forward_steps` along the temporal dimension. Default: ``'replicate'``
Attributes:
weight (Tensor): the learnable weights of the module of shape
:math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
:math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
state (List[Tensor]): a running buffer of partial computations from previous frames which are used for
the calculation of subsequent outputs.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
temporal_fill: PaddingMode = "zeros",
):
_ConvCoNd.__init__(
self,
nn.Conv1d,
F.conv1d,
("batch_size", "channel", "time"),
_single,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
temporal_fill,
)
@staticmethod
def build_from(
module: nn.Conv1d, temporal_fill: PaddingMode = None, **kwargs
) -> "Conv1d":
comodule = Conv1d(
**{
**dict(
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias is not None,
padding_mode=module.padding_mode,
temporal_fill=temporal_fill or module.padding_mode,
),
**kwargs,
}
)
with torch.no_grad():
comodule.weight.copy_(module.weight)
if module.bias is not None:
comodule.bias.copy_(module.bias)
return comodule
[docs]class Conv2d(_ConvCoNd):
r"""Continual 2D convolution over a spatio-temporal input signal.
Continual Convolutions were proposed by
Hedegaard et al.: "Continual 3D Convolutional Neural Networks for Real-time Processing of Videos", in ECCV (2022),
https://arxiv.org/pdf/2106.00050.pdf (paper) https://www.youtube.com/watch?v=Jm2A7dVEaF4 (video).
Assuming an input of shape `(B, C, T, S)`, it computes the convolution over one temporal instant `t` at a time
where `t` ∈ `range(T)`, and keeps an internal state.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. NB: stride > 1 over the first channel is not supported. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. NB: padding over the first channel is not supported. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. NB: dilation > 1 over the first channel is not supported. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
temporal_fill (string, optional): ``'zeros'`` or ``'replicate'`` (= "boring video"). `temporal_fill` determines how state is initialised and which padding is applied during `forward_steps` along the temporal dimension. Default: ``'replicate'``
Attributes:
weight (Tensor): the learnable weights of the module of shape
:math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
:math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
state (List[Tensor]): a running buffer of partial computations from previous frames which are used for
the calculation of subsequent outputs.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: _size_2_t = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
temporal_fill: PaddingMode = "zeros",
):
_ConvCoNd.__init__(
self,
nn.Conv2d,
F.conv2d,
("batch_size", "channel", "time", "space"),
_pair,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
temporal_fill,
)
@staticmethod
def build_from(
module: nn.Conv2d, temporal_fill: PaddingMode = None, **kwargs
) -> "Conv2d":
comodule = Conv2d(
**{
**dict(
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias is not None,
padding_mode=module.padding_mode,
temporal_fill=temporal_fill or module.padding_mode,
),
**kwargs,
}
)
with torch.no_grad():
comodule.load_state_dict(module.state_dict())
return comodule
[docs]class Conv3d(_ConvCoNd):
r"""Continual 3D convolution over a spatio-temporal input signal.
Continual Convolutions were proposed by
Hedegaard et al.: "Continual 3D Convolutional Neural Networks for Real-time Processing of Videos", in ECCV (2022),
https://arxiv.org/pdf/2106.00050.pdf (paper) https://www.youtube.com/watch?v=Jm2A7dVEaF4 (video).
Assuming an input of shape `(B, C, T, H, W)`, it computes the convolution over one temporal instant `t` at a time
where `t` ∈ `range(T)`, and keeps an internal state. Two forward modes are supported here.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. NB: stride > 1 over the first channel is not supported. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. NB: padding over the first channel is not supported. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. NB: dilation > 1 over the first channel is not supported. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
temporal_fill (string, optional): ``'zeros'`` or ``'replicate'`` (= "boring video"). `temporal_fill` determines how state is initialised and which padding is applied during `forward_steps` along the temporal dimension. Default: ``'replicate'``
Attributes:
weight (Tensor): the learnable weights of the module of shape
:math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
:math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
state (List[Tensor]): a running buffer of partial computations from previous frames which are used for
the calculation of subsequent outputs.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_3_t,
stride: _size_3_t = 1,
padding: _size_3_t = 0,
dilation: _size_3_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: PaddingMode = "zeros",
device=None,
dtype=None,
temporal_fill: PaddingMode = "zeros",
):
_ConvCoNd.__init__(
self,
nn.Conv3d,
F.conv3d,
("batch_size", "channel", "time", "height", "width"),
_triple,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
temporal_fill,
)
@staticmethod
def build_from(
module: nn.Conv3d, temporal_fill: PaddingMode = None, **kwargs
) -> "Conv3d":
comodule = Conv3d(
**{
**dict(
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias is not None,
padding_mode=module.padding_mode,
temporal_fill=temporal_fill or module.padding_mode,
),
**kwargs,
}
)
with torch.no_grad():
comodule.weight.copy_(module.weight)
if module.bias is not None:
comodule.bias.copy_(module.bias)
return comodule