Source code for continual.shape
from typing import Sequence, overload
from torch import Tensor, nn
from .module import CoModule
__all__ = ["Reshape"]
[docs]class Reshape(CoModule, nn.Module):
"""Reshape non-temporal dimensions of an input
Arguments:
shape: The required shape of non-temporal dimensions.
contiguous: Whether reshaped output should be made contiguous.
"""
_state_shape = 0
_dynamic_state_inds = []
@overload
def __init__(self, shape: Sequence[int], contiguous: bool = False):
... # pragma: no cover
@overload
def __init__(self, *shape: int, contiguous=False):
... # pragma: no cover
def __init__(self, *shape, contiguous=False):
nn.Module.__init__(self)
self.contiguous = contiguous
assert len(shape) > 0
if isinstance(shape[0], int):
self.shape = shape
else:
assert isinstance(shape[0], Sequence)
assert isinstance(shape[0][0], int)
self.shape = shape[0]
def extra_repr(self):
s = f"{self.shape}"
if self.contiguous:
s += ", contiguous=True"
return s
def forward(self, input: Tensor) -> Tensor:
T = input.shape[2]
x = input.moveaxis(2, 0).reshape(T, *self.shape).moveaxis(0, 2)
if self.contiguous:
x = x.contiguous()
return x
def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
return self.forward(input)
def forward_step(self, input: Tensor, update_state=True) -> Tensor:
return self._forward_step(input, None)[0]
def _forward_step(self, input: Tensor, prev_state=None) -> Tensor:
x = input.reshape(self.shape)
if self.contiguous:
x = x.contiguous()
return x, prev_state