RetroactiveTransformerEncoderLayer¶
- class continual.RetroactiveTransformerEncoderLayer(d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=<function relu>, layer_norm_eps=1e-05, device=None, dtype=None, sequence_len=None)[source]¶
Continual Retroactive Transformer Encoder layer.
When a new token is received, it computes the updated attention values corresponding to prior tokens as well.
The continual formulation of the Transformer Encoder Layer was proposed by Hedegaard et al. in “Continual Transformers: Redundancy-Free Attention for Online Inference”. https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video).
Note
In order to handle positional encoding correctly for continual input streams, the
RecyclingPositionalEncoding
should be used together with this module.- Parameters:
d_model (int) – the number of expected features in the input (required).
nhead (int) – the number of heads in the multiheadattention models (required).
dim_feedforward (int) – the dimension of the feedforward network model (default=2048).
dropout (float) – the dropout value (default=0.1).
activation (Union[Module, Callable[[Tensor], Tensor]]) – the activation function of the intermediate layer, can be a string (“relu” or “gelu”) or a unary callable. Default: relu.
layer_norm_eps (float) – the eps value in layer normalization components (default=1e-5).
device – torch device to initialize layer on. Defaults to None.
dtype – datatype of layer parameters. Defaults to None.
sequence_len (Optional[int]) – length of token-sequence to perform attention across. Defaults to None.
Examples:
encoder_layer = co.RetroactiveTransformerEncoderLayer( d_model=512, nhead=8, sequence_len=32, dropout=0.0 ) x = torch.rand(10, 512, 32) # (N, E, T) # corresponds to torch.nn.TransformerEncoderLayer out = encoder_layer.forward(x) # continual inference API firsts = encoder_layer.forward_steps(x[:,:,:-1]) last = encoder_layer.forward_step(x[:,:,-1]) assert firsts is None # The module first needs to observe ``sequence_len`` values assert torch.allclose(out, last, atol=1e-6)