Shortcuts

Source code for continual.closure

from functools import partial
from typing import Callable, Union

import torch
from torch import Tensor, nn

from .module import CoModule
from .utils import function_repr

__all__ = [
    "Lambda",
    "Multiply",
    "Add",
    "Identity",
    "Constant",
    "Zero",
    "One",
]


[docs]class Lambda(CoModule, nn.Module): """Module wrapper for stateless functions. .. note:: Operations performed in a Lambda are not counted in `ptflops` Args: fn: Function to be called during forward. forward_only_fn: Function to be called only during ``forward``. ``fn`` is used for the other call modes. forward_step_only_fn: Function to be called only during ``forward_step``. ``fn`` is used for the other call modes. forward_steps_only_fn: Function to be called only during ``forward_steps``. ``fn`` is used for the other call modes. forward_only_fn: Function to be called only during forward. ``fn`` is used for the other call modes. takes_time: If True, ``fn`` receives all steps, if False, it received one step and no time dimension. Defaults to False. Examples:: x = torch.arange(90).reshape(1,3,30) * 1.0 # Using named function def same_stats_different_values(x): return torch.randn_like(x) * x.std() + x.mean() same_stats_layer = co.Lambda(same_stats_different_values) same_stats_layer(x) # Using unnamed function mean_layer = co.Lambda(lambda x: torch.ones_like(x) * x.mean()) mean_layer(x) # Using functor sigmoid = co.Lambda(torch.nn.Sigmoid()) sigmoid(x) """ _state_shape = 0 _dynamic_state_inds = [] def __init__( self, fn: Callable[[Tensor], Tensor] = None, forward_only_fn=None, forward_step_only_fn=None, forward_steps_only_fn=None, takes_time=False, ): nn.Module.__init__(self) assert callable(fn) or all( [ callable(forward_only_fn), callable(forward_step_only_fn), callable(forward_steps_only_fn), ] ), "Either fn or all of forward_only_fn, forward_step_only_fn, and forward_steps_only_fn should be callable." self.fn = fn self.forward_only_fn = forward_only_fn self.forward_step_only_fn = forward_step_only_fn self.forward_steps_only_fn = forward_steps_only_fn self.takes_time = takes_time @staticmethod def build_from( fn: Callable[[Tensor], Tensor], forward_only_fn=None, forward_step_only_fn=None, forward_steps_only_fn=None, takes_time=False, ) -> "Lambda": return Lambda( fn, forward_only_fn, forward_step_only_fn, forward_steps_only_fn, takes_time ) def __repr__(self) -> str: s = self.__class__.__name__ + "(" if callable(self.fn): s += f"{function_repr(self.fn)}" if callable(self.forward_only_fn): if callable(self.fn): s += ", " s += f"{function_repr(self.forward_only_fn)}" if callable(self.forward_step_only_fn): s += f", {function_repr(self.forward_step_only_fn)}" if callable(self.forward_steps_only_fn): s += f", {function_repr(self.forward_steps_only_fn)}" if self.takes_time: s += ", takes_time=True" s += ")" return s def forward(self, input: Tensor) -> Tensor: if self.forward_only_fn is not None: return self.forward_only_fn(input) if self.takes_time: return self.fn(input) return torch.stack( [self.fn(input[:, :, t]) for t in range(input.shape[2])], dim=2 ) def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor: if self.forward_steps_only_fn is not None: return self.forward_steps_only_fn(input) if self.takes_time: return self.fn(input) return torch.stack( [self.fn(input[:, :, t]) for t in range(input.shape[2])], dim=2 ) def forward_step(self, input: Tensor, update_state=True) -> Tensor: return self._forward_step(input)[0] def _forward_step(self, input: Tensor, prev_state=None) -> Tensor: if self.forward_step_only_fn is not None: return self.forward_step_only_fn(input), prev_state if self.takes_time: input = input.unsqueeze(dim=2) output = self.fn(input) if self.takes_time: output = output.squeeze(dim=2) return output, prev_state
def _multiply(x: Tensor, factor: Union[float, int, Tensor]): return x * factor
[docs]def Multiply(factor: float) -> Lambda: r"""Applies an scaling transformation to the incoming data: :math:`y = ax`. Args: factor (float): Number to multiply with. """ fn = partial(_multiply, factor=factor) return Lambda(fn, takes_time=True)
def _add(x: Tensor, constant: Union[float, int, Tensor]): return x + constant
[docs]def Add(constant: float) -> Lambda: r"""Applies an additive translation to the incoming data: :math:`y = x + a`. Args: constant (float): Number to add. """ """Create Lambda with addition function""" fn = partial(_add, constant=constant) return Lambda(fn, takes_time=True)
def _unity(x: Tensor): return x
[docs]def Identity(*args, **kwargs) -> Lambda: """A placeholder identity operator that is argument-insensitive. Args: args: any argument (unused) kwargs: any keyword argument (unused) Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. Examples:: m = co.Identity(54, unused_argument1=0.1, unused_argument2=False) input = torch.randn(128, 20) output = m(input) assert output.size() == torch.Size([128, 20]) """ return Lambda(_unity, takes_time=True)
[docs]def Constant(constant: float) -> Lambda: """Returns ``constant * torch.ones_like(input)``. Arguments: constant: Constant value to return. """ return Lambda(lambda x: constant * torch.ones_like(x), takes_time=True)
[docs]def Zero() -> Lambda: """Returns ``torch.zeros_like(input)``.""" return Lambda(torch.zeros_like, takes_time=True)
[docs]def One() -> Lambda: """Returns ``torch.ones_like(input)``.""" return Lambda(torch.ones_like, takes_time=True)