SingleOutputTransformerEncoderLayer¶
- class continual.SingleOutputTransformerEncoderLayer(d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=<function relu>, layer_norm_eps=1e-05, device=None, dtype=None, sequence_len=None, single_output_forward=False, query_index=-1)[source]¶
Continual Single-output Transformer Encoder layer.
Contrary to the
torch.nn.TransformerEncoderLayer
, this layer only computes the attention for the last query duringforward_step
. The behavior duringforward
is controllable with thesingle_output_forward
parameter.The continual formulation of the Transformer Encoder Layer was proposed by Hedegaard et al. in “Continual Transformers: Redundancy-Free Attention for Online Inference”. https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video).
Note
In order to handle positional encoding correctly for continual input streams, the
RecyclingPositionalEncoding
should be used together with this module.- Parameters:
d_model (int) – the number of expected features in the input (required).
nhead (int) – the number of heads in the multiheadattention models (required).
dim_feedforward (int) – the dimension of the feedforward network model (default=2048).
dropout (float) – the dropout value (default=0.1).
activation (Union[Module, Callable[[Tensor], Tensor]]) – the activation function of the intermediate layer, can be a string (“relu” or “gelu”) or a unary callable. Default: relu.
layer_norm_eps (float) – the eps value in layer normalization components (default=1e-5).
device – torch device to initialize layer on. Defaults to None.
dtype – datatype of layer parameters. Defaults to None.
sequence_len (Optional[int]) – length of token-sequence to perform attention across. Defaults to None.
single_output_forward – whether to restrict the attention to the last token during forward. Defaults to False.
query_index (int) – the sequence position index to compute the attention for.
Examples:
encoder_layer = co.SingleOutputTransformerEncoderLayer( d_model=512, nhead=8, sequence_len=32, dropout=0.0 ) x = torch.rand(10, 512, 32) # (N, E, T) # corresponds to torch.nn.TransformerEncoderLayer out = encoder_layer.forward(x) # continual inference API firsts = encoder_layer.forward_steps(x[:,:,:-1]) last = encoder_layer.forward_step(x[:,:,-1]) assert firsts is None # The module first needs to observe ``sequence_len`` values assert torch.allclose(out[:,:,-1], last, atol=1e-6)