From 948d399dd39344746f50bf7a446896d343da5f76 Mon Sep 17 00:00:00 2001 From: David Gable Date: Sat, 23 May 2026 18:09:49 -0700 Subject: [PATCH 1/6] Added basic signal editor called Lab --- .../utils/layers/dojo/lab_executor.py | 211 ++++ .../utils/layers/dojo/routers/mosaic.py | 147 ++- .../partials/trial_viewer/_chart.html | 22 +- .../partials/trial_viewer/_header.html | 1 + .../partials/trial_viewer/_lab_trigger.html | 17 + .../partials/trial_viewer/_signal_lab.html | 1019 +++++++++++++++++ .../dojo/templates/static/js/trial-viewer.js | 121 +- .../templates/static/ts/src/trial-viewer.ts | 141 ++- .../templates/static/ts/src/types/global.d.ts | 4 + .../static/vendored/litegraph.min.css | 8 + .../static/vendored/litegraph.min.js | 915 +++++++++++++++ .../layers/dojo/templates/trial_viewer.html | 9 + 12 files changed, 2587 insertions(+), 28 deletions(-) create mode 100644 src/mujoco_mojo/utils/layers/dojo/lab_executor.py create mode 100644 src/mujoco_mojo/utils/layers/dojo/templates/partials/trial_viewer/_lab_trigger.html create mode 100644 src/mujoco_mojo/utils/layers/dojo/templates/partials/trial_viewer/_signal_lab.html create mode 100644 src/mujoco_mojo/utils/layers/dojo/templates/static/vendored/litegraph.min.css create mode 100644 src/mujoco_mojo/utils/layers/dojo/templates/static/vendored/litegraph.min.js diff --git a/src/mujoco_mojo/utils/layers/dojo/lab_executor.py b/src/mujoco_mojo/utils/layers/dojo/lab_executor.py new file mode 100644 index 00000000..03cdf6e1 --- /dev/null +++ b/src/mujoco_mojo/utils/layers/dojo/lab_executor.py @@ -0,0 +1,211 @@ +""" +Lab graph executor. + +Takes a LiteGraph-serialised graph (from graph.serialize() in the browser) +and executes it against a Polars DataFrame, returning the output series keyed +by their Signal Out labels. + +LiteGraph link format: + links: [[link_id, from_node_id, from_slot, to_node_id, to_slot, type], ...] +""" + +from __future__ import annotations + +from collections import defaultdict, deque +from typing import Any + +import polars as pl + +from mujoco_mojo.utils.filters.filters import ( + AbsoluteValueFilter, + ClipFilter, + DeadbandFilter, + HighPassFilter, + LowPassFilter, + MedianFilter, + NormalizeFilter, + RollingMeanFilter, + SavitzkyGolayFilter, + ScaleFilter, + TaringFilter, + WrapFilter, +) + +# Map node type string → filter class (single-input filters only) +_FILTER_MAP = { + "low_pass": LowPassFilter, + "high_pass": HighPassFilter, + "scale": ScaleFilter, + "rolling_mean": RollingMeanFilter, + "median": MedianFilter, + "savitzky_golay": SavitzkyGolayFilter, + "clip": ClipFilter, + "deadband": DeadbandFilter, + "wrap": WrapFilter, + "taring": TaringFilter, + "normalize": NormalizeFilter, + "absolute_value": AbsoluteValueFilter, +} + + +class LabExecutor: + """ + Execute a LiteGraph filter graph against a Polars DataFrame. + + Usage:: + + executor = LabExecutor(graph_dict) + outputs = executor.execute(df) # {label: pl.Series} + """ + + def __init__(self, graph: dict[str, Any]) -> None: + self.nodes: dict[int, dict] = {n["id"]: n for n in graph.get("nodes", [])} + # links keyed by link_id → [link_id, from_node, from_slot, to_node, to_slot, type] + self.links: dict[int, list] = {lnk[0]: lnk for lnk in graph.get("links", [])} + + # ── public helpers ──────────────────────────────────────────────────────── + + @staticmethod + def _bare_type(node: dict) -> str: + """Strip the LiteGraph category prefix, e.g. 'Signal/signal_in' → 'signal_in'.""" + return node.get("type", "").split("/")[-1] + + @property + def signal_in_columns(self) -> list[str]: + """Column names used by all Signal In nodes — used for validation.""" + return [ + n.get("properties", {}).get("column", "") + for n in self.nodes.values() + if self._bare_type(n) == "signal_in" + ] + + @property + def output_labels(self) -> list[str]: + """Labels produced by all Signal Out nodes.""" + return [ + n.get("properties", {}).get("label") or f"out_{n['id']}" + for n in self.nodes.values() + if self._bare_type(n) == "signal_out" + ] + + def execute(self, df: pl.DataFrame) -> dict[str, pl.Series]: + """Run the graph and return {output_label: series}.""" + # slot_data[node_id][slot_index] = computed series + slot_data: dict[int, dict[int, pl.Series]] = defaultdict(dict) + + for nid in self._topo_sort(): + node = self.nodes[nid] + ntype = self._bare_type(node) + props = node.get("properties", {}) + + if ntype == "signal_in": + col = props.get("column", "") + series = ( + df[col].cast(pl.Float64) + if col in df.columns + else pl.Series(name=col, values=[0.0] * len(df)) + ) + slot_data[nid][0] = series + + elif ntype == "signal_out": + signal = self._get_input(node, 0, slot_data) + if signal is not None: + slot_data[nid][0] = signal + + else: + signal = self._get_input(node, 0, slot_data) + if signal is None: + continue + wrt = self._get_input(node, 1, slot_data) + slot_data[nid][0] = self._apply(ntype, props, signal, wrt, df) + + # Collect Signal Out results + outputs: dict[str, pl.Series] = {} + for nid, node in self.nodes.items(): + if node.get("type") == "signal_out": + series = slot_data.get(nid, {}).get(0) + if series is not None: + label = node.get("properties", {}).get("label") or f"out_{nid}" + outputs[label] = series + return outputs + + # ── internals ───────────────────────────────────────────────────────────── + + def _topo_sort(self) -> list[int]: + in_degree: dict[int, int] = {nid: 0 for nid in self.nodes} + adj: dict[int, list[int]] = defaultdict(list) + for _, from_node, _, to_node, _, *_ in self.links.values(): + adj[from_node].append(to_node) + in_degree[to_node] += 1 + queue = deque(nid for nid, deg in in_degree.items() if deg == 0) + order: list[int] = [] + while queue: + nid = queue.popleft() + order.append(nid) + for nxt in adj[nid]: + in_degree[nxt] -= 1 + if in_degree[nxt] == 0: + queue.append(nxt) + return order + + def _get_input( + self, + node: dict, + slot: int, + slot_data: dict[int, dict[int, pl.Series]], + ) -> pl.Series | None: + inputs = node.get("inputs", []) + if slot >= len(inputs): + return None + link_id = inputs[slot].get("link") + if link_id is None or link_id not in self.links: + return None + lnk = self.links[link_id] + from_node, from_slot = lnk[1], lnk[2] + return slot_data.get(from_node, {}).get(from_slot) + + def _apply( + self, + ntype: str, + props: dict, + signal: pl.Series, + wrt: pl.Series | None, + df: pl.DataFrame, + ) -> pl.Series: + # Derivative and Integral handle wrt directly + if ntype == "derivative": + if wrt is not None: + dx = ( + wrt.cast(pl.Float64) + .diff() + .fill_null(strategy="forward") + .fill_null(1) + ) + return signal.cast(pl.Float64).diff().fill_null(0) / dx + dt = float(props.get("dt", 0.001)) or 0.001 + return signal.cast(pl.Float64).diff().fill_null(0) / dt + + if ntype == "integral": + if wrt is not None: + dx = wrt.cast(pl.Float64).diff().fill_null(0) + return (signal.cast(pl.Float64) * dx).cum_sum() + dt = float(props.get("dt", 0.001)) or 0.001 + return signal.cast(pl.Float64).cum_sum() * dt + + cls = _FILTER_MAP.get(ntype) + if cls is None: + return signal + + # Strip None values so Pydantic uses field defaults + clean = {k: v for k, v in props.items() if v is not None} + try: + filt = cls(**clean) + except Exception: + return signal + + tmp = pl.DataFrame({"_s": signal.cast(pl.Float64)}) + ctx = filt.apply_with_context(tmp["_s"], df) + if ctx is not None: + return ctx + tmp = tmp.with_columns(filt.apply(pl.col("_s")).alias("_s")) + return tmp["_s"] diff --git a/src/mujoco_mojo/utils/layers/dojo/routers/mosaic.py b/src/mujoco_mojo/utils/layers/dojo/routers/mosaic.py index 03105e6f..a5076bed 100644 --- a/src/mujoco_mojo/utils/layers/dojo/routers/mosaic.py +++ b/src/mujoco_mojo/utils/layers/dojo/routers/mosaic.py @@ -395,6 +395,120 @@ async def delete_profile(name: str): return {"deleted": path.relative_to(d).with_suffix("").as_posix()} +# --------------------------------------------------------------------------- +# Lab · filter graph configs stored under ~/.mujoco-mojo/lab/ +# --------------------------------------------------------------------------- + +_LAB_PREFIX = "Lab" # Virtual column category shown in the Y-axis selector +_LAB_MAX_BYTES = 1024 * 1024 # 1 MB + + +def _get_lab_dir() -> Path: + d: Path = Path.home() / ".mujoco-mojo" / "lab" + d.mkdir(parents=True, exist_ok=True) + return d + + +def _sanitize_lab_name(name: str) -> str: + """ + Return a filesystem-safe relative path from the user-supplied lab name. + + Supports folder separators, e.g. 'robotics/arm_reach/baseline'. + Each segment is sanitized independently; empty segments are dropped. + """ + name = name.strip()[:256] + segments = [s.strip() for s in name.split("/") if s.strip()] + safe: list[str] = [] + for seg in segments: + seg = re.sub(r"[^\w\s\-]", "", seg) + seg = re.sub(r"\s+", "_", seg) + seg = re.sub(r"_+", "_", seg).strip("_") + if seg: + safe.append(seg[:64]) + return "/".join(safe) or "lab" + + +def _resolve_lab_path(name: str) -> Path: + d = _get_lab_dir() + path = (d / f"{_sanitize_lab_name(name)}.json").resolve() + if not path.is_relative_to(d.resolve()): + raise HTTPException(status_code=400, detail="Invalid lab name") + return path + + +def _lab_meta(path: Path, d: Path) -> dict: + """Parse a saved lab file and return metadata for the API.""" + from mujoco_mojo.utils.layers.dojo.lab_executor import LabExecutor + + try: + graph = json.loads(path.read_text(encoding="utf-8")) + exc = LabExecutor(graph) + return { + "name": path.relative_to(d).with_suffix("").as_posix(), + "modified": int(path.stat().st_mtime * 1000), + "signal_in_columns": exc.signal_in_columns, + "outputs": exc.output_labels, + } + except Exception: + return { + "name": path.relative_to(d).with_suffix("").as_posix(), + "modified": int(path.stat().st_mtime * 1000), + "signal_in_columns": [], + "outputs": [], + } + + +@router.get("/api/lab") +async def list_labs(): + """List all saved lab graphs with their input column requirements and output labels.""" + d = _get_lab_dir() + return [ + _lab_meta(f, d) + for f in sorted( + d.rglob("*.json"), key=lambda p: p.stat().st_mtime, reverse=True + ) + ] + + +@router.get("/api/lab/{name:path}") +async def get_lab(name: str): + """Return the raw LiteGraph JSON for a saved lab.""" + path = _resolve_lab_path(name) + if not path.exists(): + raise HTTPException(status_code=404, detail="Lab not found") + return json.loads(path.read_text(encoding="utf-8")) + + +@router.post("/api/lab/{name:path}") +async def save_lab(name: str, request: Request): + """Save a LiteGraph graph JSON as a named lab.""" + cl = request.headers.get("content-length") + if cl and int(cl) > _LAB_MAX_BYTES: + raise HTTPException(status_code=413, detail="Lab payload too large") + body = await request.json() + path = _resolve_lab_path(name) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(body), encoding="utf-8") + d = _get_lab_dir() + return {"name": path.relative_to(d).with_suffix("").as_posix()} + + +@router.delete("/api/lab/{name:path}") +async def delete_lab(name: str): + """Delete a saved lab.""" + path = _resolve_lab_path(name) + if not path.exists(): + raise HTTPException(status_code=404, detail="Lab not found") + path.unlink() + d = _get_lab_dir() + # Remove empty parent directories up to (but not including) the lab root. + parent = path.parent + while parent != d and parent.is_dir() and not any(parent.iterdir()): + parent.rmdir() + parent = parent.parent + return {"deleted": path.relative_to(d).with_suffix("").as_posix()} + + @lru_cache(maxsize=128) def _get_column_manifest(path_str: str, mtime: float) -> ColumnManifest: """Retrieves all column names from the table schema.""" @@ -498,8 +612,39 @@ async def get_trial_data( logger.warning(f"Could not parse filters for {trial_id}: {e}") filter_errors.append(_format_filter_error(e)) - # build response data, applying per-column filters where present data: dict = {} + + # ── Lab virtual columns ──────────────────────────────────────────────── + # Columns named "Lab/{lab_name}/{output_label}" are computed by running + # the saved lab graph rather than reading from the parquet file. + lab_cols = [c for c in requested if c.startswith(f"{_LAB_PREFIX}/")] + if lab_cols: + from mujoco_mojo.utils.layers.dojo.lab_executor import LabExecutor + + # Group by lab name to execute each graph once + from collections import defaultdict as _dd + + by_lab: dict[str, list[tuple[str, str]]] = _dd(list) + for col in lab_cols: + parts = col.split("/", 2) + if len(parts) == 3: + _, lab_name, output_label = parts + by_lab[lab_name].append((col, output_label)) + + for lab_name, col_outputs in by_lab.items(): + lab_path = _resolve_lab_path(lab_name) + if not lab_path.exists(): + continue + try: + graph = json.loads(lab_path.read_text(encoding="utf-8")) + outputs = LabExecutor(graph).execute(df) + for full_col, output_label in col_outputs: + if output_label in outputs: + data[full_col] = outputs[output_label].to_list() + except Exception as exc: + logger.warning(f"Lab '{lab_name}' execution failed: {exc}") + + # ── build response data, applying per-column filters where present ──── for col in requested: if col not in df.columns: continue diff --git a/src/mujoco_mojo/utils/layers/dojo/templates/partials/trial_viewer/_chart.html b/src/mujoco_mojo/utils/layers/dojo/templates/partials/trial_viewer/_chart.html index 5ba2e5ec..b5c9a86f 100644 --- a/src/mujoco_mojo/utils/layers/dojo/templates/partials/trial_viewer/_chart.html +++ b/src/mujoco_mojo/utils/layers/dojo/templates/partials/trial_viewer/_chart.html @@ -1026,19 +1026,14 @@
@@ -1056,9 +1051,9 @@ Save
- -
+ +
@@ -1240,5 +1235,6 @@
+ diff --git a/src/mujoco_mojo/utils/layers/dojo/templates/partials/trial_viewer/_header.html b/src/mujoco_mojo/utils/layers/dojo/templates/partials/trial_viewer/_header.html index 40a60583..ffa80805 100644 --- a/src/mujoco_mojo/utils/layers/dojo/templates/partials/trial_viewer/_header.html +++ b/src/mujoco_mojo/utils/layers/dojo/templates/partials/trial_viewer/_header.html @@ -26,6 +26,7 @@
+ {% include 'partials/trial_viewer/_lab_trigger.html' %}