Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions dimos/protocol/tf/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from abc import abstractmethod
from collections import deque
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from functools import reduce
from typing import TypeVar
Expand Down Expand Up @@ -63,8 +64,7 @@ def get(
def receive_transform(self, *args: Transform) -> None: ...

def receive_tfmessage(self, msg: TFMessage) -> None:
for transform in msg.transforms:
self.receive_transform(transform)
self.receive_transform(*msg.transforms)


MsgT = TypeVar("MsgT")
Expand Down Expand Up @@ -117,13 +117,56 @@ class MultiTBuffer:
def __init__(self, buffer_size: float = 10.0) -> None:
self.buffers: dict[tuple[str, str], TBuffer] = {}
self.buffer_size = buffer_size
self._on_transform_cbs: list[Callable[[], None]] = []
self._cached_children_of: dict[str, list[str]] = {}
self._cached_roots: list[str] = []
self._cached_num_edges: int = 0

def subscribe(self, cb: Callable[[], None]) -> Callable[[], None]:
"""Subscribe to transform updates. Returns an unsubscribe callable."""
self._on_transform_cbs.append(cb)

def unsub() -> None:
try:
self._on_transform_cbs.remove(cb)
except ValueError:
pass

return unsub

def _invalidate_cache(self) -> None:
"""Recompute children_of/roots cache if edges changed."""
num_edges = len(self.buffers)
if num_edges != self._cached_num_edges:
self._cached_num_edges = num_edges
children_of: dict[str, list[str]] = {}
all_children: set[str] = set()
for parent, child in self.buffers:
children_of.setdefault(parent, []).append(child)
all_children.add(child)
self._cached_children_of = children_of
self._cached_roots = [f for f in children_of if f not in all_children]

@property
def children_of(self) -> Mapping[str, list[str]]:
"""Adjacency map: parent -> [children]. Cached, recomputed when edges change."""
self._invalidate_cache()
return self._cached_children_of

@property
def roots(self) -> list[str]:
"""Frames that are parents but never children. Cached, recomputed when edges change."""
self._invalidate_cache()
return self._cached_roots

def receive_transform(self, *args: Transform) -> None:
for transform in args:
key = (transform.frame_id, transform.child_frame_id)
if key not in self.buffers:
self.buffers[key] = TBuffer(self.buffer_size)
self.buffers[key].add(transform)
for cb in self._on_transform_cbs:
cb()

def get_frames(self) -> set[str]:
frames = set()
Expand Down Expand Up @@ -187,6 +230,13 @@ def get_transform_search(
if direct is not None:
return [direct]

# Build bidirectional adjacency from the cached forward map (O(E) once)
neighbors: dict[str, set[str]] = {}
for parent, kids in self.children_of.items():
for kid in kids:
neighbors.setdefault(parent, set()).add(kid)
neighbors.setdefault(kid, set()).add(parent)

# BFS to find shortest path
queue: deque[tuple[str, list[Transform]]] = deque([(parent_frame, [])])
visited = {parent_frame}
Expand All @@ -197,14 +247,10 @@ def get_transform_search(
if current_frame == child_frame:
return path

# Get all connections for current frame
connections = self.get_connections(current_frame)

for next_frame in connections:
for next_frame in neighbors.get(current_frame, set()):
if next_frame not in visited:
visited.add(next_frame)

# Get the transform between current and next frame
transform = self.get_transform(
current_frame, next_frame, time_point, time_tolerance
)
Expand Down
93 changes: 68 additions & 25 deletions dimos/visualization/rerun/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from dimos.core.core import rpc
from dimos.core.module import Module, ModuleConfig
from dimos.msgs.sensor_msgs import Image, PointCloud2
from dimos.msgs.tf2_msgs import TFMessage
from dimos.protocol.pubsub.impl.lcmpubsub import LCM
from dimos.protocol.pubsub.patterns import Glob, pattern_matches
from dimos.utils.logging_config import setup_logger
Expand Down Expand Up @@ -69,31 +70,6 @@
#
# as well as pubsubs={} to specify which protocols to listen to.


# TODO better TF processing
#
# this is rerun bridge specific, rerun has a specific (better) way of handling TFs
# using entity path conventions, each of these nodes in a path are TF frames:
#
# /world/robot1/base_link/camera/optical
#
# While here since we are just listening on TFMessage messages which optionally contain
# just a subset of full TF tree we don't know the full tree structure to build full entity
# path for a transform being published
#
# This is easy to reconstruct but a service/tf.py already does this so should be integrated here
#
# we have decoupled entity paths and actual transforms (like ROS TF frames)
# https://rerun.io/docs/concepts/logging-and-ingestion/transforms
#
# tf#/world
# tf#/base_link
# tf#/camera
#
# In order to solve this, bridge needs to own it's own tf service
# and render it's tf tree into correct rerun entity paths


logger = setup_logger()

if TYPE_CHECKING:
Expand Down Expand Up @@ -184,6 +160,9 @@ class Config(ModuleConfig):
viewer_mode: ViewerMode = field(default_factory=_resolve_viewer_mode)
connect_url: str = "rerun+http://127.0.0.1:9877/proxy"
memory_limit: str = "25%"
# When True, TFMessages are intercepted and rendered as hierarchical
# entity paths, bypassing visual_override for TF topics.
tf_enabled: bool = True

# Blueprint factory: callable(rrb) -> Blueprint for viewer layout configuration
# Set to None to disable default blueprint
Expand All @@ -208,6 +187,7 @@ class RerunBridgeModule(Module):

default_config = Config
config: Config
_last_tf_render_time: float = 0.0

@lru_cache(maxsize=256)
def _visual_override_for_entity_path(
Expand Down Expand Up @@ -273,6 +253,11 @@ def _on_message(self, msg: Any, topic: Any) -> None:
return
self._last_log[entity_path] = now

# TFMessages are handled by the shared TF service via callbacks.
# Early return prevents them from hitting the visual_override path.
if self.config.tf_enabled and isinstance(msg, TFMessage):
return

# apply visual overrides (including final_convert which handles .to_rerun())
rerun_data: RerunData | None = self._visual_override_for_entity_path(entity_path)(msg)

Expand All @@ -294,6 +279,10 @@ def start(self) -> None:
super().start()

self._last_log: dict[str, float] = {}
self._last_tf_render_time: float = 0.0
if self.config.tf_enabled:
unsub_tf = self.tf.subscribe(self._on_tf_changed)
self._disposables.add(Disposable(unsub_tf))
logger.info("Rerun bridge starting", viewer_mode=self.config.viewer_mode)

# Initialize and spawn Rerun viewer
Expand Down Expand Up @@ -341,6 +330,60 @@ def start(self) -> None:

self._log_static()

def _on_tf_changed(self) -> None:
"""Called by TF service on every transform update. Rate-limits re-renders."""
now = time.monotonic()
if now - self._last_tf_render_time >= self.config.min_interval_sec:
self._last_tf_render_time = now
self._render_tf_tree()

def _render_tf_tree(self) -> None:
"""Render the TF tree as hierarchical Rerun entity paths.

Uses the shared TF service's children_of/roots (cached by edge count)
and DFS-walks the tree to log each transform at its hierarchical
entity path (e.g. world/base_link/camera).
"""
import rerun as rr

tf = self.tf
children = tf.children_of
roots = tf.roots

visited: set[str] = set()

def _walk(frame: str, entity_path: str) -> None:
if frame in visited:
return
visited.add(frame)
for child in children.get(frame, []):
child_path = f"{entity_path}/{child}"
transform = tf.get_transform(frame, child)
if transform is not None:
rr.log(
child_path,
rr.Transform3D(
translation=[
transform.translation.x,
transform.translation.y,
transform.translation.z,
],
rotation=rr.Quaternion(
xyzw=[
transform.rotation.x,
transform.rotation.y,
transform.rotation.z,
transform.rotation.w,
],
),
),
)
_walk(child, child_path)

prefix = self.config.entity_prefix
for root in roots:
_walk(root, f"{prefix}/{root}" if prefix else root)

def _log_static(self) -> None:
import rerun as rr

Expand Down
Loading
Loading