Source code for continual.transformer
from collections import OrderedDict
from enum import Enum
from functools import partial, reduce
from typing import Callable, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor, nn
from continual.delay import State as DelayState
from .closure import Identity, Lambda
from .container import BroadcastReduce, Residual, Sequential
from .delay import Delay, PaddingMode
from .linear import Linear
from .multihead_attention import (
RetroactiveMultiheadAttention,
SingleOutputMultiheadAttention,
)
__all__ = [
"TransformerEncoder",
"TransformerEncoderLayerFactory",
"SingleOutputTransformerEncoderLayer",
"RetroactiveTransformerEncoderLayer",
]
class MhaType(Enum):
"""Type of Multi-head Attention
Supported tupes are:
- RETROACTIVE: RetroactiveMultiheadAttention
- SINGLE_OUTPUT: SingleOutputMultiheadAttention
- REGULAR: nn.MultiheadAttention
"""
RETROACTIVE = "retroactive"
SINGLE_OUTPUT = "single_output"
SINGLE_OUTPUT_FORWARD = "single_output_forward"
REGULAR = "regular"
class SelectOrDelay(Delay):
"""Select a temporal index during forward, or delay correspondingly during forward_step(s)"""
def forward(self, x: Tensor) -> Optional[Tensor]:
assert len(x.shape) >= 3 # N, C, T
return x[:, :, -1 - self._delay].unsqueeze(2)
class RetroactiveUnity(Delay):
"""Unity mapping during forward. During forward_step(s), a single-to-many mapping is assumed,
and all cached values are output."""
def __init__(
self,
delay: int,
temporal_fill: PaddingMode = "zeros",
auto_shrink: bool = False,
time_dim=-1,
):
"""Initialise Delay block
Args:
delay (int): the number of steps to delay an output.
temporal_fill (PaddingMode, optional): Temporal state initialisation mode ("zeros" or "replicate"). Defaults to "zeros".
auto_shrink (int, optional): Whether to shrink the temporal dimension of the feature map during forward.
This is handy for residuals that are parallel to modules which reduce the number of temporal steps. Defaults to False.
time_dim (int, optional): Which dimension to concatenate step outputs along
"""
self.time_dim = time_dim
Delay.__init__(self, delay, temporal_fill, auto_shrink)
def init_state(
self,
first_output: Tensor,
) -> DelayState:
padding = self._make_padding(first_output)
state_buffer = torch.stack([padding for _ in range(self.delay + 1)], dim=0)
state_index = torch.tensor(-self.delay)
return state_buffer, state_index
def _forward_step(
self, input: Tensor, prev_state: DelayState
) -> Tuple[Tensor, DelayState]:
if prev_state is None:
buffer, index = self.init_state(input)
else:
buffer, index = prev_state
# Update state
buffer[index % (self.delay + 1)] = input
new_index = index + 1
if new_index > 0:
new_index = new_index % self.delay
# Get output
output = None
if index >= 0:
output = buffer.clone().roll(shifts=int(-index - 1), dims=0)
idx = (
self.time_dim + len(output.shape)
if self.time_dim < 0
else self.time_dim
)
output = output.permute(
list(range(1, idx + 1)) + [0] + list(range(idx + 1, len(output.shape)))
)
return output, (buffer, new_index)
class RetroactiveLambda(Lambda):
"""
Lambda wrapper for functions that are applied after retroactive modules.
"""
def forward(self, input: Tensor) -> Tensor:
return Lambda.forward(self, input)
def forward_step(self, input: Tensor, *args, **kwargs) -> Tensor:
return self.forward(input)
def _forward_step(self, input: Tensor, prev_state=None, *args, **kwargs) -> Tensor:
return self.forward(input), prev_state
def forward_steps(self, input: Tensor, *args, **kwargs) -> Tensor:
return torch.stack(
[self.forward(input[:, :, t]) for t in range(input.shape[2])], dim=2
)
@staticmethod
def build_from(
fn: Callable[[Tensor], Tensor], takes_time=False
) -> "RetroactiveLambda":
return RetroactiveLambda(fn, takes_time)
class NaiveResidual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
def sum_last_pairs(inputs: Sequence[Tensor]) -> Tensor:
if inputs[0].shape != inputs[1].shape:
T_min = min(inputs[i].shape[2] for i in range(len(inputs)))
inputs = [inp[:, :, -T_min:] for inp in inputs]
return reduce(torch.Tensor.add, inputs[1:], inputs[0])
[docs]def SingleOutputTransformerEncoderLayer(
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[nn.Module, Callable[[Tensor], Tensor]] = nn.functional.relu,
layer_norm_eps: float = 1e-5,
# batch_first: bool = True,
# norm_first: bool = False,
device=None,
dtype=None,
sequence_len: int = None,
single_output_forward=False,
query_index: int = -1,
):
"""Continual Single-output Transformer Encoder layer.
Contrary to the ``torch.nn.TransformerEncoderLayer``, this layer only computes the attention
for the last query during ``forward_step``. The behavior during ``forward`` is controllable
with the ``single_output_forward`` parameter.
The continual formulation of the Transformer Encoder Layer was proposed by Hedegaard et al.
in "Continual Transformers: Redundancy-Free Attention for Online Inference".
https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video).
.. note::
In order to handle positional encoding correctly for continual input streams, the
:class:`RecyclingPositionalEncoding` should be used together with this module.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu.
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
device: torch device to initialize layer on. Defaults to None.
dtype: datatype of layer parameters. Defaults to None.
sequence_len: length of token-sequence to perform attention across. Defaults to None.
single_output_forward: whether to restrict the attention to the last token during forward. Defaults to False.
query_index: the sequence position index to compute the attention for.
Examples::
encoder_layer = co.SingleOutputTransformerEncoderLayer(
d_model=512, nhead=8, sequence_len=32, dropout=0.0
)
x = torch.rand(10, 512, 32) # (N, E, T)
# corresponds to torch.nn.TransformerEncoderLayer
out = encoder_layer.forward(x)
# continual inference API
firsts = encoder_layer.forward_steps(x[:,:,:-1])
last = encoder_layer.forward_step(x[:,:,-1])
assert firsts is None # The module first needs to observe ``sequence_len`` values
assert torch.allclose(out[:,:,-1], last, atol=1e-6)
"""
assert (
sequence_len > 0
), "Please provide a positive integer value as sequence length."
factory_kwargs = {"device": device, "dtype": dtype}
norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
mha = SingleOutputMultiheadAttention(
embed_dim=d_model,
num_heads=nhead,
dropout=dropout,
bias=True,
batch_first=True,
embed_dim_second=True,
query_index=query_index,
device=device,
dtype=dtype,
sequence_len=sequence_len,
forward_returns_attn_mask=False,
single_output_forward=single_output_forward,
)
ff = Sequential(
OrderedDict(
[
(
"linear1",
Linear(d_model, dim_feedforward, channel_dim=1, **factory_kwargs),
),
(
"activation",
activation,
),
(
"dropout",
nn.Dropout(dropout),
),
(
"linear2",
Linear(dim_feedforward, d_model, channel_dim=1, **factory_kwargs),
),
(
"dropout2",
nn.Dropout(dropout),
),
]
)
)
return Sequential(
BroadcastReduce(
OrderedDict(
[
(
"residual",
SelectOrDelay(mha.delay)
if single_output_forward
else Identity(),
),
(
"self_attn",
mha,
),
]
),
reduce=sum_last_pairs,
auto_delay=False,
),
Sequential(
OrderedDict(
[
("norm1", Lambda(norm1, takes_time=False)),
("_ff_block", Residual(ff)),
("norm2", Lambda(norm2, takes_time=False)),
]
)
),
)
[docs]def RetroactiveTransformerEncoderLayer(
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[nn.Module, Callable[[Tensor], Tensor]] = nn.functional.relu,
layer_norm_eps: float = 1e-5,
# batch_first: bool = True,
# norm_first: bool = False,
device=None,
dtype=None,
sequence_len: int = None,
):
"""Continual Retroactive Transformer Encoder layer.
When a new token is received, it computes the updated attention values corresponding
to prior tokens as well.
The continual formulation of the Transformer Encoder Layer was proposed by Hedegaard et al.
in "Continual Transformers: Redundancy-Free Attention for Online Inference".
https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video).
.. note::
In order to handle positional encoding correctly for continual input streams, the
:class:`RecyclingPositionalEncoding` should be used together with this module.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu.
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
device: torch device to initialize layer on. Defaults to None.
dtype: datatype of layer parameters. Defaults to None.
sequence_len: length of token-sequence to perform attention across. Defaults to None.
Examples::
encoder_layer = co.RetroactiveTransformerEncoderLayer(
d_model=512, nhead=8, sequence_len=32, dropout=0.0
)
x = torch.rand(10, 512, 32) # (N, E, T)
# corresponds to torch.nn.TransformerEncoderLayer
out = encoder_layer.forward(x)
# continual inference API
firsts = encoder_layer.forward_steps(x[:,:,:-1])
last = encoder_layer.forward_step(x[:,:,-1])
assert firsts is None # The module first needs to observe ``sequence_len`` values
assert torch.allclose(out, last, atol=1e-6)
"""
assert (
sequence_len > 0
), "Please provide a positive integer value as sequence length."
factory_kwargs = {"device": device, "dtype": dtype}
norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
mha = RetroactiveMultiheadAttention(
embed_dim=d_model,
num_heads=nhead,
dropout=dropout,
bias=True,
batch_first=True,
embed_dim_second=True,
device=device,
dtype=dtype,
sequence_len=sequence_len,
forward_returns_attn_mask=False,
)
ff = Sequential(
OrderedDict(
[
("linear1", Linear(d_model, dim_feedforward, **factory_kwargs)),
("activation", Lambda(activation, takes_time=True)),
("dropout", nn.Dropout(dropout)),
("linear2", Linear(dim_feedforward, d_model, **factory_kwargs)),
("dropout2", nn.Dropout(dropout)),
]
)
)
return Sequential(
BroadcastReduce(
OrderedDict(
[
("residual", RetroactiveUnity(mha.delay)),
("self_attn", mha),
]
),
reduce="sum",
auto_delay=False,
),
RetroactiveLambda(
nn.Sequential(
OrderedDict(
[
("norm1", norm1),
("_ff_block", NaiveResidual(ff)),
("norm2", norm2),
]
)
)
),
)
# TODO: impl
def StepLocalTransformerEncoderLayer(
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[nn.Module, Callable[[Tensor], Tensor]] = nn.functional.relu,
layer_norm_eps: float = 1e-5,
# batch_first: bool = True,
# norm_first: bool = False,
device=None,
dtype=None,
sequence_len: int = None,
):
...
[docs]def TransformerEncoderLayerFactory(
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[nn.Module, Callable[[Tensor], Tensor], str] = nn.functional.relu,
layer_norm_eps: float = 1e-5,
# batch_first: bool = True,
# norm_first: bool = False,
device=None,
dtype=None,
sequence_len: int = None,
) -> Callable[[MhaType], Sequential]:
"""Defines the hyper-parameters of Continual Transformer Encoder layers, where each layer
contains feed forward networks and continual multi-head attentions as proposed by
Vaswani et al. in "Attention is all you need".
It can produce either a :class:`SingleOutputTransformerEncoderLayer` or a
:class:`RetroactiveTransformerEncoderLayer`.
These were proposed by Hedegaard et al. in
"Continual Transformers: Redundancy-Free Attention for Online Inference".
https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video).
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu.
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
device: torch device to initialize layer on. Defaults to None.
dtype: datatype of layer parameters. Defaults to None.
sequence_len: length of token-sequence to perform attention across. Defaults to None.
Returns:
Callable[[Union[str,MhaType]], Sequential]: Factory function return the layer module given
the desired MHA type (one of "retroactive", "single_output", "single_output_forward", and "regular").
Examples::
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8, sequence_len=32)
transformer_encoder = co.TransformerEncoder(encoder_layer, num_layers=2)
src = torch.rand(10, 512, 32)
out = transformer_encoder(src)
"""
if activation in {"relu", "gelu"}:
activation = {
"relu": nn.functional.relu,
"gelu": nn.functional.relu,
}[activation]
def TransformerEncoderLayer(mha_type: MhaType):
factory_fn = {
MhaType.RETROACTIVE: RetroactiveTransformerEncoderLayer,
MhaType.SINGLE_OUTPUT: SingleOutputTransformerEncoderLayer,
MhaType.SINGLE_OUTPUT_FORWARD: partial(
SingleOutputTransformerEncoderLayer, single_output_forward=True
),
MhaType.REGULAR: StepLocalTransformerEncoderLayer,
}[MhaType(mha_type)]
return factory_fn(
d_model,
nhead,
dim_feedforward,
dropout,
activation,
layer_norm_eps,
# batch_first
# norm_first
device,
dtype,
sequence_len,
)
return TransformerEncoderLayer
[docs]class TransformerEncoder(Sequential):
"""Continual Transformer Encoder is a stack of N encoder layers.
The continual formulation of the Transformer Encoder was proposed by Hedegaard et al.
in "Continual Transformers: Redundancy-Free Attention for Online Inference".
https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video).
.. note::
This class deviates from the Pytorch implementation in the following ways:
1) `encoder_layer` parameter takes a factory functor, TransformerEncoderLayerFactory
2) `mask` and `src_key_padding_mask` are not supported currently.
.. note::
The efficiency gains of ``forward_step`` compared to ``forward`` is highly dependent
on the chosen ``num_layers``. Here, a lower ``num_layers`` is most efficient.
Accordingly, we recommend increasing ``d_model``, ``nhead``, and ``dim_feedforward``
of the :class:`TransformerEncoderLayerFactory` rather than increasing ``num_layers`` if larger
models are desired. Keeping the parameter-count equal, this was found to work well
for regular Transformer Encoders as well (https://arxiv.org/pdf/2210.00640.pdf).
.. note::
In order to handle positional encoding correctly for continual input streams, the
:class:`RecyclingPositionalEncoding` should be used together with this module.
Args:
encoder_layer: An instance of :class:`TransformerEncoderLayerFactory`.
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples::
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8, sequence_len=32)
transformer_encoder = co.TransformerEncoder(encoder_layer, num_layers=2)
src = torch.rand(10, 512, 32)
out = transformer_encoder(src)
"""
def __init__(
self,
encoder_layer: Callable[[MhaType, Optional[bool]], Sequential],
num_layers: int,
norm: nn.Module = None,
):
layers = []
if num_layers == 1:
layers.append(encoder_layer(MhaType.SINGLE_OUTPUT))
else:
layers.append(encoder_layer(MhaType.RETROACTIVE))
for _ in range(1, num_layers - 1):
layers.append(
RetroactiveLambda(encoder_layer(MhaType.REGULAR), takes_time=True)
)
layers.append(
RetroactiveLambda(
encoder_layer(MhaType.SINGLE_OUTPUT_FORWARD), takes_time=True
)
)
def unity(x):
return x
def squeeze_last(x):
return x.squeeze(-1)
layers.append(
Lambda(unity, None, squeeze_last, squeeze_last, takes_time=True)
)
Sequential.__init__(self, OrderedDict([("layers", Sequential(*layers))]))
if norm is not None:
self.add_module("norm", Lambda(norm, takes_time=False))
@staticmethod
def build_from(
trans_enc: nn.TransformerEncoder, sequence_len: int
) -> "TransformerEncoder":
assert isinstance(trans_enc, nn.TransformerEncoder)
# Create model
tel = trans_enc.layers[0]
layer_factory = TransformerEncoderLayerFactory(
d_model=tel.self_attn.embed_dim,
nhead=tel.self_attn.num_heads,
dim_feedforward=tel.linear1.out_features,
dropout=tel.dropout.p,
activation=tel.activation,
layer_norm_eps=tel.norm1.eps,
device=tel.linear1.weight.device,
dtype=tel.linear1.weight.dtype,
sequence_len=sequence_len,
)
net = TransformerEncoder(
layer_factory, num_layers=trans_enc.num_layers, norm=trans_enc.norm
)
# Transfer weights
new_sd = {}
net_keys = list(net.state_dict().keys())
match_keys, key_inds = zip(
*sorted(
[
(
k.replace("fn.", "")
.replace("_ff_block.", "")
.replace(".0.self_attn", ".self_attn"),
i,
)
for i, k in enumerate(net_keys)
],
key=lambda x: x[0],
)
)
reg_keys, weights = zip(
*sorted(
[item for item in trans_enc.state_dict().items()], key=lambda x: x[0]
)
)
assert all(
[
".".join(k1.split(".")[-2:]) == ".".join(k2.split(".")[-2:])
for k1, k2 in zip(match_keys, reg_keys)
]
)
new_sd = {net_keys[key_inds[i]]: weights[i] for i in range(len(net_keys))}
net.load_state_dict(new_sd)
return net