Shortcuts

Source code for continual.linear

import torch
from torch import Tensor, nn
from torch.nn.functional import linear

from .module import CoModule

__all__ = ["Linear"]


[docs]class Linear(CoModule, nn.Linear): r"""Applies a linear transformation to a dimension of the incoming data: :math:`y = xA^T + b`. This module supports :ref:`TensorFloat32<tf32_on_ampere>`. Args: in_features: size of each input sample out_features: size of each output sample bias: If set to ``False``, the layer will not learn an additive bias. Default: ``True`` channel_dim: Channel dimension index over which to perform linear projection. Default: -1. Shape: - Input: :math:`(B, C_{in}, T, *)` where :math:`*` means any number of additional dimensions and :math:`C_{in} = \text{in\_features}` if `channel_dim = 2`. If channel_dim = -1, the order of input dimensions is :math:`(*, C_{in})`. - Output: :math:`(B, C_{out}, T, *)` where all but the last dimension are the same shape as the input and :math:`C_{out} = \text{out\_features}` if `channel_dim = 2`. If channel_dim = -1, the order of input dimensions is :math:`(*, C_{in})`. Attributes: weight: the learnable weights of the module of shape :math:`(\text{out\_features}, \text{in\_features})`. The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = \frac{1}{\text{in\_features}}` bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{in\_features}}` Examples:: # Use like torch.nn.Linear m = co.Linear(20, 30) input = torch.randn(128, 20) output = m(input) assert output.size() == torch.Size([128, 30]) # Or in conjunction with other continual modules # B C T H W input = torch.randn(1, 3, 16, 128, 128) net = co.Sequential( co.Conv3d(3, 32, 3), co.AdaptiveAvgPool3d((1, 1, 1), 32), co.Linear(32, 10, channel_dim=1), ) output = net(input) assert output.size() == torch.Size([1, 10, 1, 1, 1]) """ _state_shape = 0 _dynamic_state_inds = [] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, channel_dim=-1, ) -> None: nn.Linear.__init__(self, in_features, out_features, bias, device, dtype) self.channel_dim = channel_dim def extra_repr(self): return nn.Linear.extra_repr(self) + f", channel_dim={self.channel_dim}" def forward(self, input: Tensor) -> Tensor: if self.channel_dim != -1: input = input.swapaxes(self.channel_dim, -1) output = linear(input, self.weight, self.bias) # Assumes channel-last if self.channel_dim != -1: output = output.swapaxes(self.channel_dim, -1) return output def forward_step(self, input: Tensor, update_state=True) -> Tensor: return self.forward(input) def _forward_step(self, input: Tensor, prev_state=None) -> Tensor: return self.forward(input), prev_state def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor: return self.forward(input) @staticmethod def build_from( module: nn.Linear, channel_dim=-1, **kwargs, ) -> "Linear": comodule = Linear( **{ **dict( in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, device=module.weight.device, dtype=module.weight.dtype, channel_dim=channel_dim, ), **kwargs, } ) with torch.no_grad(): comodule.weight.copy_(module.weight) if module.bias is not None: comodule.bias.copy_(module.bias) return comodule