From 9ab19a39dc4d57df8c6a717011a78e2596b17a94 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 8 May 2026 01:02:26 +0000 Subject: [PATCH 1/2] feat(query): add get_raw_metrics for SDK access to un-deduped writes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Backend PR Trainy-ai/server-private#448 added /api/runs/metrics/raw alongside the dedup migration: most reads go through the deduped v2 tables and summaries, but power users debugging duplicate writes (resumes, re-logs, manual corrections) need to see every row that actually landed in ClickHouse, with the original write time. This is the SDK side of that endpoint. pluto.query.get_raw_metrics takes a single (run, metric) — keeping payloads bounded — and returns rows with logGroup, step, time, value, and a nonFiniteFlags bitmask (bit0=NaN, bit1=+Inf, bit2=-Inf). The bitmask is the only reliable signal of non-finite values once the payload is JSON-serialized; value is 0 when any flag is set. The server caps at 50_000 rows per request. When the cap is hit the SDK emits a UserWarning telling the caller to narrow step_min/step_max — making truncation impossible to miss in audit work without changing the return shape away from the DataFrame/list pattern get_metrics already uses. Co-Authored-By: Claude Opus 4.7 (1M context) --- pluto/query.py | 100 ++++++++++++++++++++++++++++++++ tests/test_query.py | 137 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 237 insertions(+) diff --git a/pluto/query.py b/pluto/query.py index 49277a2..87e6d46 100644 --- a/pluto/query.py +++ b/pluto/query.py @@ -36,6 +36,10 @@ _RETRY_WAIT_MIN = 0.5 _RETRY_WAIT_MAX = 4.0 +# Server-side cap for ``GET /api/runs/metrics/raw``. Mirrors +# RAW_METRICS_MAX_LIMIT in web/server/routes/runs-openapi.ts. +_RAW_METRICS_MAX = 50000 + class QueryError(Exception): """Raised when a query to the Pluto server fails.""" @@ -247,6 +251,83 @@ def get_metrics( return _to_dataframe(raw) + def get_raw_metrics( + self, + project: str, + run_id: Union[int, str], + metric_name: str, + step_min: Optional[int] = None, + step_max: Optional[int] = None, + limit: int = _RAW_METRICS_MAX, + ) -> Any: + """Fetch every raw write for a single metric (un-deduped). + + Reads from ClickHouse's ``mlop_metrics`` directly — no FINAL, + no reservoir sampling — so callers see every row that landed + for the given ``(run, metric)``, including duplicates from + resumed runs, re-logs, or manual corrections. Use this for + audit and cleanup; use :meth:`get_metrics` for analysis or + chart rendering. + + Each row carries the original write ``time``, which + disambiguates writes that share a step, plus a + ``nonFiniteFlags`` bitmask: ``bit0=NaN``, ``bit1=+Inf``, + ``bit2=-Inf``. When any flag is set, ``value`` is ``0`` and + the flag is the source of truth — JSON serializes non-finite + floats as null, so the bitmask is the only reliable signal + once the payload leaves ClickHouse. + + The server caps responses at 50 000 rows. When the cap is + hit a :class:`UserWarning` is emitted; narrow ``step_min`` / + ``step_max`` to page. + + Args: + project: Project name. + run_id: Numeric server ID (``int``) or display ID string + (e.g. ``"MMP-1"``). + metric_name: Metric name (e.g. ``"train/loss"``). Required — + scoping every raw scan to a single metric keeps + payloads bounded. + step_min: Minimum step number (inclusive). + step_max: Maximum step number (inclusive). + limit: Max rows to return (max 50 000). + + Returns: + ``pandas.DataFrame`` with columns ``logGroup``, ``step``, + ``time``, ``value``, ``nonFiniteFlags``. If pandas is not + installed, ``list[dict]`` with the same keys. + """ + params: Dict[str, Any] = { + 'runId': self._resolve_run_id(project, run_id), + 'projectName': project, + 'logName': metric_name, + 'limit': min(limit, _RAW_METRICS_MAX), + } + if step_min is not None: + params['stepMin'] = step_min + if step_max is not None: + params['stepMax'] = step_max + + resp = self._get('/api/runs/metrics/raw', params=params) + rows = resp.get('rows', []) + if resp.get('truncated'): + warnings.warn( + f'get_raw_metrics: result truncated at {len(rows)} rows. ' + f'Narrow step_min/step_max to retrieve more.', + stacklevel=2, + ) + + try: + import pandas as pd + + if not rows: + return pd.DataFrame( + columns=['logGroup', 'step', 'time', 'value', 'nonFiniteFlags'] + ) + return pd.DataFrame(rows) + except ImportError: + return rows + # ------------------------------------------------------------------ # Statistics / comparison # ------------------------------------------------------------------ @@ -589,6 +670,25 @@ def get_metrics( ) +def get_raw_metrics( + project: str, + run_id: Union[int, str], + metric_name: str, + step_min: Optional[int] = None, + step_max: Optional[int] = None, + limit: int = _RAW_METRICS_MAX, +) -> Any: + """Fetch raw metric writes. See :meth:`Client.get_raw_metrics`.""" + return _get_client().get_raw_metrics( + project, + run_id, + metric_name, + step_min=step_min, + step_max=step_max, + limit=limit, + ) + + def get_statistics( project: str, run_id: Union[int, str], diff --git a/tests/test_query.py b/tests/test_query.py index dec315d..c8e0461 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -328,6 +328,143 @@ def test_numeric_string_is_treated_as_server_id(self, client, mock_response): assert params['runId'] == 42 +# --------------------------------------------------------------------------- +# get_raw_metrics +# --------------------------------------------------------------------------- + + +class TestGetRawMetrics: + def test_basic(self, client, mock_response): + rows = [ + { + 'logGroup': 'train', + 'step': 0, + 'time': '2025-01-01 00:00:00.000', + 'value': 1.0, + 'nonFiniteFlags': 0, + }, + { + 'logGroup': 'train', + 'step': 0, + 'time': '2025-01-01 00:00:01.000', + 'value': 1.1, + 'nonFiniteFlags': 0, + }, + ] + client._client.get.return_value = mock_response( + 200, {'rows': rows, 'truncated': False} + ) + result = client.get_raw_metrics('proj', 42, 'train/loss') + params = client._client.get.call_args[1]['params'] + assert params['logName'] == 'train/loss' + assert params['runId'] == 42 + assert params['projectName'] == 'proj' + assert params['limit'] == 50000 + try: + import pandas as pd + + assert isinstance(result, pd.DataFrame) + assert len(result) == 2 + assert list(result.columns) == [ + 'logGroup', + 'step', + 'time', + 'value', + 'nonFiniteFlags', + ] + except ImportError: + assert result == rows + + def test_step_range_passed_through(self, client, mock_response): + client._client.get.return_value = mock_response( + 200, {'rows': [], 'truncated': False} + ) + client.get_raw_metrics('proj', 42, 'train/loss', step_min=100, step_max=200) + params = client._client.get.call_args[1]['params'] + assert params['stepMin'] == 100 + assert params['stepMax'] == 200 + + def test_step_range_omitted_when_none(self, client, mock_response): + client._client.get.return_value = mock_response( + 200, {'rows': [], 'truncated': False} + ) + client.get_raw_metrics('proj', 42, 'train/loss') + params = client._client.get.call_args[1]['params'] + assert 'stepMin' not in params + assert 'stepMax' not in params + + def test_limit_clamped_to_max(self, client, mock_response): + client._client.get.return_value = mock_response( + 200, {'rows': [], 'truncated': False} + ) + client.get_raw_metrics('proj', 42, 'train/loss', limit=10**9) + params = client._client.get.call_args[1]['params'] + assert params['limit'] == 50000 + + def test_truncated_emits_warning(self, client, mock_response): + rows = [ + { + 'logGroup': 'train', + 'step': i, + 'time': '2025-01-01', + 'value': float(i), + 'nonFiniteFlags': 0, + } + for i in range(3) + ] + client._client.get.return_value = mock_response( + 200, {'rows': rows, 'truncated': True} + ) + with pytest.warns(UserWarning, match='truncated'): + client.get_raw_metrics('proj', 42, 'train/loss') + + def test_no_warning_when_not_truncated(self, client, mock_response): + client._client.get.return_value = mock_response( + 200, {'rows': [], 'truncated': False} + ) + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter('error') + client.get_raw_metrics('proj', 42, 'train/loss') + + def test_empty_returns_empty_dataframe_with_columns(self, client, mock_response): + client._client.get.return_value = mock_response( + 200, {'rows': [], 'truncated': False} + ) + result = client.get_raw_metrics('proj', 42, 'train/loss') + try: + import pandas as pd + + assert isinstance(result, pd.DataFrame) + assert len(result) == 0 + assert list(result.columns) == [ + 'logGroup', + 'step', + 'time', + 'value', + 'nonFiniteFlags', + ] + except ImportError: + assert result == [] + + def test_display_id_resolves_to_numeric(self, client, mock_response): + client._client.get.side_effect = [ + mock_response(200, {'id': 99, 'displayId': 'MMP-1'}), + mock_response(200, {'rows': [], 'truncated': False}), + ] + client.get_raw_metrics('proj', 'MMP-1', 'train/loss') + first_url = client._client.get.call_args_list[0][0][0] + assert 'by-display-id/MMP-1' in first_url + second_call = client._client.get.call_args_list[1] + assert second_call[0][0].endswith('/api/runs/metrics/raw') + assert second_call[1]['params']['runId'] == 99 + + def test_metric_name_required(self, client): + with pytest.raises(TypeError): + client.get_raw_metrics('proj', 42) # type: ignore[call-arg] + + # --------------------------------------------------------------------------- # get_statistics # --------------------------------------------------------------------------- From c3c5f1161a43ea6e35079b6b847e3c51eae1bdba Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 9 May 2026 00:11:15 +0000 Subject: [PATCH 2/2] test(e2e): use get_raw_metrics for metric-presence checks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The four e2e tests that log a metric and immediately verify it on the server were going through pq.get_metric_names, which reads from mlop_metric_summaries_v2 — a refreshable MV with a 5-minute interval. Worst-case staleness exceeds the 60s poll window, so the check times out before the summary has refreshed even though the underlying row landed in mlop_metrics within milliseconds. (See ingest/docker-setup/ sql/07_metric_summaries_v2_refresh_mv.sql in server-private — the schema comment calls out interval + compute_time as the staleness budget.) Switch the affected sites to a new _poll_metric_present helper that reads the un-deduped mlop_metrics table directly via the new pq.get_raw_metrics method. That endpoint has no MV in its path, so write-then-read is consistent the moment run.finish() drains the sync process. The follow-up pq.get_metrics value checks already use a real-time path (mlop_metrics_v2 FINAL, fed by a non-refreshable mirror MV) and remain unchanged. Tests fixed: test_e2e_metrics_logged test_e2e_multiple_metrics_single_log test_e2e_full_lifecycle test_fork_e2e_log_metrics This does not address the underlying server-side behavior of the /metric-names endpoint — real callers (UI, MCP, SDK list flows) still see up to ~5min of staleness on freshly-logged metrics. Tracking that as a separate server-private follow-up. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_e2e.py | 48 +++++++++++++++++++++++++++--------------- tests/test_fork_e2e.py | 23 ++++++++++++++++++-- 2 files changed, 52 insertions(+), 19 deletions(-) diff --git a/tests/test_e2e.py b/tests/test_e2e.py index ba9230a..fb00cbf 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -71,6 +71,28 @@ def _poll_metric_names( ) +def _poll_metric_present( + project: str, + run_id: int, + metric_name: str, + timeout: float = _POLL_TIMEOUT, +) -> None: + """Poll until *metric_name* has at least one raw write for the run. + + Reads the un-deduped ``mlop_metrics`` table directly via + ``pq.get_raw_metrics``, bypassing the ``mlop_metric_summaries_v2`` + refreshable MV that backs ``pq.get_metric_names`` (5-minute refresh + interval, longer than the e2e poll window). Use this helper when a + test logs a metric and immediately needs to confirm it landed. + """ + rows = _poll( + fn=lambda: pq.get_raw_metrics(project, run_id, metric_name, limit=100), + check=lambda r: len(r) > 0, + timeout=timeout, + ) + assert len(rows) > 0, f"'{metric_name}' has no rows on server for run {run_id}" + + def _poll_run( project: str, run_id: int, @@ -272,14 +294,11 @@ def test_e2e_metrics_logged(): run.log({'train/loss': 1.0 - step * 0.1, 'train/acc': step * 0.1}) run.finish() - # Check metric names exist (poll for eventual consistency) - metric_names = _poll_metric_names( - TESTING_PROJECT_NAME, run_id, ['train/loss', 'train/acc'] - ) - assert ( - 'train/loss' in metric_names - ), f"'train/loss' not in server metric names: {metric_names}" - assert 'train/acc' in metric_names + # Verify each metric landed in raw storage. Reads the un-deduped + # mlop_metrics table directly (no MV refresh latency), so this is + # immediate after run.finish() flushes the sync process. + _poll_metric_present(TESTING_PROJECT_NAME, run_id, 'train/loss') + _poll_metric_present(TESTING_PROJECT_NAME, run_id, 'train/acc') # Check metric values metrics = pq.get_metrics(TESTING_PROJECT_NAME, run_id, metric_names=['train/loss']) @@ -460,12 +479,8 @@ def test_e2e_multiple_metrics_single_log(): ) run.finish() - expected = ['multi/loss', 'multi/accuracy', 'multi/lr'] - metric_names = _poll_metric_names(TESTING_PROJECT_NAME, run_id, expected) - for name in expected: - assert ( - name in metric_names - ), f"'{name}' not in server metric names: {metric_names}" + for name in ('multi/loss', 'multi/accuracy', 'multi/lr'): + _poll_metric_present(TESTING_PROJECT_NAME, run_id, name) # --------------------------------------------------------------------------- @@ -623,9 +638,8 @@ def test_e2e_full_lifecycle(): assert 'validated' in server_tags assert 'lifecycle' not in server_tags # Removed - # Metrics (poll for eventual consistency) - metric_names = _poll_metric_names(TESTING_PROJECT_NAME, run_id, ['lifecycle/loss']) - assert 'lifecycle/loss' in metric_names + # Metrics (raw read — bypasses the summaries MV's refresh interval) + _poll_metric_present(TESTING_PROJECT_NAME, run_id, 'lifecycle/loss') metrics = pq.get_metrics( TESTING_PROJECT_NAME, run_id, metric_names=['lifecycle/loss'] diff --git a/tests/test_fork_e2e.py b/tests/test_fork_e2e.py index d88dc88..91b715b 100644 --- a/tests/test_fork_e2e.py +++ b/tests/test_fork_e2e.py @@ -55,6 +55,26 @@ def _poll_metric_names( ) +def _poll_metric_present( + project: str, + run_id: int, + metric_name: str, + timeout: float = _POLL_TIMEOUT, +) -> None: + """Poll until *metric_name* has at least one raw write for the run. + + Reads ``mlop_metrics`` directly via ``pq.get_raw_metrics``, bypassing + the ``mlop_metric_summaries_v2`` refreshable MV (5-minute interval) + that backs ``pq.get_metric_names``. + """ + rows = _poll( + fn=lambda: pq.get_raw_metrics(project, run_id, metric_name, limit=100), + check=lambda r: len(r) > 0, + timeout=timeout, + ) + assert len(rows) > 0, f"'{metric_name}' has no rows on server for run {run_id}" + + def _poll_max_step( project: str, run_id: int, @@ -260,8 +280,7 @@ def test_fork_e2e_log_metrics(parent_run): run.log({'fork/loss': 0.5 - step * 0.1}) run.finish() - metric_names = _poll_metric_names(FORK_PROJECT, run_id, ['fork/loss']) - assert 'fork/loss' in metric_names + _poll_metric_present(FORK_PROJECT, run_id, 'fork/loss') metrics = pq.get_metrics(FORK_PROJECT, run_id, metric_names=['fork/loss']) if hasattr(metrics, 'to_dict'):