Shortcuts

Source code for continual.positional_encoding

from typing import Optional, Tuple

import numpy as np
import torch
from torch import Tensor, nn

import continual as co

__all__ = ["RecyclingPositionalEncoding"]


class CyclicPositionalEncoding(nn.Module):
    """Cyclic Positional Encoding as proposed by Ma et al. in
    "Learning to Iteratively Solve Routing Problems with Dual-Aspect Collaborative Transformer"
    https://arxiv.org/abs/2110.02544
    """

    def __init__(self, num_embeddings: int, embedding_dim: int, mean_pooling=True):
        super(CyclicPositionalEncoding, self).__init__()

        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        skip_base = np.power(num_embeddings, 1 / (embedding_dim // 2))
        skip_set = np.linspace(
            skip_base, num_embeddings, embedding_dim // 2, dtype="int"
        )
        x = np.zeros((num_embeddings, embedding_dim))

        def basesin(x, omiga, fai=0):
            T = 2 * np.pi / omiga
            return np.sin(omiga * np.abs(np.mod(x, 2 * T) - T) + fai)

        def basecos(x, omiga, fai=0):
            T = 2 * np.pi / omiga
            return np.cos(omiga * np.abs(np.mod(x, 2 * T) - T) + fai)

        for i in range(embedding_dim):
            # see Appendix B
            skip = (
                skip_set[i // 3 * 3 + 1]
                if (i // 3 * 3 + 1) < (embedding_dim // 2)
                else skip_set[-1]
            )

            # get z(i) in the paper (via longer_pattern)
            if num_embeddings > skip:
                longer_pattern = np.arange(
                    0, np.ceil((num_embeddings) / skip) * skip + 0.01, 0.01
                )
            else:
                longer_pattern = np.arange(0, num_embeddings + 0.01, 0.01)
                skip = num_embeddings

            num = len(longer_pattern) - 1
            omiga = 2 * np.pi / skip

            # see Appendix B
            fai = (
                0
                if i <= (embedding_dim // 2)
                else 2 * np.pi * ((-i + (embedding_dim // 2)) / (embedding_dim // 2))
            )

            # Eq. (4) in the paper
            if i % 2 == 1:
                x[:, i] = basecos(longer_pattern, omiga, fai)[
                    np.linspace(0, num, num_embeddings + 1, dtype="int")
                ][:num_embeddings]
            else:
                x[:, i] = basesin(longer_pattern, omiga, fai)[
                    np.linspace(0, num, num_embeddings + 1, dtype="int")
                ][:num_embeddings]

        pattern = torch.from_numpy(x).type(torch.FloatTensor)
        pattern_sum = torch.zeros_like(pattern)

        # Averaging the adjacient embeddings if needed (optional, almost the same performance)
        arange = torch.arange(num_embeddings)
        pooling = [0] if not mean_pooling else [-2, -1, 0, 1, 2]
        time = 0
        for i in pooling:
            time += 1
            index = (arange + i + num_embeddings) % num_embeddings
            pattern_sum += pattern.gather(0, index.view(-1, 1).expand_as(pattern))
        pattern = 1.0 / time * pattern_sum - pattern.mean(0)
        self.register_buffer("pattern", pattern)

    def forward(self, input: Tensor) -> Tensor:
        return self.pattern[input]


State = Tuple[Tensor]


[docs]class RecyclingPositionalEncoding(co.CoModule, nn.Module): """Recycling Positional Encoding with learned or static weights. Recycling Positional Encoding were proposed by Hedegaard et al. in "Continual Transformers: Redundancy-Free Attention for Online Inference" https://arxiv.org/abs/2201.06268 (paper) https://www.youtube.com/watch?v=gy802Tlp-eQ (video). When static encoding is selected, the module employs "Cyclic Positional Encoding" as proposed by Ma et al. in "Learning to Iteratively Solve Routing Problems with Dual-Aspect Collaborative Transformer" https://arxiv.org/abs/2110.02544. Args: embed_dim: dimensionality of positional embeddings. num_embeds: number of embeddings to recycle among. learned: whether embeddings should be learned or static sinusoidal forward_update_index_steps: the number of index steps to offset the encoding query with each time ``forward`` is called. This ensures that positional encodings have a new starting position at each call. Examples:: pe = RecyclingPositionalEncoding( embed_dim=10, num_embeds=16 * 2 - 1, forward_update_index_steps=0 ) x = torch.zeros((1, 10, 16)) # (B, C, T) o_forward = pe.forward(x) o_forward_steps = pe.forward_steps(x[:, :, :-1]) o_forward_step = pe.forward_step(x[:, :, -1]) assert torch.equal(o_forward[:, :, :-1], o_forward_steps) assert torch.equal(o_forward[:, :, -1], o_forward_step) """ _state_shape = 1 _dynamic_state_inds = [False] def __init__( self, embed_dim: int, num_embeds: int, learned: bool = True, forward_update_index_steps: int = 1, ): nn.Module.__init__(self) self.pe = {True: nn.Embedding, False: CyclicPositionalEncoding}[learned]( num_embeds, embed_dim ) self.register_buffer("state_index", torch.tensor(0), persistent=False) self.forward_update_index_steps = forward_update_index_steps def forward(self, input: Tensor, update_index_steps: int = None) -> Tensor: T = input.shape[2] assert T <= self.pe.num_embeddings position_ids = ( torch.arange(T, device=input.device).unsqueeze(0) + self.state_index ) % self.pe.num_embeddings index_update = ( self.forward_update_index_steps if update_index_steps is None else update_index_steps ) self.state_index = (self.state_index + index_update) % self.pe.num_embeddings position_embeddings = self.pe(position_ids).transpose(1, 2) return input + position_embeddings def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor: return self.forward( input, update_index_steps=input.shape[2] if update_state else 0 ) def forward_step(self, input: Tensor, update_state=True) -> Tensor: output = input + self.pe(self.state_index.unsqueeze(0)) if update_state: self.state_index = (self.state_index + 1) % self.pe.num_embeddings return output def _forward_step( self, input: Tensor, prev_state: Optional[State] = None ) -> Tuple[Tensor, State]: if prev_state is None: state_index = self.init_state()[0] else: state_index = prev_state[0] output = input + self.pe(state_index.unsqueeze(0)) state_index = (state_index + 1) % self.pe.num_embeddings return output, (state_index,) def init_state(self) -> State: self.state_index = torch.tensor(0) return (self.state_index,) def clean_state(self): self.state_index = torch.tensor(0)