diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 825e89fc8c..8c0bb3bd3c 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -16,6 +16,7 @@ from abc import abstractmethod from collections import deque +from collections.abc import Callable, Mapping from dataclasses import dataclass, field from functools import reduce from typing import TypeVar @@ -63,8 +64,7 @@ def get( def receive_transform(self, *args: Transform) -> None: ... def receive_tfmessage(self, msg: TFMessage) -> None: - for transform in msg.transforms: - self.receive_transform(transform) + self.receive_transform(*msg.transforms) MsgT = TypeVar("MsgT") @@ -117,6 +117,47 @@ class MultiTBuffer: def __init__(self, buffer_size: float = 10.0) -> None: self.buffers: dict[tuple[str, str], TBuffer] = {} self.buffer_size = buffer_size + self._on_transform_cbs: list[Callable[[], None]] = [] + self._cached_children_of: dict[str, list[str]] = {} + self._cached_roots: list[str] = [] + self._cached_num_edges: int = 0 + + def subscribe(self, cb: Callable[[], None]) -> Callable[[], None]: + """Subscribe to transform updates. Returns an unsubscribe callable.""" + self._on_transform_cbs.append(cb) + + def unsub() -> None: + try: + self._on_transform_cbs.remove(cb) + except ValueError: + pass + + return unsub + + def _invalidate_cache(self) -> None: + """Recompute children_of/roots cache if edges changed.""" + num_edges = len(self.buffers) + if num_edges != self._cached_num_edges: + self._cached_num_edges = num_edges + children_of: dict[str, list[str]] = {} + all_children: set[str] = set() + for parent, child in self.buffers: + children_of.setdefault(parent, []).append(child) + all_children.add(child) + self._cached_children_of = children_of + self._cached_roots = [f for f in children_of if f not in all_children] + + @property + def children_of(self) -> Mapping[str, list[str]]: + """Adjacency map: parent -> [children]. Cached, recomputed when edges change.""" + self._invalidate_cache() + return self._cached_children_of + + @property + def roots(self) -> list[str]: + """Frames that are parents but never children. Cached, recomputed when edges change.""" + self._invalidate_cache() + return self._cached_roots def receive_transform(self, *args: Transform) -> None: for transform in args: @@ -124,6 +165,8 @@ def receive_transform(self, *args: Transform) -> None: if key not in self.buffers: self.buffers[key] = TBuffer(self.buffer_size) self.buffers[key].add(transform) + for cb in self._on_transform_cbs: + cb() def get_frames(self) -> set[str]: frames = set() @@ -187,6 +230,13 @@ def get_transform_search( if direct is not None: return [direct] + # Build bidirectional adjacency from the cached forward map (O(E) once) + neighbors: dict[str, set[str]] = {} + for parent, kids in self.children_of.items(): + for kid in kids: + neighbors.setdefault(parent, set()).add(kid) + neighbors.setdefault(kid, set()).add(parent) + # BFS to find shortest path queue: deque[tuple[str, list[Transform]]] = deque([(parent_frame, [])]) visited = {parent_frame} @@ -197,14 +247,10 @@ def get_transform_search( if current_frame == child_frame: return path - # Get all connections for current frame - connections = self.get_connections(current_frame) - - for next_frame in connections: + for next_frame in neighbors.get(current_frame, set()): if next_frame not in visited: visited.add(next_frame) - # Get the transform between current and next frame transform = self.get_transform( current_frame, next_frame, time_point, time_tolerance ) diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 9bba9dd82f..24ae3b8779 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -37,6 +37,7 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.tf2_msgs import TFMessage from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.pubsub.patterns import Glob, pattern_matches from dimos.utils.logging_config import setup_logger @@ -69,31 +70,6 @@ # # as well as pubsubs={} to specify which protocols to listen to. - -# TODO better TF processing -# -# this is rerun bridge specific, rerun has a specific (better) way of handling TFs -# using entity path conventions, each of these nodes in a path are TF frames: -# -# /world/robot1/base_link/camera/optical -# -# While here since we are just listening on TFMessage messages which optionally contain -# just a subset of full TF tree we don't know the full tree structure to build full entity -# path for a transform being published -# -# This is easy to reconstruct but a service/tf.py already does this so should be integrated here -# -# we have decoupled entity paths and actual transforms (like ROS TF frames) -# https://rerun.io/docs/concepts/logging-and-ingestion/transforms -# -# tf#/world -# tf#/base_link -# tf#/camera -# -# In order to solve this, bridge needs to own it's own tf service -# and render it's tf tree into correct rerun entity paths - - logger = setup_logger() if TYPE_CHECKING: @@ -184,6 +160,9 @@ class Config(ModuleConfig): viewer_mode: ViewerMode = field(default_factory=_resolve_viewer_mode) connect_url: str = "rerun+http://127.0.0.1:9877/proxy" memory_limit: str = "25%" + # When True, TFMessages are intercepted and rendered as hierarchical + # entity paths, bypassing visual_override for TF topics. + tf_enabled: bool = True # Blueprint factory: callable(rrb) -> Blueprint for viewer layout configuration # Set to None to disable default blueprint @@ -208,6 +187,7 @@ class RerunBridgeModule(Module): default_config = Config config: Config + _last_tf_render_time: float = 0.0 @lru_cache(maxsize=256) def _visual_override_for_entity_path( @@ -273,6 +253,11 @@ def _on_message(self, msg: Any, topic: Any) -> None: return self._last_log[entity_path] = now + # TFMessages are handled by the shared TF service via callbacks. + # Early return prevents them from hitting the visual_override path. + if self.config.tf_enabled and isinstance(msg, TFMessage): + return + # apply visual overrides (including final_convert which handles .to_rerun()) rerun_data: RerunData | None = self._visual_override_for_entity_path(entity_path)(msg) @@ -294,6 +279,10 @@ def start(self) -> None: super().start() self._last_log: dict[str, float] = {} + self._last_tf_render_time: float = 0.0 + if self.config.tf_enabled: + unsub_tf = self.tf.subscribe(self._on_tf_changed) + self._disposables.add(Disposable(unsub_tf)) logger.info("Rerun bridge starting", viewer_mode=self.config.viewer_mode) # Initialize and spawn Rerun viewer @@ -341,6 +330,60 @@ def start(self) -> None: self._log_static() + def _on_tf_changed(self) -> None: + """Called by TF service on every transform update. Rate-limits re-renders.""" + now = time.monotonic() + if now - self._last_tf_render_time >= self.config.min_interval_sec: + self._last_tf_render_time = now + self._render_tf_tree() + + def _render_tf_tree(self) -> None: + """Render the TF tree as hierarchical Rerun entity paths. + + Uses the shared TF service's children_of/roots (cached by edge count) + and DFS-walks the tree to log each transform at its hierarchical + entity path (e.g. world/base_link/camera). + """ + import rerun as rr + + tf = self.tf + children = tf.children_of + roots = tf.roots + + visited: set[str] = set() + + def _walk(frame: str, entity_path: str) -> None: + if frame in visited: + return + visited.add(frame) + for child in children.get(frame, []): + child_path = f"{entity_path}/{child}" + transform = tf.get_transform(frame, child) + if transform is not None: + rr.log( + child_path, + rr.Transform3D( + translation=[ + transform.translation.x, + transform.translation.y, + transform.translation.z, + ], + rotation=rr.Quaternion( + xyzw=[ + transform.rotation.x, + transform.rotation.y, + transform.rotation.z, + transform.rotation.w, + ], + ), + ), + ) + _walk(child, child_path) + + prefix = self.config.entity_prefix + for root in roots: + _walk(root, f"{prefix}/{root}" if prefix else root) + def _log_static(self) -> None: import rerun as rr diff --git a/dimos/visualization/rerun/test_tf_tree.py b/dimos/visualization/rerun/test_tf_tree.py new file mode 100644 index 0000000000..71441222b4 --- /dev/null +++ b/dimos/visualization/rerun/test_tf_tree.py @@ -0,0 +1,191 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TF tree rendering in RerunBridgeModule.""" + +from __future__ import annotations + +import builtins +import sys +from types import ModuleType +from typing import Any +from unittest.mock import MagicMock, patch + +# Stub out heavy/unavailable dependencies before importing bridge. +_real_import = builtins.__import__ + + +def _mock_import( + name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0 +) -> Any: + try: + return _real_import(name, globals, locals, fromlist, level) + except (ModuleNotFoundError, ImportError): + if "lazy_loader" in name: + m = ModuleType(name) + m.attach = lambda *a, **kw: (lambda n: None, lambda: [], []) # type: ignore[attr-defined] + sys.modules[name] = m + return m + mock_mod: Any = MagicMock() + sys.modules[name] = mock_mod + return mock_mod # type: ignore[return-value] + + +builtins.__import__ = _mock_import # type: ignore[assignment] + +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.tf2_msgs import TFMessage +from dimos.protocol.tf import MultiTBuffer +from dimos.visualization.rerun.bridge import Config, RerunBridgeModule + +# Restore normal import after our modules are loaded. +builtins.__import__ = _real_import + + +def _make_transform( + parent: str, + child: str, + tx: float = 0.0, + ty: float = 0.0, + tz: float = 0.0, +) -> Transform: + return Transform( + translation=Vector3(tx, ty, tz), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id=parent, + child_frame_id=child, + ) + + +def _make_bridge(*, tf_enabled: bool = True) -> RerunBridgeModule: + """Create a RerunBridgeModule without running Module lifecycle.""" + bridge = object.__new__(RerunBridgeModule) + bridge.config = Config(pubsubs=[], tf_enabled=tf_enabled, entity_prefix="world") + bridge._last_log = {} + bridge._last_tf_render_time = 0.0 + if tf_enabled: + bridge._tf = MultiTBuffer() # type: ignore[assignment] # inject plain buffer for testing (production uses LCMTF which extends both) + else: + bridge._tf = None + return bridge + + +class TestRenderTfTree: + """Tests for _render_tf_tree DFS walk and entity path construction.""" + + @patch("rerun.log") + @patch("rerun.Transform3D") + @patch("rerun.Quaternion") + def test_simple_chain( + self, mock_quat: MagicMock, mock_t3d: MagicMock, mock_log: MagicMock + ) -> None: + """A→B→C chain produces world/A/B and world/A/B/C entity paths.""" + bridge = _make_bridge() + bridge.tf.receive_transform( + _make_transform("odom", "base_link", tx=1.0), + _make_transform("base_link", "camera", tz=0.5), + ) + + bridge._render_tf_tree() + + logged_paths = [c.args[0] for c in mock_log.call_args_list] + assert "world/odom/base_link" in logged_paths + assert "world/odom/base_link/camera" in logged_paths + assert len(logged_paths) == 2 + + @patch("rerun.log") + @patch("rerun.Transform3D") + @patch("rerun.Quaternion") + def test_multiple_roots( + self, mock_quat: MagicMock, mock_t3d: MagicMock, mock_log: MagicMock + ) -> None: + """Two disjoint trees produce separate root paths.""" + bridge = _make_bridge() + bridge.tf.receive_transform( + _make_transform("odom", "base_link"), + _make_transform("map", "marker"), + ) + + bridge._render_tf_tree() + + logged_paths = {c.args[0] for c in mock_log.call_args_list} + assert "world/odom/base_link" in logged_paths + assert "world/map/marker" in logged_paths + assert len(logged_paths) == 2 + + @patch("rerun.log") + def test_tf_disabled_falls_through(self, mock_log: MagicMock) -> None: + """When tf_enabled=False, TFMessage is NOT intercepted by the TF path.""" + bridge = _make_bridge(tf_enabled=False) + + msg = TFMessage(_make_transform("odom", "base_link")) + + # _on_message should NOT enter the TF intercept branch. + # It will try the visual override path which needs _visual_override_for_entity_path. + # We patch that to verify the fallthrough. + with patch.object(bridge, "_visual_override_for_entity_path") as mock_vo: + mock_vo.return_value = lambda m: None # suppress further processing + bridge._on_message(msg, "/tf") + + # visual override path was reached (not short-circuited by TF intercept) + mock_vo.assert_called_once() + + @patch("rerun.log") + @patch("rerun.Transform3D") + @patch("rerun.Quaternion") + def test_incremental_update( + self, mock_quat: MagicMock, mock_t3d: MagicMock, mock_log: MagicMock + ) -> None: + """Adding a new child after initial render extends the tree.""" + bridge = _make_bridge() + bridge.tf.receive_transform( + _make_transform("odom", "base_link"), + ) + bridge._render_tf_tree() + assert len(mock_log.call_args_list) == 1 + + mock_log.reset_mock() + bridge.tf.receive_transform( + _make_transform("base_link", "lidar"), + ) + bridge._render_tf_tree() + + logged_paths = {c.args[0] for c in mock_log.call_args_list} + assert "world/odom/base_link" in logged_paths + assert "world/odom/base_link/lidar" in logged_paths + + @patch("rerun.log") + @patch("rerun.Transform3D") + @patch("rerun.Quaternion") + def test_cycle_protection( + self, mock_quat: MagicMock, mock_t3d: MagicMock, mock_log: MagicMock + ) -> None: + """A cycle in the TF graph does not cause infinite recursion.""" + bridge = _make_bridge() + # Create A→B→C→A cycle + bridge.tf.receive_transform( + _make_transform("A", "B"), + _make_transform("B", "C"), + _make_transform("C", "A"), + ) + + # Should not raise or hang + bridge._render_tf_tree() + + logged_paths = [c.args[0] for c in mock_log.call_args_list] + # A pure cycle has no root (every frame is a child of some other frame), + # so roots list is empty and no DFS starts → zero logs. + assert len(logged_paths) == 0 # pure cycle → no root → no DFS starts