Residual¶
- class continual.Residual(module, temporal_fill=None, reduce='sum', residual_shrink=False)[source]¶
Residual connection wrapper for input.
This module produces a short form of BroadCast reduce with one delay stream:
conv = co.Conv3d(32, 32, kernel_size=3, padding=1) res1 = co.BroadcastReduce(conv, co.Delay(2), reduce="sum") res2 = co.Residual(conv) x = torch.randn(1, 32, 5, 5, 5) assert torch.equal(res1(x), res2(x))
- Parameters:
module (CoModule) – module to which a residual should be added.
temporal_fill (PaddingMode, optional) – temporal fill type in delay. Defaults to None.
reduce (Reduction, optional) – Reduction function. Defaults to “sum”.
residual_shrink (bool, optional) –
Set residual to shrink its forward to match the temporal dimension reduction of the wrapped module. Options: “centered”, “lagging” or True: Centered residual shrink;
”lagging”: lagging shrink. Defaults to False. “leading”: leading shrink, i.e. no delay during forward_step(s).
- Returns:
BroadcastReduce module with residual.
- Return type: