Shortcuts

RetroactiveTransformerEncoderLayer

class continual.RetroactiveTransformerEncoderLayer(d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=<function relu>, layer_norm_eps=1e-05, device=None, dtype=None, sequence_len=None)[source]

Continual Retroactive Transformer Encoder layer.

When a new token is received, it computes the updated attention values corresponding to prior tokens as well.

The continual formulation of the Transformer Encoder Layer was proposed by Hedegaard et al. in “Continual Transformers: Redundancy-Free Attention for Online Inference”. https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video).

Note

In order to handle positional encoding correctly for continual input streams, the RecyclingPositionalEncoding should be used together with this module.

Parameters:
  • d_model (int) – the number of expected features in the input (required).

  • nhead (int) – the number of heads in the multiheadattention models (required).

  • dim_feedforward (int) – the dimension of the feedforward network model (default=2048).

  • dropout (float) – the dropout value (default=0.1).

  • activation (Union[Module, Callable[[Tensor], Tensor]]) – the activation function of the intermediate layer, can be a string (“relu” or “gelu”) or a unary callable. Default: relu.

  • layer_norm_eps (float) – the eps value in layer normalization components (default=1e-5).

  • device – torch device to initialize layer on. Defaults to None.

  • dtype – datatype of layer parameters. Defaults to None.

  • sequence_len (Optional[int]) – length of token-sequence to perform attention across. Defaults to None.

Examples:

encoder_layer = co.RetroactiveTransformerEncoderLayer(
    d_model=512, nhead=8, sequence_len=32, dropout=0.0
)
x = torch.rand(10, 512, 32)  # (N, E, T)

# corresponds to torch.nn.TransformerEncoderLayer
out = encoder_layer.forward(x)

# continual inference API
firsts = encoder_layer.forward_steps(x[:,:,:-1])
last = encoder_layer.forward_step(x[:,:,-1])

assert firsts is None  # The module first needs to observe ``sequence_len`` values
assert torch.allclose(out, last, atol=1e-6)
Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.