Shortcuts

Source code for continual.multihead_attention.single_output_mha

import math
from functools import partial
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from continual.logging import getLogger
from continual.module import _callmode

from .mha_base import MultiheadAttentionBase, scaled_dot_prod_attn_flops

logger = getLogger(__name__)
logger_once = getLogger(__name__, log_once=True)

State = Tuple[
    Tensor,  # Q_mem, (B, Nt-1, E)
    Tensor,  # K_T_mem, (B, E, Ns)
    Tensor,  # V_mem, (B, Ns, E)
    Tensor,  # stride_index
]


def _scaled_dot_product_attention_default_state(
    batch_size: int,
    sequence_len: int,
    embed_dim: int,
    num_heads: int,
    query_index=-1,
    init_fn=torch.zeros,
    dtype=None,
    device=None,
):
    init_fn = partial(init_fn, dtype=dtype, device=device)
    E = embed_dim // num_heads
    B = batch_size * num_heads
    N = sequence_len
    Nq = sequence_len - query_index - 1 if query_index >= 0 else -query_index - 1
    Q_mem = init_fn((B, Nq, E))
    K_T_mem = init_fn((B, E, N))
    V_mem = init_fn((B, N, E))
    return (Q_mem, K_T_mem, V_mem)


def _scaled_dot_product_attention_step(
    prev_state: State,
    q_step: Tensor,  # step input (B, E)
    k_step: Tensor,  # step input (B, E)
    v_step: Tensor,  # step input (B, E)
    attn_mask: Optional[Tensor] = None,
    dropout_p: float = 0.0,
) -> Tuple[Tensor, State]:
    """
    Computes the Continual Singe-output Scaled Dot-Product Attention on query, key and value tensors.
    Returns attended values and updated states.

    Args:
        q_step, k_step, v_step: query, key and value tensors for a step. See Shape section for shape details.
        attn_mask: optional tensor containing mask values to be added to calculated
            attention. May be 2D or 3D; see Shape section for details.
        dropout_p: dropout probability. If greater than 0.0, dropout is applied.

    Shape:
        - q_step: :math:`(B, E)` where B is batch size and E is embedding dimension.
        - k_step: :math:`(B, E)` where B is batch size and E is embedding dimension.
        - v_step: :math:`(B, E)` where B is batch size and E is embedding dimension.

        - Output: attention values have shape :math:`(B, Nt, E)`; new state
    """
    if attn_mask is not None:  # pragma: no cover
        logger_once.warning(
            "attn_mask is not supported in scaled_dot_product_attention yet and will be skipped"
        )
    if dropout_p != 0.0:  # pragma: no cover
        logger_once.warning(
            "dropout_p is not supported in continual scaled_dot_product_attention yet and will be skipped"
        )

    (
        Q_mem,  # (B, Nq, E)
        K_T_mem,  # (B, E, Ns)
        V_mem,  # (B, Ns, E)
    ) = prev_state

    B, E = q_step.shape
    q_step = q_step / math.sqrt(E)
    q_sel = (Q_mem[:, 0] if Q_mem.shape[1] > 0 else q_step).unsqueeze(1)

    # Update states
    # Note: We're allowing the K and V mem to have one more entry than
    # strictly necessary to simplify computatations.

    K_T_new = torch.roll(K_T_mem, shifts=-1, dims=(2,))
    K_T_new[:, :, -1] = k_step

    V_new = torch.roll(V_mem, shifts=-1, dims=(1,))
    V_new[:, -1] = v_step

    attn = torch.bmm(q_sel, K_T_new)
    attn_sm = F.softmax(attn, dim=-1)

    if dropout_p > 0.0:  # pragma: no cover
        attn_sm = F.dropout(attn_sm, p=dropout_p)

    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
    output = torch.bmm(attn_sm, V_new)

    if Q_mem.shape[1] > 0:
        Q_new = torch.roll(Q_mem, shifts=-1, dims=(1,))
        Q_new[:, -1] = q_step
    else:
        Q_new = Q_mem

    new_states = (Q_new, K_T_new, V_new)

    return output, new_states


