Shortcuts

TransformerEncoderLayerFactory

class continual.TransformerEncoderLayerFactory(d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=<function relu>, layer_norm_eps=1e-05, device=None, dtype=None, sequence_len=None)[source]

Defines the hyper-parameters of Continual Transformer Encoder layers, where each layer contains feed forward networks and continual multi-head attentions as proposed by Vaswani et al. in “Attention is all you need”.

It can produce either a SingleOutputTransformerEncoderLayer or a RetroactiveTransformerEncoderLayer. These were proposed by Hedegaard et al. in “Continual Transformers: Redundancy-Free Attention for Online Inference”. https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video).

Parameters:
  • d_model (int) – the number of expected features in the input (required).

  • nhead (int) – the number of heads in the multiheadattention models (required).

  • dim_feedforward (int) – the dimension of the feedforward network model (default=2048).

  • dropout (float) – the dropout value (default=0.1).

  • activation (Union[Module, Callable[[Tensor], Tensor], str]) – the activation function of the intermediate layer, can be a string (“relu” or “gelu”) or a unary callable. Default: relu.

  • layer_norm_eps (float) – the eps value in layer normalization components (default=1e-5).

  • device – torch device to initialize layer on. Defaults to None.

  • dtype – datatype of layer parameters. Defaults to None.

  • sequence_len (Optional[int]) – length of token-sequence to perform attention across. Defaults to None.

Returns:

Factory function return the layer module given

the desired MHA type (one of “retroactive”, “single_output”, “single_output_forward”, and “regular”).

Return type:

Callable[[Union[str,MhaType]], Sequential]

Examples:

encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8, sequence_len=32)
transformer_encoder = co.TransformerEncoder(encoder_layer, num_layers=2)
src = torch.rand(10, 512, 32)
out = transformer_encoder(src)
Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.