Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions pluto/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
_RETRY_WAIT_MIN = 0.5
_RETRY_WAIT_MAX = 4.0

# Server-side cap for ``GET /api/runs/metrics/raw``. Mirrors
# RAW_METRICS_MAX_LIMIT in web/server/routes/runs-openapi.ts.
_RAW_METRICS_MAX = 50000


class QueryError(Exception):
"""Raised when a query to the Pluto server fails."""
Expand Down Expand Up @@ -247,6 +251,83 @@ def get_metrics(

return _to_dataframe(raw)

def get_raw_metrics(
self,
project: str,
run_id: Union[int, str],
metric_name: str,
step_min: Optional[int] = None,
step_max: Optional[int] = None,
limit: int = _RAW_METRICS_MAX,
) -> Any:
"""Fetch every raw write for a single metric (un-deduped).

Reads from ClickHouse's ``mlop_metrics`` directly — no FINAL,
no reservoir sampling — so callers see every row that landed
for the given ``(run, metric)``, including duplicates from
resumed runs, re-logs, or manual corrections. Use this for
audit and cleanup; use :meth:`get_metrics` for analysis or
chart rendering.

Each row carries the original write ``time``, which
disambiguates writes that share a step, plus a
``nonFiniteFlags`` bitmask: ``bit0=NaN``, ``bit1=+Inf``,
``bit2=-Inf``. When any flag is set, ``value`` is ``0`` and
the flag is the source of truth — JSON serializes non-finite
floats as null, so the bitmask is the only reliable signal
once the payload leaves ClickHouse.

The server caps responses at 50 000 rows. When the cap is
hit a :class:`UserWarning` is emitted; narrow ``step_min`` /
``step_max`` to page.

Args:
project: Project name.
run_id: Numeric server ID (``int``) or display ID string
(e.g. ``"MMP-1"``).
metric_name: Metric name (e.g. ``"train/loss"``). Required —
scoping every raw scan to a single metric keeps
payloads bounded.
step_min: Minimum step number (inclusive).
step_max: Maximum step number (inclusive).
limit: Max rows to return (max 50 000).

Returns:
``pandas.DataFrame`` with columns ``logGroup``, ``step``,
``time``, ``value``, ``nonFiniteFlags``. If pandas is not
installed, ``list[dict]`` with the same keys.
"""
params: Dict[str, Any] = {
'runId': self._resolve_run_id(project, run_id),
'projectName': project,
'logName': metric_name,
'limit': min(limit, _RAW_METRICS_MAX),
}
if step_min is not None:
params['stepMin'] = step_min
if step_max is not None:
params['stepMax'] = step_max

resp = self._get('/api/runs/metrics/raw', params=params)
rows = resp.get('rows', [])
if resp.get('truncated'):
warnings.warn(
f'get_raw_metrics: result truncated at {len(rows)} rows. '
f'Narrow step_min/step_max to retrieve more.',
stacklevel=2,
)

try:
import pandas as pd

if not rows:
return pd.DataFrame(
columns=['logGroup', 'step', 'time', 'value', 'nonFiniteFlags']
)
return pd.DataFrame(rows)
except ImportError:
return rows
Comment on lines +312 to +329

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider renaming the logGroup column to metric in the returned data for consistency with get_metrics. In the Pluto SDK, metric is the standard name for the metric identifier column. Additionally, using resp.get('rows') or [] is safer than resp.get('rows', []) as it handles cases where the server might explicitly return null for the rows key, preventing potential iteration errors.

        rows = resp.get('rows') or []
        if resp.get('truncated'):
            warnings.warn(
                f'get_raw_metrics: result truncated at {len(rows)} rows. '
                f'Narrow step_min/step_max to retrieve more.',
                stacklevel=2,
            )

        # For consistency with get_metrics, map logGroup to metric
        for row in rows:
            if 'logGroup' in row and 'metric' not in row:
                row['metric'] = row.pop('logGroup')

        try:
            import pandas as pd

            if not rows:
                return pd.DataFrame(
                    columns=['metric', 'step', 'time', 'value', 'nonFiniteFlags']
                )
            return pd.DataFrame(rows)
        except ImportError:
            return rows

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushing back on both:

1. logGroupmetric rename. Two different fields:

  • logName is the full metric identifier (train/loss) — that's what get_metrics renames to metric (see pluto/query.py:247-249).
  • logGroup is the prefix only (train for train/loss) — a separate, narrower field.

Renaming logGroupmetric would put the prefix string into a column named metric, which is semantically wrong and would mislead callers. Also, metric_name is required on get_raw_metrics so every row in a response is already for the same metric — there's no need for a metric column to disambiguate (which is the actual reason get_metrics has one: it can be called for many metrics at once and merges results).

2. resp.get('rows') or []. The server endpoint's Zod schema is rows: z.array(...) — non-null by contract. The current .get('rows', []) is already more defensive than peer methods in this file (e.g. get_metrics does bare result['metrics'] indexing at pluto/query.py:241/244). Adding or [] would guard a state the response schema can't produce and would be inconsistent with the rest of the module.

Not applying either.


# ------------------------------------------------------------------
# Statistics / comparison
# ------------------------------------------------------------------
Expand Down Expand Up @@ -589,6 +670,25 @@ def get_metrics(
)


def get_raw_metrics(
project: str,
run_id: Union[int, str],
metric_name: str,
step_min: Optional[int] = None,
step_max: Optional[int] = None,
limit: int = _RAW_METRICS_MAX,
) -> Any:
"""Fetch raw metric writes. See :meth:`Client.get_raw_metrics`."""
return _get_client().get_raw_metrics(
project,
run_id,
metric_name,
step_min=step_min,
step_max=step_max,
limit=limit,
)


def get_statistics(
project: str,
run_id: Union[int, str],
Expand Down
48 changes: 31 additions & 17 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,28 @@ def _poll_metric_names(
)


def _poll_metric_present(
project: str,
run_id: int,
metric_name: str,
timeout: float = _POLL_TIMEOUT,
) -> None:
"""Poll until *metric_name* has at least one raw write for the run.

Reads the un-deduped ``mlop_metrics`` table directly via
``pq.get_raw_metrics``, bypassing the ``mlop_metric_summaries_v2``
refreshable MV that backs ``pq.get_metric_names`` (5-minute refresh
interval, longer than the e2e poll window). Use this helper when a
test logs a metric and immediately needs to confirm it landed.
"""
rows = _poll(
fn=lambda: pq.get_raw_metrics(project, run_id, metric_name, limit=100),
check=lambda r: len(r) > 0,
timeout=timeout,
)
assert len(rows) > 0, f"'{metric_name}' has no rows on server for run {run_id}"


def _poll_run(
project: str,
run_id: int,
Expand Down Expand Up @@ -272,14 +294,11 @@ def test_e2e_metrics_logged():
run.log({'train/loss': 1.0 - step * 0.1, 'train/acc': step * 0.1})
run.finish()

# Check metric names exist (poll for eventual consistency)
metric_names = _poll_metric_names(
TESTING_PROJECT_NAME, run_id, ['train/loss', 'train/acc']
)
assert (
'train/loss' in metric_names
), f"'train/loss' not in server metric names: {metric_names}"
assert 'train/acc' in metric_names
# Verify each metric landed in raw storage. Reads the un-deduped
# mlop_metrics table directly (no MV refresh latency), so this is
# immediate after run.finish() flushes the sync process.
_poll_metric_present(TESTING_PROJECT_NAME, run_id, 'train/loss')
_poll_metric_present(TESTING_PROJECT_NAME, run_id, 'train/acc')

# Check metric values
metrics = pq.get_metrics(TESTING_PROJECT_NAME, run_id, metric_names=['train/loss'])
Expand Down Expand Up @@ -460,12 +479,8 @@ def test_e2e_multiple_metrics_single_log():
)
run.finish()

expected = ['multi/loss', 'multi/accuracy', 'multi/lr']
metric_names = _poll_metric_names(TESTING_PROJECT_NAME, run_id, expected)
for name in expected:
assert (
name in metric_names
), f"'{name}' not in server metric names: {metric_names}"
for name in ('multi/loss', 'multi/accuracy', 'multi/lr'):
_poll_metric_present(TESTING_PROJECT_NAME, run_id, name)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -623,9 +638,8 @@ def test_e2e_full_lifecycle():
assert 'validated' in server_tags
assert 'lifecycle' not in server_tags # Removed

# Metrics (poll for eventual consistency)
metric_names = _poll_metric_names(TESTING_PROJECT_NAME, run_id, ['lifecycle/loss'])
assert 'lifecycle/loss' in metric_names
# Metrics (raw read — bypasses the summaries MV's refresh interval)
_poll_metric_present(TESTING_PROJECT_NAME, run_id, 'lifecycle/loss')

metrics = pq.get_metrics(
TESTING_PROJECT_NAME, run_id, metric_names=['lifecycle/loss']
Expand Down
23 changes: 21 additions & 2 deletions tests/test_fork_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@ def _poll_metric_names(
)


def _poll_metric_present(
project: str,
run_id: int,
metric_name: str,
timeout: float = _POLL_TIMEOUT,
) -> None:
"""Poll until *metric_name* has at least one raw write for the run.

Reads ``mlop_metrics`` directly via ``pq.get_raw_metrics``, bypassing
the ``mlop_metric_summaries_v2`` refreshable MV (5-minute interval)
that backs ``pq.get_metric_names``.
"""
rows = _poll(
fn=lambda: pq.get_raw_metrics(project, run_id, metric_name, limit=100),
check=lambda r: len(r) > 0,
timeout=timeout,
)
assert len(rows) > 0, f"'{metric_name}' has no rows on server for run {run_id}"


def _poll_max_step(
project: str,
run_id: int,
Expand Down Expand Up @@ -260,8 +280,7 @@ def test_fork_e2e_log_metrics(parent_run):
run.log({'fork/loss': 0.5 - step * 0.1})
run.finish()

metric_names = _poll_metric_names(FORK_PROJECT, run_id, ['fork/loss'])
assert 'fork/loss' in metric_names
_poll_metric_present(FORK_PROJECT, run_id, 'fork/loss')

metrics = pq.get_metrics(FORK_PROJECT, run_id, metric_names=['fork/loss'])
if hasattr(metrics, 'to_dict'):
Expand Down
137 changes: 137 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,143 @@ def test_numeric_string_is_treated_as_server_id(self, client, mock_response):
assert params['runId'] == 42


# ---------------------------------------------------------------------------
# get_raw_metrics
# ---------------------------------------------------------------------------


class TestGetRawMetrics:
def test_basic(self, client, mock_response):
rows = [
{
'logGroup': 'train',
'step': 0,
'time': '2025-01-01 00:00:00.000',
'value': 1.0,
'nonFiniteFlags': 0,
},
{
'logGroup': 'train',
'step': 0,
'time': '2025-01-01 00:00:01.000',
'value': 1.1,
'nonFiniteFlags': 0,
},
]
client._client.get.return_value = mock_response(
200, {'rows': rows, 'truncated': False}
)
result = client.get_raw_metrics('proj', 42, 'train/loss')
params = client._client.get.call_args[1]['params']
assert params['logName'] == 'train/loss'
assert params['runId'] == 42
assert params['projectName'] == 'proj'
assert params['limit'] == 50000
try:
import pandas as pd

assert isinstance(result, pd.DataFrame)
assert len(result) == 2
assert list(result.columns) == [
'logGroup',
'step',
'time',
'value',
'nonFiniteFlags',
]
except ImportError:
assert result == rows

def test_step_range_passed_through(self, client, mock_response):
client._client.get.return_value = mock_response(
200, {'rows': [], 'truncated': False}
)
client.get_raw_metrics('proj', 42, 'train/loss', step_min=100, step_max=200)
params = client._client.get.call_args[1]['params']
assert params['stepMin'] == 100
assert params['stepMax'] == 200

def test_step_range_omitted_when_none(self, client, mock_response):
client._client.get.return_value = mock_response(
200, {'rows': [], 'truncated': False}
)
client.get_raw_metrics('proj', 42, 'train/loss')
params = client._client.get.call_args[1]['params']
assert 'stepMin' not in params
assert 'stepMax' not in params

def test_limit_clamped_to_max(self, client, mock_response):
client._client.get.return_value = mock_response(
200, {'rows': [], 'truncated': False}
)
client.get_raw_metrics('proj', 42, 'train/loss', limit=10**9)
params = client._client.get.call_args[1]['params']
assert params['limit'] == 50000

def test_truncated_emits_warning(self, client, mock_response):
rows = [
{
'logGroup': 'train',
'step': i,
'time': '2025-01-01',
'value': float(i),
'nonFiniteFlags': 0,
}
for i in range(3)
]
client._client.get.return_value = mock_response(
200, {'rows': rows, 'truncated': True}
)
with pytest.warns(UserWarning, match='truncated'):
client.get_raw_metrics('proj', 42, 'train/loss')

def test_no_warning_when_not_truncated(self, client, mock_response):
client._client.get.return_value = mock_response(
200, {'rows': [], 'truncated': False}
)
import warnings

with warnings.catch_warnings():
warnings.simplefilter('error')
client.get_raw_metrics('proj', 42, 'train/loss')

def test_empty_returns_empty_dataframe_with_columns(self, client, mock_response):
client._client.get.return_value = mock_response(
200, {'rows': [], 'truncated': False}
)
result = client.get_raw_metrics('proj', 42, 'train/loss')
try:
import pandas as pd

assert isinstance(result, pd.DataFrame)
assert len(result) == 0
assert list(result.columns) == [
'logGroup',
'step',
'time',
'value',
'nonFiniteFlags',
]
except ImportError:
assert result == []

def test_display_id_resolves_to_numeric(self, client, mock_response):
client._client.get.side_effect = [
mock_response(200, {'id': 99, 'displayId': 'MMP-1'}),
mock_response(200, {'rows': [], 'truncated': False}),
]
client.get_raw_metrics('proj', 'MMP-1', 'train/loss')
first_url = client._client.get.call_args_list[0][0][0]
assert 'by-display-id/MMP-1' in first_url
second_call = client._client.get.call_args_list[1]
assert second_call[0][0].endswith('/api/runs/metrics/raw')
assert second_call[1]['params']['runId'] == 99

def test_metric_name_required(self, client):
with pytest.raises(TypeError):
client.get_raw_metrics('proj', 42) # type: ignore[call-arg]


# ---------------------------------------------------------------------------
# get_statistics
# ---------------------------------------------------------------------------
Expand Down
Loading