[docs]class SingleOutputMultiheadAttention(MultiheadAttentionBase): """ MultiHeadAttention which only computes the attention output for the a single query during continual inference. Continual MHAs were proposed by Hedegaard et al. in "Continual Transformers: Redundancy-Free Attention for Online Inference" https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video). This module augments the MultiHeadAttention in PyTorch with `forward_step` / `forward_steps` functions, in which one / more query, key, and value tokens are passed to yield the multihead attentions, and updated outputs are computed for each token input. Args: embed_dim: total dimension of the model. num_heads: parallel attention heads. dropout: a Dropout layer on attn_output_weights. Default: 0.0. bias: add bias as module parameter. Default: True. add_bias_kv: add bias to the key and value sequences at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. kdim: total number of features in key. Default: None. vdim: total number of features in value. Default: None. batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). device: torch device to initialize layer on. Defaults to None. dtype: datatype of layer parameters. Defaults to None. sequence_len: Length of token sequence. query_index: The index of the query to compute the attention. Here, -1 denotes the latest query. forward_returns_attn_mask: Whether forward should return attention mask. embed_dim_second: Whether the embed dimension should be second. single_output_forward: Whether forward should be restricted to compute attention for only one query. .. note:: If :attr:`kdim` and :attr:`vdim` are None, they will be set to :attr:`embed_dim` such that query, key, and value have the same number of features. Examples:: mha = co.SingleOutputMultiheadAttention( embed_dim=512, num_heads=8, sequence_len=32, dropout=0.0, batch_first=True, embed_dim_second=True, ) x = torch.rand(10, 512, 32) out, attn_mask = mha.forward(x) # continual inference API firsts = mha.forward_steps(x[:,:,:-1]) last = mha.forward_step(x[:,:,-1]) assert firsts is None # The module first needs to observe ``sequence_len`` values assert torch.allclose(out[:,:,-1], last, atol=1e-6) """ _state_shape = 4 # Q_mem, K_T_mem, V_mem, stride_index _dynamic_state_inds = [True, True, True, False] def __init__( self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=True, device=None, dtype=None, sequence_len=None, query_index=-1, forward_returns_attn_mask=True, embed_dim_second=False, single_output_forward=True, ) -> None: MultiheadAttentionBase.__init__( self, embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype, sequence_len, partial( _scaled_dot_product_attention_default_state, sequence_len=sequence_len, embed_dim=embed_dim, num_heads=num_heads, query_index=query_index, ), _scaled_dot_product_attention_step, forward_returns_attn_mask, embed_dim_second, ) assert query_index < sequence_len self.query_index = query_index self.single_output_forward = single_output_forward self.register_buffer("Q_mem", torch.tensor([]), persistent=False) self.register_buffer("K_T_mem", torch.tensor([]), persistent=False) self.register_buffer("V_mem", torch.tensor([]), persistent=False) self.register_buffer("stride_index", torch.tensor(0), persistent=False) def get_state(self) -> Optional[State]: if len(self.V_mem) > 0: return ( self.Q_mem, self.K_T_mem, self.V_mem, self.stride_index, ) return None def set_state(self, state: State): ( self.Q_mem, self.K_T_mem, self.V_mem, self.stride_index, ) = state def clean_state(self): self.Q_mem = torch.tensor([], device=self.Q_mem.device) self.K_T_mem = torch.tensor([], device=self.K_T_mem.device) self.V_mem = torch.tensor([], device=self.V_mem.device) self.stride_index = torch.tensor(0) @property def delay(self) -> int: return ( self.sequence_len - self.query_index - 1 if self.query_index >= 0 else -self.query_index - 1 ) def forward( self, query: Tensor, key: Tensor = None, value: Tensor = None, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: if not self.single_output_forward: return MultiheadAttentionBase.forward( self, query, key, value, key_padding_mask, need_weights, attn_mask ) if key is None: key = query if value is None: value = query # Select a single query entry if self.batch_first: if self.embed_dim_second: query = query[:, :, self.query_index].unsqueeze(2) else: query = query[:, self.query_index].unsqueeze(1) else: query = query[self.query_index].unsqueeze(0) o = MultiheadAttentionBase.forward( self, query, key, value, key_padding_mask, need_weights, attn_mask ) return o def _forward_step( self, query: Tensor, key: Tensor = None, value: Tensor = None, prev_state: Optional[State] = None, ) -> Tuple[Optional[Tensor], State]: """ Args: query, key, value: step_inputs for mapping a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. Shapes for inputs: - query: :math:`(N, E)` N is the batch size, E is the embedding dimension. - key: :math:`(N, E)`, N is the batch size, E is the embedding dimension. - value: :math:`(N, E)` N is the batch size, E is the embedding dimension. Shapes for outputs: - attn_output: :math:`(N, E)` where N is the batch size, E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. """ if key is None: key = query if value is None: value = query o, next_state = MultiheadAttentionBase._forward_step( self, query, key, value, prev_state ) if o is not None: o = o.squeeze(0) return o, next_state
[docs] def forward_step( self, query: Tensor, key: Tensor = None, value: Tensor = None, update_state=True, *args, **kwargs, ) -> Optional[Tensor]: """ Args: query, key, value: step_inputs for mapping a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. Shapes for inputs: - query: :math:`(N, E)` N is the batch size, E is the embedding dimension. - key: :math:`(N, E)`, N is the batch size, E is the embedding dimension. - value: :math:`(N, E)` N is the batch size, E is the embedding dimension. Shapes for outputs: - attn_output: :math:`(N, E)` where N is the batch size, E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. """ o = MultiheadAttentionBase.forward_step( self, query, key, value, update_state, *args, **kwargs ) return o.squeeze(1 if self.batch_first else 0) if isinstance(o, Tensor) else o
[docs] def forward_steps( self, query: Tensor, key: Tensor = None, value: Tensor = None, update_state=True, *args, **kwargs, ) -> Optional[Tensor]: """Forward computation for multiple steps with state initialisation Args: query (Tensor): query. key (Tensor): key. value (Tensor): value. update_state (bool): Whether internal state should be updated during this operation. Returns: Tensor: Stepwise layer outputs """ o = MultiheadAttentionBase.forward_steps( self, query, key, value, update_state, *args, **kwargs ) if isinstance(o, Tensor): o = o.squeeze(2) if self.embed_dim_second: o = o.transpose(1, 2) # N T E -> N E T return o
def flops(self, include_muls=True, include_adds=False, include_exps=False): f = 0 # Linear projection steps_taken = { _callmode("forward"): 1 if self.single_output_forward else self.sequence_len, _callmode("forward_step"): 1, }[self.call_mode] f += ( steps_taken * self.embed_dim * self.embed_dim * 3 # Assuming equal len for Q, K, and V ) if include_adds: f += 3 * steps_taken * self.embed_dim * (self.embed_dim - 1) if self.in_proj_bias is not None: f += 3 * steps_taken * self.embed_dim if include_adds: f += 3 * steps_taken * self.embed_dim # Multi-head Scaled Dot-Product Attention f += self.num_heads * { _callmode("forward"): single_output_scaled_dot_prod_attn_flops if self.single_output_forward else scaled_dot_prod_attn_flops, _callmode("forward_step"): single_output_scaled_dot_prod_attn_flops, }[self.call_mode]( self.sequence_len, self.embed_dim // self.num_heads, include_muls, include_adds, include_exps, ) # Linear projection f += 1 * self.embed_dim * (self.embed_dim + 1) return f
def single_output_scaled_dot_prod_attn_flops( sequence_len, embed_dim, include_muls=True, include_adds=False, include_exps=False ): n = sequence_len d = embed_dim flops = 0 if include_muls: flops += 2 * n * d + 2 * d if include_adds: flops += 2 * n * d - d - 1 if include_exps: flops += n return flops

© Copyright Copyright (c) 2021-2023, Lukas Hedegaard. Revision b75acad6.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.