diff --git a/pluto/compat/wandb.py b/pluto/compat/wandb.py index a9e702a..8540e3c 100644 --- a/pluto/compat/wandb.py +++ b/pluto/compat/wandb.py @@ -41,6 +41,7 @@ """ import atexit +import json import logging import os import threading @@ -55,6 +56,9 @@ logger = logging.getLogger(__name__) +# Distinct from None so config dedup can tell "never logged" from "logged None". +_MISSING = object() + _original_wandb_init = None _original_wandb_log = None _original_wandb_finish = None @@ -91,6 +95,13 @@ def __init__(self, wandb_run, pluto_run, pluto_module, wandb_disabled=False): self._fallback_step = 0 # Used when wandb is disabled (_step won't increment) self._closed = False self._close_lock = threading.Lock() + # Keys we've already warned about being unforwardable to Pluto, so a + # value logged every step warns once rather than spamming the logs. + self._unforwardable_warned: set = set() + # Last config values we synced to Pluto, keyed by log key. Lets us skip + # redundant update_config() calls when a str/bool/config value is logged + # unchanged every step (a common pattern: phase/status/checkpoint paths). + self._last_logged_config: Dict[str, Any] = {} if self._pluto_run: atexit.register(self._atexit_cleanup_pluto) @@ -158,7 +169,26 @@ def _do_finish(): logger.debug(f'pluto.compat.wandb: Pluto finish timed out after {timeout}s') def log(self, data: Dict[str, Any], step=None, commit=None, **kwargs): - """Log metrics to both wandb and Pluto.""" + """Log metrics to both wandb and Pluto. + + Value routing for the Pluto side: + - int/float and any scalar exposing .item() (numpy/torch/etc.) + -> Pluto metrics (time-series), matching Pluto core's own log() + - wandb media (Image/Video/Audio/Histogram/Table), and lists + thereof -> converted Pluto media + - str and bool -> Pluto config (latest-wins). Pluto has no + string/bool time-series metric, so these mirror wandb's + summary/overview placement and stay queryable via + get_run().config. + - anything else with no metric/media mapping -> preserved as + config if it survives update_config's normalization (incl. + OmegaConf), otherwise dropped and reported to Sentry telemetry + once per key (a maintainer-coverage signal, not a user-facing + warning). See _handle_unforwardable. + + str/bool/config values are deduped against the last synced value, so + logging an unchanged value every step doesn't spam update_config. + """ # Determine the step to use for Pluto. # When step is explicit, use it. Otherwise: # - Normal mode: read wandb's _step before log() increments it @@ -186,11 +216,29 @@ def log(self, data: Dict[str, Any], step=None, commit=None, **kwargs): # Pluto.log() natively supports lists, so we just need # to convert each element and pass the list through. pluto_data: Dict[str, Any] = {} + # String values have no time-series metric equivalent in + # Pluto (op._process_log_item_sync only keeps int/float/ + # tensor/File/Data). wandb puts loose strings in the run + # summary/overview; the closest Pluto analogue is config, + # which is latest-wins and queryable via get_run().config. + # This is what lets e.g. a resume skill read back the most + # recent checkpoint/r2_path for a run. + pluto_config: Dict[str, Any] = {} for key, value in data.items(): - if isinstance(value, (int, float)): + if isinstance(value, bool): + # bool is a subclass of int, but Pluto drops bool + # metrics — surface it as config so it isn't lost. + # Skip if unchanged since last log (avoid redundant + # config writes when logged every step). + if self._last_logged_config.get(key, _MISSING) != value: + pluto_config[key] = value + elif isinstance(value, (int, float)): pluto_data[key] = value - elif _is_torch_tensor_scalar(value): - pluto_data[key] = value.item() + elif (num := _as_scalar_number(value)) is not None: + pluto_data[key] = num + elif isinstance(value, str): + if self._last_logged_config.get(key, _MISSING) != value: + pluto_config[key] = value elif isinstance(value, (list, tuple)): # List of wandb media — convert each element. converted_items = [] @@ -200,22 +248,102 @@ def log(self, data: Dict[str, Any], step=None, commit=None, **kwargs): converted_items.append(c) if converted_items: pluto_data[key] = converted_items + else: + # Not a media list (e.g. list of primitives) — + # preserve as config if possible, else warn. + self._handle_unforwardable(key, value, pluto_config) else: # Try to convert wandb data types to pluto equivalents converted = _convert_wandb_to_pluto(key, value, self._pluto) if converted is not None: pluto_data[key] = converted - + else: + # No metric/media mapping — last-resort handling + # so the value is never silently dropped. + self._handle_unforwardable(key, value, pluto_config) + + # Metrics and config are sent in independent try blocks: a + # failure logging metrics must NOT skip the config update (or + # vice versa) — str/bool from the same wandb.log() call live in + # config and would otherwise be silently lost. if pluto_data: - log_kwargs = {} - if actual_step is not None: - log_kwargs['step'] = actual_step - self._pluto_run.log(pluto_data, **log_kwargs) + try: + log_kwargs = {} + if actual_step is not None: + log_kwargs['step'] = actual_step + self._pluto_run.log(pluto_data, **log_kwargs) + except Exception as e: + logger.debug( + f'pluto.compat.wandb: Failed to log metrics to Pluto: {e}' + ) + + if pluto_config: + try: + self._pluto_run.update_config(pluto_config) + # Only remember as synced once the update succeeds. + self._last_logged_config.update(pluto_config) + except Exception as e: + logger.debug( + f'pluto.compat.wandb: Failed to sync config to Pluto: {e}' + ) except Exception as e: - logger.debug(f'pluto.compat.wandb: Failed to log metrics to Pluto: {e}') + logger.debug(f'pluto.compat.wandb: Failed to prepare Pluto data: {e}') return result + def _handle_unforwardable(self, key, value, pluto_config: Dict[str, Any]) -> None: + """Last-resort handling for a value with no metric/media mapping. + + Pluto only stores numbers (metrics), media/structured data, and + config — so values outside those (dicts, None, raw/multi-element + tensors, numpy arrays, unconvertible wandb media like Html/Object3D, + custom objects) have nowhere to go. Rather than dropping them + silently — which is what made missing data so hard to diagnose — + we: + + 1. Preserve the value as config if it survives update_config's own + normalization (mirrors how wandb keeps loose values in the run + summary). This covers nested dicts/lists of primitives, None, and + OmegaConf DictConfig/ListConfig nodes (which to_native_config + deep-converts). Skipped if unchanged since the last log. + 2. Otherwise drop the Pluto copy (it still reached W&B) and report + it as a maintainer-coverage signal via Sentry telemetry — once + per key. This is a gap in OUR type handling, not a user error, + so we deliberately do NOT emit a user-facing warning: people + migrating away from wandb shouldn't be nagged about types only + we can fix. The local log stays at debug for self-host + debugging. + """ + storable, native = _config_storable_value(value) + if storable: + if self._last_logged_config.get(key, _MISSING) != native: + pluto_config[key] = native + return + if key in self._unforwardable_warned: + return + self._unforwardable_warned.add(key) + type_name = type(value).__name__ + # Quiet locally (debug only) — not a user-actionable problem. + logger.debug( + 'pluto.compat.wandb: not forwarding %r to Pluto — type %s has no ' + 'metric/media/config mapping (still logged to W&B).', + key, + type_name, + ) + # Alert us (the maintainers) so we can add coverage for the type. + # Message is keyed on the type (not the run-specific key) so Sentry + # groups all occurrences of the same unhandled type together. + try: + from pluto import sentry + + sentry.capture_message( + f'wandb compat: unforwardable Pluto log value of type ' + f'{type_name!r} (no metric/media/config mapping)', + level='warning', + ) + except Exception: + pass + def finish(self, exit_code=None, quiet=None): """Finish both wandb and Pluto runs.""" with self._close_lock: @@ -498,14 +626,56 @@ def _resolve_wandb_to_pluto_run(wandb_run_id, project): return None -def _is_torch_tensor_scalar(value): - """Check if value is a scalar torch tensor.""" +def _as_scalar_number(value): + """Return value as a python int/float if it's a scalar number, else None. + + Mirrors Pluto's own log() (op._process_log_item_sync), which forwards + anything exposing a callable ``.item()``. The shim previously only + accepted plain int/float and torch scalar tensors, so a value logged as + a numpy scalar (``np.int64``), a 0-d numpy array, or a non-torch 0-d + tensor was dropped here even though Pluto core would have kept it — e.g. + an ``epoch`` that is ``np.int64`` rather than a plain ``int``. + + bool and str are excluded (Pluto drops bool metrics; str routes to + config). ``.item()`` on a multi-element array/tensor raises — we treat + that as "not a scalar" and return None, same as Pluto would fail it. + """ + if isinstance(value, (bool, str)): + return None + item = getattr(value, 'item', None) + if not callable(item): + return None + try: + result = item() + except Exception: + return None + if isinstance(result, bool) or not isinstance(result, (int, float)): + return None + return result + + +def _config_storable_value(value): + """Return ``(storable, native)`` for the config fallback. + + Mirrors what ``update_config`` actually does — normalize via + ``to_native_config`` (which deep-converts OmegaConf ``DictConfig`` / + ``ListConfig`` to native containers), then check JSON-serializability. + Keeping the gate in lockstep with ``update_config`` means a logged + ``DictConfig`` is correctly stored as config, even though plain + ``json.dumps`` would reject it. Tensors / ndarrays / custom objects still + fail (``to_native_config`` leaves them as-is) and fall through to the + Sentry path. + + Returns ``(True, native_value)`` when storable, else ``(False, None)``. + """ try: - import torch + from pluto.util import to_native_config - return isinstance(value, torch.Tensor) and value.dim() == 0 - except ImportError: - return False + native = to_native_config(value) + json.dumps(native) + return True, native + except Exception: + return False, None def _is_torch_distributed() -> bool: diff --git a/tests/test_wandb_compat.py b/tests/test_wandb_compat.py index 00eb7a1..ade9434 100644 --- a/tests/test_wandb_compat.py +++ b/tests/test_wandb_compat.py @@ -236,3 +236,178 @@ def test_omegaconf_config_flows_through_shim_and_serializes(clean_env, monkeypat payload = make_compat_start_v1(native, Settings(), info=None) inner = json.loads(json.loads(payload.decode())['config']) assert inner['model']['full_name'] == 'resnet-v2' # interpolation resolved + + +# --------------------------------------------------------------------------- +# WandbRunWrapper.log value routing +# +# The shim pre-filters each logged value before forwarding to Pluto. These +# tests pin the routing that backs the /resume-crashed-run use case: string +# paths (e.g. checkpoint/r2_path) must reach Pluto as config (latest-wins, +# queryable via get_run().config), and numpy scalars must not be silently +# dropped the way plain str/np values were before. +# --------------------------------------------------------------------------- + + +def _make_wrapper(): + """Build a WandbRunWrapper with mock wandb/pluto runs (no atexit).""" + wandb_run = MagicMock() + wandb_run._step = 7 + pluto_run = MagicMock() + pluto_module = MagicMock() + # Avoid registering a real atexit handler during the test. + with mock.patch.object(wandb_compat.atexit, 'register'): + wrapper = wandb_compat.WandbRunWrapper( + wandb_run, pluto_run, pluto_module, wandb_disabled=False + ) + return wrapper, pluto_run + + +def test_log_routes_strings_to_config_not_metrics(): + """checkpoint/r2_path (a str) must land in Pluto config, not log().""" + wrapper, pluto_run = _make_wrapper() + + wrapper.log( + { + 'checkpoint/step': 100, + 'checkpoint/r2_path': 's3://bucket/run/ckpt-100.pt', + 'checkpoint/local_path': '/nfs/run/ckpt-100.pt', + } + ) + + # Strings forwarded to config (latest-wins, readable via get_run().config). + assert pluto_run.update_config.call_count == 1 + cfg = pluto_run.update_config.call_args.args[0] + assert cfg['checkpoint/r2_path'] == 's3://bucket/run/ckpt-100.pt' + assert cfg['checkpoint/local_path'] == '/nfs/run/ckpt-100.pt' + + # Numeric value still goes to metrics; strings must NOT be in log(). + logged = pluto_run.log.call_args.args[0] + assert logged == {'checkpoint/step': 100} + assert 'checkpoint/r2_path' not in logged + + +def test_log_forwards_numpy_scalars_as_metrics(): + """np.int64/np.float32 must reach Pluto metrics, not be dropped.""" + np = pytest.importorskip('numpy') + wrapper, pluto_run = _make_wrapper() + + wrapper.log( + { + 'checkpoint/step': np.int64(100), + 'loss': np.float32(0.5), + } + ) + + logged = pluto_run.log.call_args.args[0] + assert logged['checkpoint/step'] == 100 + assert isinstance(logged['checkpoint/step'], int) # .item() -> python int + assert abs(logged['loss'] - 0.5) < 1e-6 + assert isinstance(logged['loss'], float) + + +def test_log_forwards_any_item_scalar_like_pluto_core(): + """Any scalar exposing .item() is forwarded, matching Pluto's own log(). + + Guards against the shim being stricter than op._process_log_item_sync: + e.g. an ``epoch`` that arrives as a 0-d-tensor-like wrapper rather than a + plain int must still reach Pluto instead of being silently dropped. + """ + wrapper, pluto_run = _make_wrapper() + + class _ScalarLike: + def __init__(self, v): + self._v = v + + def item(self): + return self._v + + wrapper.log({'checkpoint/epoch': _ScalarLike(12)}) + + logged = pluto_run.log.call_args.args[0] + assert logged == {'checkpoint/epoch': 12} + + +def test_log_does_not_treat_failing_item_as_scalar(): + """A non-scalar whose .item() raises must not crash or produce a metric.""" + wrapper, pluto_run = _make_wrapper() + + class _MultiElement: + def item(self): + raise ValueError('can only convert an array of size 1') + + wrapper.log({'weird': _MultiElement()}) + + assert not pluto_run.log.called + assert not pluto_run.update_config.called + + +def test_unforwardable_value_alerts_sentry_once_not_user(): + """An unmappable value alerts Sentry (maintainers) once — not the user.""" + wrapper, pluto_run = _make_wrapper() + + class _Opaque: + """Not numeric, not media, not JSON-serializable.""" + + def item(self): + raise ValueError('not a scalar') + + with mock.patch('pluto.sentry.capture_message') as cap: + wrapper.log({'mystery': _Opaque()}) + wrapper.log({'mystery': _Opaque()}) # second time: no duplicate alert + + # Exactly one maintainer-facing Sentry alert, grouped by type name. + assert cap.call_count == 1 + assert '_Opaque' in cap.call_args.args[0] + assert cap.call_args.kwargs.get('level') == 'warning' + # Nothing forwarded to the user's run, and no user-facing exception. + assert not pluto_run.log.called + assert not pluto_run.update_config.called + + +def test_json_serializable_unmapped_value_falls_back_to_config(): + """A dict/None with no metric mapping is preserved as config, not dropped.""" + wrapper, pluto_run = _make_wrapper() + + wrapper.log({'meta/info': {'kind': 'resume', 'attempt': 3}, 'note': None}) + + cfg = pluto_run.update_config.call_args.args[0] + assert cfg['meta/info'] == {'kind': 'resume', 'attempt': 3} + assert cfg['note'] is None + assert not pluto_run.log.called # no numeric metrics in this call + + +def test_log_skips_redundant_config_updates(): + """An unchanged str/bool config value must not re-trigger update_config.""" + wrapper, pluto_run = _make_wrapper() + + # First log: config is synced. + wrapper.log({'phase': 'train', 'loss': 0.5}) + assert pluto_run.update_config.call_count == 1 + assert pluto_run.update_config.call_args.args[0] == {'phase': 'train'} + + # Same config value again: update_config must NOT be called. + pluto_run.update_config.reset_mock() + wrapper.log({'phase': 'train', 'loss': 0.4}) + assert pluto_run.update_config.call_count == 0 + + # Changed config value: update_config is called again, with only the change. + wrapper.log({'phase': 'val', 'loss': 0.3}) + assert pluto_run.update_config.call_count == 1 + assert pluto_run.update_config.call_args.args[0] == {'phase': 'val'} + + +def test_omegaconf_value_falls_back_to_config_not_dropped(): + """A logged OmegaConf node is storable as config (not Sentry-dropped).""" + OmegaConf = pytest.importorskip('omegaconf').OmegaConf + wrapper, pluto_run = _make_wrapper() + + cfg_node = OmegaConf.create({'lr': 0.01, 'sched': {'name': 'cosine'}}) + + with mock.patch('pluto.sentry.capture_message') as cap: + wrapper.log({'hparams': cfg_node}) + + # Stored as config, deep-converted to native containers; not dropped. + cfg = pluto_run.update_config.call_args.args[0] + assert cfg['hparams'] == {'lr': 0.01, 'sched': {'name': 'cosine'}} + assert not cap.called # OmegaConf is storable -> no maintainer alert