Lambda¶
- class continual.Lambda(fn=None, forward_only_fn=None, forward_step_only_fn=None, forward_steps_only_fn=None, takes_time=False)[source]¶
Module wrapper for stateless functions.
Note
Operations performed in a Lambda are not counted in ptflops
- Parameters:
fn (Callable[[Tensor], Tensor]) – Function to be called during forward.
forward_only_fn – Function to be called only during
forward
.fn
is used for the other call modes.forward_step_only_fn – Function to be called only during
forward_step
.fn
is used for the other call modes.forward_steps_only_fn – Function to be called only during
forward_steps
.fn
is used for the other call modes.forward_only_fn – Function to be called only during forward.
fn
is used for the other call modes.takes_time – If True,
fn
receives all steps, if False, it received one step and no time dimension. Defaults to False.
Examples:
x = torch.arange(90).reshape(1,3,30) * 1.0 # Using named function def same_stats_different_values(x): return torch.randn_like(x) * x.std() + x.mean() same_stats_layer = co.Lambda(same_stats_different_values) same_stats_layer(x) # Using unnamed function mean_layer = co.Lambda(lambda x: torch.ones_like(x) * x.mean()) mean_layer(x) # Using functor sigmoid = co.Lambda(torch.nn.Sigmoid()) sigmoid(x)