Shortcuts

SingleOutputMultiheadAttention

class continual.SingleOutputMultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=True, device=None, dtype=None, sequence_len=None, query_index=- 1, forward_returns_attn_mask=True, embed_dim_second=False, single_output_forward=True)[source]

MultiHeadAttention which only computes the attention output for the a single query during continual inference.

Continual MHAs were proposed by Hedegaard et al. in “Continual Transformers: Redundancy-Free Attention for Online Inference” https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video).

This module augments the MultiHeadAttention in PyTorch with forward_step / forward_steps functions, in which one / more query, key, and value tokens are passed to yield the multihead attentions, and updated outputs are computed for each token input.

Parameters:
  • embed_dim – total dimension of the model.

  • num_heads – parallel attention heads.

  • dropout – a Dropout layer on attn_output_weights. Default: 0.0.

  • bias – add bias as module parameter. Default: True.

  • add_bias_kv – add bias to the key and value sequences at dim=0.

  • add_zero_attn – add a new batch of zeros to the key and value sequences at dim=1.

  • kdim – total number of features in key. Default: None.

  • vdim – total number of features in value. Default: None.

  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

  • device – torch device to initialize layer on. Defaults to None.

  • dtype – datatype of layer parameters. Defaults to None.

  • sequence_len – Length of token sequence.

  • query_index – The index of the query to compute the attention. Here, -1 denotes the latest query.

  • forward_returns_attn_mask – Whether forward should return attention mask.

  • embed_dim_second – Whether the embed dimension should be second.

  • single_output_forward – Whether forward should be restricted to compute attention for only one query.

Note

If kdim and vdim are None, they will be set to embed_dim such that query, key, and value have the same number of features.

Examples:

mha = co.SingleOutputMultiheadAttention(
    embed_dim=512,
    num_heads=8,
    sequence_len=32,
    dropout=0.0,
    batch_first=True,
    embed_dim_second=True,
)
x = torch.rand(10, 512, 32)

out, attn_mask = mha.forward(x)

# continual inference API
firsts = mha.forward_steps(x[:,:,:-1])
last = mha.forward_step(x[:,:,-1])

assert firsts is None  # The module first needs to observe ``sequence_len`` values
assert torch.allclose(out[:,:,-1], last, atol=1e-6)
forward_step(query, key=None, value=None, update_state=True, *args, **kwargs)[source]
Parameters:
  • query (Tensor) – step_inputs for mapping a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.

  • key (Optional[Tensor]) – step_inputs for mapping a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.

  • value (Optional[Tensor]) – step_inputs for mapping a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.

Return type:

Optional[Tensor]

Shapes for inputs:
  • query: (N,E)(N, E) N is the batch size, E is the embedding dimension.

  • key: (N,E)(N, E), N is the batch size, E is the embedding dimension.

  • value: (N,E)(N, E) N is the batch size, E is the embedding dimension.

Shapes for outputs:
  • attn_output: (N,E)(N, E) where N is the batch size, E is the embedding dimension. (N,L,E)(N, L, E) if batch_first is True.

forward_steps(query, key=None, value=None, update_state=True, *args, **kwargs)[source]

Forward computation for multiple steps with state initialisation

Parameters:
  • query (Tensor) – query.

  • key (Tensor) – key.

  • value (Tensor) – value.

  • update_state (bool) – Whether internal state should be updated during this operation.

Returns:

Stepwise layer outputs

Return type:

Tensor

Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.