diff --git a/stable_pretraining/callbacks/probe.py b/stable_pretraining/callbacks/probe.py index 0b2ada5f2..c552ccbc3 100644 --- a/stable_pretraining/callbacks/probe.py +++ b/stable_pretraining/callbacks/probe.py @@ -52,6 +52,7 @@ class OnlineProbe(TrainableCallback): optimizer step. Default is 1 (no accumulation). metrics: Metrics to track during training/validation. Can be dict, list, tuple, or single metric instance. + log_on: Log intervals for probe metrics. Options are "step", "epoch", or "both" (default "both"). Note: * The probe module is stored in ``pl_module.callbacks_modules[name]``. @@ -78,6 +79,7 @@ def __init__( gradient_clip_algorithm: str = "norm", metrics: Optional[Union[dict, tuple, list, torchmetrics.Metric]] = None, verbose: bool = None, + log_on: str = "both", ) -> None: from .utils import resolve_verbose @@ -88,6 +90,7 @@ def __init__( logging.warning(f"Not loss given to {name}, will use output of `probe`") self.loss = loss self.verbose = resolve_verbose(verbose) + self.log_on = log_on # Store probe configuration for later initialization self._probe_config = probe @@ -108,7 +111,6 @@ def __init__( logging.info(f" input: {input}") logging.info(f" target: {target}") logging.info(f" accumulate_grad_batches: {accumulate_grad_batches}") - # Setup metrics self.metrics = metrics logging.info(" wrapping forward") self.wrap_forward(pl_module=module) @@ -178,12 +180,19 @@ def new_forward(self, batch, stage, callback=self, fn=fn): metric(preds, y) metric_logs[f"eval/{callback.name}_{metric_name}"] = metric + on_step = callback.log_on in ["step", "both"] + on_epoch = callback.log_on in ["epoch", "both"] + # Raw scalars (loss): sync across GPUs if scalar_logs: - self.log_dict(scalar_logs, on_step=True, on_epoch=True, sync_dist=True) + self.log_dict( + scalar_logs, on_step=on_step, on_epoch=on_epoch, sync_dist=True + ) # torchmetrics: handle their own distributed sync, do NOT use sync_dist if metric_logs: - self.log_dict(metric_logs, on_step=True, on_epoch=True, sync_dist=False) + self.log_dict( + metric_logs, on_step=on_step, on_epoch=on_epoch, sync_dist=False + ) return outputs # Bind the new method to the instance