SingleOutputMultiheadAttention¶
- class continual.SingleOutputMultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=True, device=None, dtype=None, sequence_len=None, query_index=- 1, forward_returns_attn_mask=True, embed_dim_second=False, single_output_forward=True)[source]¶
MultiHeadAttention which only computes the attention output for the a single query during continual inference.
Continual MHAs were proposed by Hedegaard et al. in “Continual Transformers: Redundancy-Free Attention for Online Inference” https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video).
This module augments the MultiHeadAttention in PyTorch with forward_step / forward_steps functions, in which one / more query, key, and value tokens are passed to yield the multihead attentions, and updated outputs are computed for each token input.
- Parameters:
embed_dim – total dimension of the model.
num_heads – parallel attention heads.
dropout – a Dropout layer on attn_output_weights. Default: 0.0.
bias – add bias as module parameter. Default: True.
add_bias_kv – add bias to the key and value sequences at dim=0.
add_zero_attn – add a new batch of zeros to the key and value sequences at dim=1.
kdim – total number of features in key. Default: None.
vdim – total number of features in value. Default: None.
batch_first – If
True
, then the input and output tensors are provided as (batch, seq, feature). Default:False
(seq, batch, feature).device – torch device to initialize layer on. Defaults to None.
dtype – datatype of layer parameters. Defaults to None.
sequence_len – Length of token sequence.
query_index – The index of the query to compute the attention. Here, -1 denotes the latest query.
forward_returns_attn_mask – Whether forward should return attention mask.
embed_dim_second – Whether the embed dimension should be second.
single_output_forward – Whether forward should be restricted to compute attention for only one query.
Note
If
kdim
andvdim
are None, they will be set toembed_dim
such that query, key, and value have the same number of features.Examples:
mha = co.SingleOutputMultiheadAttention( embed_dim=512, num_heads=8, sequence_len=32, dropout=0.0, batch_first=True, embed_dim_second=True, ) x = torch.rand(10, 512, 32) out, attn_mask = mha.forward(x) # continual inference API firsts = mha.forward_steps(x[:,:,:-1]) last = mha.forward_step(x[:,:,-1]) assert firsts is None # The module first needs to observe ``sequence_len`` values assert torch.allclose(out[:,:,-1], last, atol=1e-6)
- forward_step(query, key=None, value=None, update_state=True, *args, **kwargs)[source]¶
- Parameters:
query (Tensor) – step_inputs for mapping a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.
key (Optional[Tensor]) – step_inputs for mapping a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.
value (Optional[Tensor]) – step_inputs for mapping a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.
- Return type:
Optional[Tensor]
- Shapes for inputs:
query: N is the batch size, E is the embedding dimension.
key: , N is the batch size, E is the embedding dimension.
value: N is the batch size, E is the embedding dimension.
- Shapes for outputs:
attn_output: where N is the batch size, E is the embedding dimension. if
batch_first
isTrue
.
- forward_steps(query, key=None, value=None, update_state=True, *args, **kwargs)[source]¶
Forward computation for multiple steps with state initialisation
- Parameters:
query (Tensor) – query.
key (Tensor) – key.
value (Tensor) – value.
update_state (bool) – Whether internal state should be updated during this operation.
- Returns:
Stepwise layer outputs
- Return type:
Tensor