diff --git a/pluto/query.py b/pluto/query.py index 49277a2..87e6d46 100644 --- a/pluto/query.py +++ b/pluto/query.py @@ -36,6 +36,10 @@ _RETRY_WAIT_MIN = 0.5 _RETRY_WAIT_MAX = 4.0 +# Server-side cap for ``GET /api/runs/metrics/raw``. Mirrors +# RAW_METRICS_MAX_LIMIT in web/server/routes/runs-openapi.ts. +_RAW_METRICS_MAX = 50000 + class QueryError(Exception): """Raised when a query to the Pluto server fails.""" @@ -247,6 +251,83 @@ def get_metrics( return _to_dataframe(raw) + def get_raw_metrics( + self, + project: str, + run_id: Union[int, str], + metric_name: str, + step_min: Optional[int] = None, + step_max: Optional[int] = None, + limit: int = _RAW_METRICS_MAX, + ) -> Any: + """Fetch every raw write for a single metric (un-deduped). + + Reads from ClickHouse's ``mlop_metrics`` directly — no FINAL, + no reservoir sampling — so callers see every row that landed + for the given ``(run, metric)``, including duplicates from + resumed runs, re-logs, or manual corrections. Use this for + audit and cleanup; use :meth:`get_metrics` for analysis or + chart rendering. + + Each row carries the original write ``time``, which + disambiguates writes that share a step, plus a + ``nonFiniteFlags`` bitmask: ``bit0=NaN``, ``bit1=+Inf``, + ``bit2=-Inf``. When any flag is set, ``value`` is ``0`` and + the flag is the source of truth — JSON serializes non-finite + floats as null, so the bitmask is the only reliable signal + once the payload leaves ClickHouse. + + The server caps responses at 50 000 rows. When the cap is + hit a :class:`UserWarning` is emitted; narrow ``step_min`` / + ``step_max`` to page. + + Args: + project: Project name. + run_id: Numeric server ID (``int``) or display ID string + (e.g. ``"MMP-1"``). + metric_name: Metric name (e.g. ``"train/loss"``). Required — + scoping every raw scan to a single metric keeps + payloads bounded. + step_min: Minimum step number (inclusive). + step_max: Maximum step number (inclusive). + limit: Max rows to return (max 50 000). + + Returns: + ``pandas.DataFrame`` with columns ``logGroup``, ``step``, + ``time``, ``value``, ``nonFiniteFlags``. If pandas is not + installed, ``list[dict]`` with the same keys. + """ + params: Dict[str, Any] = { + 'runId': self._resolve_run_id(project, run_id), + 'projectName': project, + 'logName': metric_name, + 'limit': min(limit, _RAW_METRICS_MAX), + } + if step_min is not None: + params['stepMin'] = step_min + if step_max is not None: + params['stepMax'] = step_max + + resp = self._get('/api/runs/metrics/raw', params=params) + rows = resp.get('rows', []) + if resp.get('truncated'): + warnings.warn( + f'get_raw_metrics: result truncated at {len(rows)} rows. ' + f'Narrow step_min/step_max to retrieve more.', + stacklevel=2, + ) + + try: + import pandas as pd + + if not rows: + return pd.DataFrame( + columns=['logGroup', 'step', 'time', 'value', 'nonFiniteFlags'] + ) + return pd.DataFrame(rows) + except ImportError: + return rows + # ------------------------------------------------------------------ # Statistics / comparison # ------------------------------------------------------------------ @@ -589,6 +670,25 @@ def get_metrics( ) +def get_raw_metrics( + project: str, + run_id: Union[int, str], + metric_name: str, + step_min: Optional[int] = None, + step_max: Optional[int] = None, + limit: int = _RAW_METRICS_MAX, +) -> Any: + """Fetch raw metric writes. See :meth:`Client.get_raw_metrics`.""" + return _get_client().get_raw_metrics( + project, + run_id, + metric_name, + step_min=step_min, + step_max=step_max, + limit=limit, + ) + + def get_statistics( project: str, run_id: Union[int, str], diff --git a/tests/test_e2e.py b/tests/test_e2e.py index ba9230a..fb00cbf 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -71,6 +71,28 @@ def _poll_metric_names( ) +def _poll_metric_present( + project: str, + run_id: int, + metric_name: str, + timeout: float = _POLL_TIMEOUT, +) -> None: + """Poll until *metric_name* has at least one raw write for the run. + + Reads the un-deduped ``mlop_metrics`` table directly via + ``pq.get_raw_metrics``, bypassing the ``mlop_metric_summaries_v2`` + refreshable MV that backs ``pq.get_metric_names`` (5-minute refresh + interval, longer than the e2e poll window). Use this helper when a + test logs a metric and immediately needs to confirm it landed. + """ + rows = _poll( + fn=lambda: pq.get_raw_metrics(project, run_id, metric_name, limit=100), + check=lambda r: len(r) > 0, + timeout=timeout, + ) + assert len(rows) > 0, f"'{metric_name}' has no rows on server for run {run_id}" + + def _poll_run( project: str, run_id: int, @@ -272,14 +294,11 @@ def test_e2e_metrics_logged(): run.log({'train/loss': 1.0 - step * 0.1, 'train/acc': step * 0.1}) run.finish() - # Check metric names exist (poll for eventual consistency) - metric_names = _poll_metric_names( - TESTING_PROJECT_NAME, run_id, ['train/loss', 'train/acc'] - ) - assert ( - 'train/loss' in metric_names - ), f"'train/loss' not in server metric names: {metric_names}" - assert 'train/acc' in metric_names + # Verify each metric landed in raw storage. Reads the un-deduped + # mlop_metrics table directly (no MV refresh latency), so this is + # immediate after run.finish() flushes the sync process. + _poll_metric_present(TESTING_PROJECT_NAME, run_id, 'train/loss') + _poll_metric_present(TESTING_PROJECT_NAME, run_id, 'train/acc') # Check metric values metrics = pq.get_metrics(TESTING_PROJECT_NAME, run_id, metric_names=['train/loss']) @@ -460,12 +479,8 @@ def test_e2e_multiple_metrics_single_log(): ) run.finish() - expected = ['multi/loss', 'multi/accuracy', 'multi/lr'] - metric_names = _poll_metric_names(TESTING_PROJECT_NAME, run_id, expected) - for name in expected: - assert ( - name in metric_names - ), f"'{name}' not in server metric names: {metric_names}" + for name in ('multi/loss', 'multi/accuracy', 'multi/lr'): + _poll_metric_present(TESTING_PROJECT_NAME, run_id, name) # --------------------------------------------------------------------------- @@ -623,9 +638,8 @@ def test_e2e_full_lifecycle(): assert 'validated' in server_tags assert 'lifecycle' not in server_tags # Removed - # Metrics (poll for eventual consistency) - metric_names = _poll_metric_names(TESTING_PROJECT_NAME, run_id, ['lifecycle/loss']) - assert 'lifecycle/loss' in metric_names + # Metrics (raw read — bypasses the summaries MV's refresh interval) + _poll_metric_present(TESTING_PROJECT_NAME, run_id, 'lifecycle/loss') metrics = pq.get_metrics( TESTING_PROJECT_NAME, run_id, metric_names=['lifecycle/loss'] diff --git a/tests/test_fork_e2e.py b/tests/test_fork_e2e.py index d88dc88..91b715b 100644 --- a/tests/test_fork_e2e.py +++ b/tests/test_fork_e2e.py @@ -55,6 +55,26 @@ def _poll_metric_names( ) +def _poll_metric_present( + project: str, + run_id: int, + metric_name: str, + timeout: float = _POLL_TIMEOUT, +) -> None: + """Poll until *metric_name* has at least one raw write for the run. + + Reads ``mlop_metrics`` directly via ``pq.get_raw_metrics``, bypassing + the ``mlop_metric_summaries_v2`` refreshable MV (5-minute interval) + that backs ``pq.get_metric_names``. + """ + rows = _poll( + fn=lambda: pq.get_raw_metrics(project, run_id, metric_name, limit=100), + check=lambda r: len(r) > 0, + timeout=timeout, + ) + assert len(rows) > 0, f"'{metric_name}' has no rows on server for run {run_id}" + + def _poll_max_step( project: str, run_id: int, @@ -260,8 +280,7 @@ def test_fork_e2e_log_metrics(parent_run): run.log({'fork/loss': 0.5 - step * 0.1}) run.finish() - metric_names = _poll_metric_names(FORK_PROJECT, run_id, ['fork/loss']) - assert 'fork/loss' in metric_names + _poll_metric_present(FORK_PROJECT, run_id, 'fork/loss') metrics = pq.get_metrics(FORK_PROJECT, run_id, metric_names=['fork/loss']) if hasattr(metrics, 'to_dict'): diff --git a/tests/test_query.py b/tests/test_query.py index dec315d..c8e0461 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -328,6 +328,143 @@ def test_numeric_string_is_treated_as_server_id(self, client, mock_response): assert params['runId'] == 42 +# --------------------------------------------------------------------------- +# get_raw_metrics +# --------------------------------------------------------------------------- + + +class TestGetRawMetrics: + def test_basic(self, client, mock_response): + rows = [ + { + 'logGroup': 'train', + 'step': 0, + 'time': '2025-01-01 00:00:00.000', + 'value': 1.0, + 'nonFiniteFlags': 0, + }, + { + 'logGroup': 'train', + 'step': 0, + 'time': '2025-01-01 00:00:01.000', + 'value': 1.1, + 'nonFiniteFlags': 0, + }, + ] + client._client.get.return_value = mock_response( + 200, {'rows': rows, 'truncated': False} + ) + result = client.get_raw_metrics('proj', 42, 'train/loss') + params = client._client.get.call_args[1]['params'] + assert params['logName'] == 'train/loss' + assert params['runId'] == 42 + assert params['projectName'] == 'proj' + assert params['limit'] == 50000 + try: + import pandas as pd + + assert isinstance(result, pd.DataFrame) + assert len(result) == 2 + assert list(result.columns) == [ + 'logGroup', + 'step', + 'time', + 'value', + 'nonFiniteFlags', + ] + except ImportError: + assert result == rows + + def test_step_range_passed_through(self, client, mock_response): + client._client.get.return_value = mock_response( + 200, {'rows': [], 'truncated': False} + ) + client.get_raw_metrics('proj', 42, 'train/loss', step_min=100, step_max=200) + params = client._client.get.call_args[1]['params'] + assert params['stepMin'] == 100 + assert params['stepMax'] == 200 + + def test_step_range_omitted_when_none(self, client, mock_response): + client._client.get.return_value = mock_response( + 200, {'rows': [], 'truncated': False} + ) + client.get_raw_metrics('proj', 42, 'train/loss') + params = client._client.get.call_args[1]['params'] + assert 'stepMin' not in params + assert 'stepMax' not in params + + def test_limit_clamped_to_max(self, client, mock_response): + client._client.get.return_value = mock_response( + 200, {'rows': [], 'truncated': False} + ) + client.get_raw_metrics('proj', 42, 'train/loss', limit=10**9) + params = client._client.get.call_args[1]['params'] + assert params['limit'] == 50000 + + def test_truncated_emits_warning(self, client, mock_response): + rows = [ + { + 'logGroup': 'train', + 'step': i, + 'time': '2025-01-01', + 'value': float(i), + 'nonFiniteFlags': 0, + } + for i in range(3) + ] + client._client.get.return_value = mock_response( + 200, {'rows': rows, 'truncated': True} + ) + with pytest.warns(UserWarning, match='truncated'): + client.get_raw_metrics('proj', 42, 'train/loss') + + def test_no_warning_when_not_truncated(self, client, mock_response): + client._client.get.return_value = mock_response( + 200, {'rows': [], 'truncated': False} + ) + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter('error') + client.get_raw_metrics('proj', 42, 'train/loss') + + def test_empty_returns_empty_dataframe_with_columns(self, client, mock_response): + client._client.get.return_value = mock_response( + 200, {'rows': [], 'truncated': False} + ) + result = client.get_raw_metrics('proj', 42, 'train/loss') + try: + import pandas as pd + + assert isinstance(result, pd.DataFrame) + assert len(result) == 0 + assert list(result.columns) == [ + 'logGroup', + 'step', + 'time', + 'value', + 'nonFiniteFlags', + ] + except ImportError: + assert result == [] + + def test_display_id_resolves_to_numeric(self, client, mock_response): + client._client.get.side_effect = [ + mock_response(200, {'id': 99, 'displayId': 'MMP-1'}), + mock_response(200, {'rows': [], 'truncated': False}), + ] + client.get_raw_metrics('proj', 'MMP-1', 'train/loss') + first_url = client._client.get.call_args_list[0][0][0] + assert 'by-display-id/MMP-1' in first_url + second_call = client._client.get.call_args_list[1] + assert second_call[0][0].endswith('/api/runs/metrics/raw') + assert second_call[1]['params']['runId'] == 99 + + def test_metric_name_required(self, client): + with pytest.raises(TypeError): + client.get_raw_metrics('proj', 42) # type: ignore[call-arg] + + # --------------------------------------------------------------------------- # get_statistics # ---------------------------------------------------------------------------