BroadcastReduce¶
- class continual.BroadcastReduce(*args: CoModule, reduce: Union[Reduction, Callable[[Sequence[Tensor]], Tensor], str] = 'sum', auto_delay=True)[source]¶
- class continual.BroadcastReduce(arg: OrderedDict[str, CoModule], reduce: Union[Reduction, Callable[[Sequence[Tensor]], Tensor], str] = 'sum', auto_delay=True)
Broadcast an input to parallel modules and reduce. This module is a shorthand for:
co.Sequential(co.Broadcast(), co.Parallel(*args), co.Reduce(reduce))
For instance, it can be used to succinctly create a continual 3D Inception Module:
def norm_relu(module, channels): return co.Sequential( module, nn.BatchNorm3d(channels), nn.ReLU(), ) inception_module = co.BroadcastReduce( co.Conv3d(192, 64, kernel_size=1), co.Sequential( norm_relu(co.Conv3d(192, 96, kernel_size=1), 96), norm_relu(co.Conv3d(96, 128, kernel_size=3, padding=1), 128), ), co.Sequential( norm_relu(co.Conv3d(192, 16, kernel_size=1), 16), norm_relu(co.Conv3d(16, 32, kernel_size=5, padding=2), 32), ), co.Sequential( co.MaxPool3d(kernel_size=(1, 3, 3), padding=(0, 1, 1), stride=1), norm_relu(co.Conv3d(192, 32, kernel_size=1), 32), ), reduce="concat", )
- Parameters:
arg (OrderedDict[str, CoModule]) – An OrderedDict or modules to be applied in parallel.
*args (CoModule) – Modules to be applied in parallel.
reduce (ReductionFuncOrEnum, optional) – Function used to reduce the parallel outputs. Sum or concatenation can be specified by passing “sum” or “concat” respectively. Custom reduce functions can also be passed. Defaults to “sum”.
auto_delay (bool, optional) – Automatically add delay to modules in order to match the longest delay. Defaults to True.