From 5bf17958b28024bd81852836724509af683afe83 Mon Sep 17 00:00:00 2001 From: Florian Rottach Date: Mon, 20 Apr 2026 08:08:12 +0200 Subject: [PATCH 1/4] [Doc] Optuna HP optimization --- examples/GALLERY_HEADER.rst | 38 +++++++++++++++++++ examples/hp_search.yaml | 46 +++++++++++++++++++++++ stable_pretraining/callbacks/hp_metric.py | 16 ++++++++ 3 files changed, 100 insertions(+) create mode 100644 examples/hp_search.yaml create mode 100644 stable_pretraining/callbacks/hp_metric.py diff --git a/examples/GALLERY_HEADER.rst b/examples/GALLERY_HEADER.rst index 187b6e572..b8b7d5de4 100644 --- a/examples/GALLERY_HEADER.rst +++ b/examples/GALLERY_HEADER.rst @@ -2,3 +2,41 @@ Examples ======== Configuration examples for stable-pretraining. + + +# Bayesian Hyperparameter Search with Optuna +Sweeping over a search space is very easy with Hydra and Optuna. `hp_search.yaml` provides an example configuration for performing hyperparameter optimization using Optuna's TPE sampler (bayesian optimization). + +First, make sure to install Optuna if you haven't already: +``` +pip install optuna +``` + +Then, make sure to register the callback HPMetricLogger, which will make sure your hyperparameter optimization metric is logged correctly. More complex logic can also be implemented in this callback. + +``` +from spt.callbacks import HPMetricLogger + +callbacks=[HPMetricLogger(metric_name="eval/some_metric")] +``` + +Finally, make sure your train script contains this code to return the hp_metric to Optuna: +``` +... +manager = spt.Manager(...) +manager() + + +if hasattr(module, "hp_metric"): + result = module.hp_metric.item() + if np.isnan(result): + logger.warning("HP Metric is NaN, returning inf for optimization.") + result = float("inf") + logger.info(f"HP Metric: {result}") + return result +``` + +Now you can simply run the hyperparameter search and it will automatically run multiple trials: +``` +python train.py --config-name=hydra_hp_search +``` diff --git a/examples/hp_search.yaml b/examples/hp_search.yaml new file mode 100644 index 000000000..d8e59de0f --- /dev/null +++ b/examples/hp_search.yaml @@ -0,0 +1,46 @@ +defaults: + - _self_ + - mymodel + - override hydra/launcher: submitit_slurm + - override hydra/sweeper: optuna + +hydra: + mode: MULTIRUN + sweeper: + _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper + direction: minimize + n_trials: 300 # The number of trials for hyperparameter optimization + n_jobs: 15 # Number of parallel jobs to run for hyperparameter optimization + max_failure_rate: 0.3 # Maximum allowed failure rate for trials (e.g., due to out-of-memory errors) + sampler: + _target_: optuna.samplers.TPESampler + seed: 42 + n_startup_trials: 10 # Number of initial random trials before using the TPE sampler + multivariate: True + consider_prior: True + # Some example hyperparameters to search over + params: + module.optimizer.lr: choice(0.01, 0.005, 0.001, 0.0005) + module.masking_ratio: choice(0.0, 0.1, 0.3, 0.6, 0.9) + module.layers: choice(1, 2, 3) + module.activation_function: choice("relu", "gelu") + data.batch_size: choice(32, 64, 128) + sweep: + dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} + + launcher: + partition: gpu + timeout_min: 43200 + cpus_per_task: 8 + mem_gb: 128 + gpus_per_node: 1 + + +# --- Use this section to override any default settings from your config.yaml --- +trainer: + logger: + version: hp_${hydra:job.num} + +data: + some_property: override_value diff --git a/stable_pretraining/callbacks/hp_metric.py b/stable_pretraining/callbacks/hp_metric.py new file mode 100644 index 000000000..c52203d24 --- /dev/null +++ b/stable_pretraining/callbacks/hp_metric.py @@ -0,0 +1,16 @@ +import pytorch_lightning as pl + + +class HPMetricLogger(pl.Callback): + """Callback to log a specific metric for hyperparameter optimization.""" + + def __init__(self, metric_name): + super().__init__() + self.metric_name = metric_name + + def on_validation_epoch_end(self, trainer, pl_module): + hp_metric = trainer.callback_metrics[self.metric_name] + if getattr(pl_module, "hp_metric", None) is None: + pl_module.hp_metric = hp_metric + else: + pl_module.hp_metric = min(pl_module.hp_metric, hp_metric) From 6c6cb90dc77379c82ebc55185a2322c0d0d4d387 Mon Sep 17 00:00:00 2001 From: Florian Rottach Date: Mon, 20 Apr 2026 08:16:51 +0200 Subject: [PATCH 2/4] [Doc] Extend hp documentation --- examples/GALLERY_HEADER.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/GALLERY_HEADER.rst b/examples/GALLERY_HEADER.rst index b8b7d5de4..922a86624 100644 --- a/examples/GALLERY_HEADER.rst +++ b/examples/GALLERY_HEADER.rst @@ -40,3 +40,4 @@ Now you can simply run the hyperparameter search and it will automatically run m ``` python train.py --config-name=hydra_hp_search ``` +It is recommended to use the EarlyStopping callback in combination with hyperparameter optimization to avoid wasting resources on bad trials. From 0dd35dbf23638e53904c68b3ce3b8a6ee345da4f Mon Sep 17 00:00:00 2001 From: Florian Rottach Date: Thu, 23 Apr 2026 07:05:14 +0200 Subject: [PATCH 3/4] [Feature] logging config for probes --- stable_pretraining/callbacks/probe.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/stable_pretraining/callbacks/probe.py b/stable_pretraining/callbacks/probe.py index bcbef5881..08aa38e39 100644 --- a/stable_pretraining/callbacks/probe.py +++ b/stable_pretraining/callbacks/probe.py @@ -53,6 +53,7 @@ class OnlineProbe(TrainableCallback): optimizer step. Default is 1 (no accumulation). metrics: Metrics to track during training/validation. Can be dict, list, tuple, or single metric instance. + log_on: Log intervals for probe metrics. Options are "step", "epoch", or "both" (default "both"). Note: - The probe module is stored in pl_module.callbacks_modules[name] @@ -79,6 +80,7 @@ def __init__( gradient_clip_algorithm: str = "norm", metrics: Optional[Union[dict, tuple, list, torchmetrics.Metric]] = None, verbose: bool = None, + log_on: str = "both", ) -> None: from .utils import resolve_verbose @@ -89,6 +91,7 @@ def __init__( logging.warning(f"Not loss given to {name}, will use output of `probe`") self.loss = loss self.verbose = resolve_verbose(verbose) + self.log_on = log_on # Store probe configuration for later initialization self._probe_config = probe @@ -109,7 +112,6 @@ def __init__( logging.info(f" input: {input}") logging.info(f" target: {target}") logging.info(f" accumulate_grad_batches: {accumulate_grad_batches}") - # Setup metrics self.metrics = metrics logging.info(" wrapping forward") self.wrap_forward(pl_module=module) @@ -180,11 +182,18 @@ def new_forward(self, batch, stage, callback=self, fn=fn): metric_logs[f"eval/{callback.name}_{metric_name}"] = metric # Raw scalars (loss): sync across GPUs + on_step = callback.log_on in ["step", "both"] + on_epoch = callback.log_on in ["epoch", "both"] + if scalar_logs: - self.log_dict(scalar_logs, on_step=True, on_epoch=True, sync_dist=True) + self.log_dict( + scalar_logs, on_step=on_step, on_epoch=on_epoch, sync_dist=True + ) # torchmetrics: handle their own distributed sync, do NOT use sync_dist if metric_logs: - self.log_dict(metric_logs, on_step=True, on_epoch=True, sync_dist=False) + self.log_dict( + metric_logs, on_step=on_step, on_epoch=on_epoch, sync_dist=False + ) return outputs # Bind the new method to the instance From 03eafc124bd6ab77d2e0e9e7515d7823c6ef26d8 Mon Sep 17 00:00:00 2001 From: Florian Rottach Date: Thu, 23 Apr 2026 07:12:53 +0200 Subject: [PATCH 4/4] extract logs --- stable_pretraining/callbacks/probe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_pretraining/callbacks/probe.py b/stable_pretraining/callbacks/probe.py index 08aa38e39..c641a35b8 100644 --- a/stable_pretraining/callbacks/probe.py +++ b/stable_pretraining/callbacks/probe.py @@ -181,10 +181,10 @@ def new_forward(self, batch, stage, callback=self, fn=fn): metric(preds, y) metric_logs[f"eval/{callback.name}_{metric_name}"] = metric - # Raw scalars (loss): sync across GPUs on_step = callback.log_on in ["step", "both"] on_epoch = callback.log_on in ["epoch", "both"] + # Raw scalars (loss): sync across GPUs if scalar_logs: self.log_dict( scalar_logs, on_step=on_step, on_epoch=on_epoch, sync_dist=True