diff --git a/changes.md b/changes.md new file mode 100644 index 0000000000..d1e4b4b2e7 --- /dev/null +++ b/changes.md @@ -0,0 +1,19 @@ +# PR #1643 (rconnect) — Paul Review Fixes + +## Commits (local, not pushed) + +### 1. `81769d273` — Log exception + unblock stop() on startup failure +- If `_serve()` throws, `_server_ready` was never set → `stop()` blocked 5s +- Now logs exception and sets `_server_ready` in finally +- **Revert:** `git revert 81769d273` + +## Reviewer was wrong on +- `_server_ready` race — it IS set inside `async with` (after bind), not before +- `msg.get("x") or 0` — code already uses `msg.get("x", 0)` correctly + +## Not addressed (need Jeff's input) +- `vis_module` always bundling `RerunWebSocketServer` — opt-out design choice +- `LCM()` instantiated for non-rerun backends — wasted resource +- `rerun-connect` skipping `WebsocketVisModule` — intentional? +- Default `host = "0.0.0.0"` — intentional for remote viewer use case +- Hardcoded test ports — should use port=0 for parallel safety diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index b8165658d9..9c5623d141 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -32,7 +32,7 @@ from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.spec import perception -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module def default_transform() -> Transform: @@ -120,5 +120,5 @@ def stop(self) -> None: demo_camera = autoconnect( CameraModule.blueprint(), - RerunBridgeModule.blueprint(), + vis_module("rerun"), ) diff --git a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py index f3de842b46..b39dd7bcec 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py +++ b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py @@ -15,36 +15,45 @@ from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 from dimos.mapping.voxels import VoxelGridMapper -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module voxel_size = 0.05 mid360_fastlio = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=-1), - RerunBridgeModule.blueprint( - visual_override={ - "world/lidar": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + }, + }, ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") mid360_fastlio_voxels = autoconnect( FastLio2.blueprint(), VoxelGridMapper.blueprint(publish_interval=1.0, voxel_size=voxel_size, carve_columns=False), - RerunBridgeModule.blueprint( - visual_override={ - "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - "world/lidar": None, - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + "world/lidar": None, + }, + }, ), ).global_config(n_workers=3, robot_model="mid360_fastlio2_voxels") mid360_fastlio_voxels_native = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=3.0), - RerunBridgeModule.blueprint( - visual_override={ - "world/lidar": None, - "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": None, + "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + }, + }, ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") diff --git a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py index c8835b3e89..958af084e2 100644 --- a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py +++ b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py @@ -14,9 +14,9 @@ from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.lidar.livox.module import Mid360 -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module mid360 = autoconnect( Mid360.blueprint(), - RerunBridgeModule.blueprint(), + vis_module("rerun"), ).global_config(n_workers=2, robot_model="mid360") diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index a9fb0fb44b..90e468aaf2 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -46,8 +46,8 @@ from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.sensor_msgs.JointState import JointState from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.robot.foxglove_bridge import FoxgloveBridge # TODO: migrate to rerun from dimos.utils.data import get_data +from dimos.visualization.vis_module import vis_module def _make_base_pose( @@ -409,7 +409,7 @@ def _make_piper_config( base_transform=_XARM_PERCEPTION_CAMERA_TRANSFORM, ), ObjectSceneRegistrationModule.blueprint(target_frame="world"), - FoxgloveBridge.blueprint(), # TODO: migrate to rerun + vis_module("foxglove"), ) .transports( { diff --git a/dimos/manipulation/grasping/demo_grasping.py b/dimos/manipulation/grasping/demo_grasping.py index 782283029b..f1ce67709e 100644 --- a/dimos/manipulation/grasping/demo_grasping.py +++ b/dimos/manipulation/grasping/demo_grasping.py @@ -14,15 +14,14 @@ # limitations under the License. from pathlib import Path -from dimos.agents.mcp.mcp_client import McpClient -from dimos.agents.mcp.mcp_server import McpServer +from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera from dimos.manipulation.grasping.graspgen_module import graspgen from dimos.manipulation.grasping.grasping import GraspingModule from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.visualization.vis_module import vis_module camera_module = RealSenseCamera.blueprint(enable_pointcloud=False) @@ -44,7 +43,6 @@ ("/tmp", "/tmp", "rw") ], # Grasp visualization debug standalone: python -m dimos.manipulation.grasping.visualize_grasps ), - FoxgloveBridge.blueprint(), - McpServer.blueprint(), - McpClient.blueprint(), + vis_module("foxglove"), + Agent.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/perception/demo_object_scene_registration.py b/dimos/perception/demo_object_scene_registration.py index c6d8c96625..13fb26cbb5 100644 --- a/dimos/perception/demo_object_scene_registration.py +++ b/dimos/perception/demo_object_scene_registration.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.mcp.mcp_client import McpClient -from dimos.agents.mcp.mcp_server import McpServer +from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera from dimos.hardware.sensors.camera.zed.compat import ZEDCamera from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.visualization.vis_module import vis_module camera_choice = "zed" @@ -34,7 +33,6 @@ demo_object_scene_registration = autoconnect( camera_module, ObjectSceneRegistrationModule.blueprint(target_frame="world", prompt_mode=YoloePromptMode.LRPC), - FoxgloveBridge.blueprint(), - McpServer.blueprint(), - McpClient.blueprint(), + vis_module("foxglove"), + Agent.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 5910093d61..44bfa8e280 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -157,6 +157,7 @@ "reid-module": "dimos.perception.detection.reid.module", "replanning-a-star-planner": "dimos.navigation.replanning_a_star.module", "rerun-bridge-module": "dimos.visualization.rerun.bridge", + "rerun-web-socket-server": "dimos.visualization.rerun.websocket_server", "ros-nav": "dimos.navigation.rosnav", "simple-phone-teleop": "dimos.teleop.phone.phone_extensions", "spatial-memory": "dimos.perception.spatial_perception", diff --git a/dimos/robot/drone/blueprints/basic/drone_basic.py b/dimos/robot/drone/blueprints/basic/drone_basic.py index fbe6621ae1..c60483cb0a 100644 --- a/dimos/robot/drone/blueprints/basic/drone_basic.py +++ b/dimos/robot/drone/blueprints/basic/drone_basic.py @@ -20,10 +20,9 @@ from dimos.core.blueprints import autoconnect from dimos.core.global_config import global_config -from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.robot.drone.camera_module import DroneCameraModule from dimos.robot.drone.connection_module import DroneConnectionModule -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module def _static_drone_body(rr: Any) -> list[Any]: @@ -60,23 +59,12 @@ def _drone_rerun_blueprint() -> Any: _rerun_config = { "blueprint": _drone_rerun_blueprint, - "pubsubs": [LCM()], "static": { "world/tf/base_link": _static_drone_body, }, } -# Conditional visualization -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - _vis = FoxgloveBridge.blueprint() -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - _vis = RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config) -else: - _vis = autoconnect() +_vis = vis_module(global_config.viewer, rerun_config=_rerun_config) # Determine connection string based on replay flag connection_string = "udp:0.0.0.0:14550" @@ -92,7 +80,6 @@ def _drone_rerun_blueprint() -> Any: outdoor=False, ), DroneCameraModule.blueprint(camera_intrinsics=[1000.0, 1000.0, 960.0, 540.0]), - WebsocketVisModule.blueprint(), ) __all__ = [ diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py index 5b127fb697..9efe400895 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py @@ -17,10 +17,11 @@ from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image -from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1 import unitree_g1 +from dimos.visualization.vis_module import vis_module unitree_g1_shm = autoconnect( unitree_g1.transports( @@ -30,10 +31,9 @@ ), } ), - FoxgloveBridge.blueprint( - shm_channels=[ - "/color_image#sensor_msgs.Image", - ] + vis_module( + global_config.viewer, + foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, ), ) diff --git a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py index c3da9521c5..220caff949 100644 --- a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py +++ b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py @@ -40,8 +40,7 @@ from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) -from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module def _convert_camera_info(camera_info: Any) -> Any: @@ -90,7 +89,6 @@ def _g1_rerun_blueprint() -> Any: rerun_config = { "blueprint": _g1_rerun_blueprint, - "pubsubs": [LCM()], "visual_override": { "world/camera_info": _convert_camera_info, "world/global_map": _convert_global_map, @@ -101,18 +99,7 @@ def _g1_rerun_blueprint() -> Any: }, } -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - _with_vis = autoconnect(FoxgloveBridge.blueprint()) -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - _with_vis = autoconnect( - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config) - ) -else: - _with_vis = autoconnect() +_with_vis = vis_module(global_config.viewer, rerun_config=rerun_config) def _create_webcam() -> Webcam: @@ -147,8 +134,6 @@ def _create_webcam() -> Webcam: VoxelGridMapper.blueprint(voxel_size=0.1), CostMapper.blueprint(), WavefrontFrontierExplorer.blueprint(), - # Visualization - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_g1") .transports( diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index a0d1e6a7ae..1e0f32d25c 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -22,10 +22,9 @@ from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image -from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.connection import GO2Connection -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module # Mac has some issue with high bandwidth UDP, so we use pSHMTransport for color_image # actually we can use pSHMTransport for all platforms, and for all streams @@ -87,9 +86,6 @@ def _go2_rerun_blueprint() -> Any: rerun_config = { "blueprint": _go2_rerun_blueprint, - # any pubsub that supports subscribe_all and topic that supports str(topic) - # is acceptable here - "pubsubs": [LCM()], # Custom converters for specific rerun entity paths # Normally all these would be specified in their respectative modules # Until this is implemented we have central overrides here @@ -106,29 +102,19 @@ def _go2_rerun_blueprint() -> Any: }, } - -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - with_vis = autoconnect( - _transports_base, - FoxgloveBridge.blueprint(shm_channels=["/color_image#sensor_msgs.Image"]), - ) -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - with_vis = autoconnect( - _transports_base, - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), - ) -else: - with_vis = _transports_base +_with_vis = autoconnect( + _transports_base, + vis_module( + global_config.viewer, + rerun_config=rerun_config, + foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, + ), +) unitree_go2_basic = ( autoconnect( - with_vis, + _with_vis, GO2Connection.blueprint(), - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py index 1c55f3e93c..0468cad40d 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py @@ -22,15 +22,13 @@ from dimos.core.blueprints import autoconnect from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator -from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import with_vis +from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import _with_vis from dimos.robot.unitree.go2.fleet_connection import Go2FleetConnection -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule unitree_go2_fleet = ( autoconnect( - with_vis, + _with_vis, Go2FleetConnection.blueprint(), - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index d6367310de..1b67de3b75 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -26,12 +26,12 @@ from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_extensions import ArmTeleopModule from dimos.teleop.quest.quest_types import Buttons -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module # Arm teleop with press-and-hold engage (has rerun viz) teleop_quest_rerun = autoconnect( ArmTeleopModule.blueprint(), - RerunBridgeModule.blueprint(), + vis_module("rerun"), ).transports( { ("left_controller_output", PoseStamped): LCMTransport("/teleop/left_delta", PoseStamped), diff --git a/dimos/visualization/rerun/test_viewer_ws_e2e.py b/dimos/visualization/rerun/test_viewer_ws_e2e.py new file mode 100644 index 0000000000..80c4743e61 --- /dev/null +++ b/dimos/visualization/rerun/test_viewer_ws_e2e.py @@ -0,0 +1,329 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end test: dimos-viewer (headless) → WebSocket → RerunWebSocketServer. + +dimos-viewer is started in ``--connect`` mode so it initialises its WebSocket +client. The viewer needs a gRPC proxy to connect to; we give it a non-existent +one so the viewer starts up anyway but produces no visualisation. The important +part is that the WebSocket client inside the viewer tries to connect to +``ws://127.0.0.1:/ws``. + +Because the viewer is a native GUI application it cannot run headlessly in CI +without a display. This test therefore verifies the connection at the protocol +level by using the ``RerunWebSocketServer`` module directly as the server and +injecting synthetic JSON messages that mimic what the viewer would send once a +user clicks in the 3D viewport. +""" + +import asyncio +import json +import os +import shutil +import subprocess +import threading +import time +from typing import Any + +import pytest + +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + +_E2E_PORT = 13032 + + +def _make_server(port: int = _E2E_PORT) -> RerunWebSocketServer: + return RerunWebSocketServer(port=port) + + +def _wait_for_server(port: int, timeout: float = 5.0) -> None: + import websockets.asyncio.client as ws_client + + async def _probe() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws"): + pass + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + asyncio.run(_probe()) + return + except Exception: + time.sleep(0.05) + raise TimeoutError(f"Server on port {port} did not become ready within {timeout}s") + + +def _send_messages(port: int, messages: list[dict[str, Any]], *, delay: float = 0.05) -> None: + import websockets.asyncio.client as ws_client + + async def _run() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws") as ws: + for msg in messages: + await ws.send(json.dumps(msg)) + await asyncio.sleep(delay) + + asyncio.run(_run()) + + +class TestViewerProtocolE2E: + """Verify the full Python-server side of the viewer ↔ DimOS protocol. + + These tests use the ``RerunWebSocketServer`` as the server and a dummy + WebSocket client (playing the role of dimos-viewer) to inject messages. + They confirm every message type is correctly routed and that only click + messages produce stream publishes. + """ + + def test_viewer_click_reaches_stream(self) -> None: + """A viewer click message received over WebSocket publishes PointStamped.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + done = threading.Event() + + def _on_pt(pt: Any) -> None: + received.append(pt) + done.set() + + server.clicked_point.subscribe(_on_pt) + + _send_messages( + _E2E_PORT, + [ + { + "type": "click", + "x": 10.0, + "y": 20.0, + "z": 0.5, + "entity_path": "/world/robot", + "timestamp_ms": 42000, + } + ], + ) + + done.wait(timeout=3.0) + server.stop() + + assert len(received) == 1 + pt = received[0] + assert abs(pt.x - 10.0) < 1e-9 + assert abs(pt.y - 20.0) < 1e-9 + assert abs(pt.z - 0.5) < 1e-9 + assert pt.frame_id == "/world/robot" + assert abs(pt.ts - 42.0) < 1e-6 + + def test_viewer_keyboard_twist_no_publish(self) -> None: + """Twist messages from keyboard control do not publish clicked_point.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + server.clicked_point.subscribe(received.append) + + _send_messages( + _E2E_PORT, + [ + { + "type": "twist", + "linear_x": 0.5, + "linear_y": 0.0, + "linear_z": 0.0, + "angular_x": 0.0, + "angular_y": 0.0, + "angular_z": 0.8, + } + ], + ) + + server.stop() + assert received == [] + + def test_viewer_stop_no_publish(self) -> None: + """Stop messages do not publish clicked_point.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + server.clicked_point.subscribe(received.append) + + _send_messages(_E2E_PORT, [{"type": "stop"}]) + + server.stop() + assert received == [] + + def test_full_viewer_session_sequence(self) -> None: + """Realistic session: connect, heartbeats, click, WASD, stop → one point.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + done = threading.Event() + + def _on_pt(pt: Any) -> None: + received.append(pt) + done.set() + + server.clicked_point.subscribe(_on_pt) + + _send_messages( + _E2E_PORT, + [ + # Initial heartbeats (viewer connects and starts 1 Hz heartbeat) + {"type": "heartbeat", "timestamp_ms": 1000}, + {"type": "heartbeat", "timestamp_ms": 2000}, + # User clicks a point in the 3D viewport + { + "type": "click", + "x": 3.14, + "y": 2.71, + "z": 1.41, + "entity_path": "/world", + "timestamp_ms": 3000, + }, + # User presses W (forward) + { + "type": "twist", + "linear_x": 0.5, + "linear_y": 0.0, + "linear_z": 0.0, + "angular_x": 0.0, + "angular_y": 0.0, + "angular_z": 0.0, + }, + # User releases W + {"type": "stop"}, + # Another heartbeat + {"type": "heartbeat", "timestamp_ms": 4000}, + ], + delay=0.2, + ) + + done.wait(timeout=3.0) + server.stop() + + assert len(received) == 1, f"Expected exactly 1 click, got {len(received)}" + pt = received[0] + assert abs(pt.x - 3.14) < 1e-9 + assert abs(pt.y - 2.71) < 1e-9 + assert abs(pt.z - 1.41) < 1e-9 + + def test_reconnect_after_disconnect(self) -> None: + """Server keeps accepting new connections after a client disconnects.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + all_done = threading.Event() + + def _on_pt(pt: Any) -> None: + received.append(pt) + if len(received) >= 2: + all_done.set() + + server.clicked_point.subscribe(_on_pt) + + # First connection — send one click and disconnect + _send_messages( + _E2E_PORT, + [{"type": "click", "x": 1.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], + ) + + # Second connection (simulating viewer reconnect) — send another click + _send_messages( + _E2E_PORT, + [{"type": "click", "x": 2.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], + ) + + all_done.wait(timeout=5.0) + server.stop() + + xs = sorted(pt.x for pt in received) + assert xs == [1.0, 2.0], f"Unexpected xs: {xs}" + + +class TestViewerBinaryConnectMode: + """Smoke test: dimos-viewer binary starts in --connect mode and its WebSocket + client attempts to connect to our Python server.""" + + @pytest.mark.skipif( + shutil.which("dimos-viewer") is None + or "--connect" not in subprocess.run( + ["dimos-viewer", "--help"], capture_output=True, text=True + ).stdout, + reason="dimos-viewer binary not installed or does not support --connect", + ) + def test_viewer_ws_client_connects(self) -> None: + """dimos-viewer --connect starts and its WS client connects to our server.""" + server = _make_server() + server.start() + _wait_for_server(_E2E_PORT) + + received: list[Any] = [] + + def _on_pt(pt: Any) -> None: + received.append(pt) + + server.clicked_point.subscribe(_on_pt) + + # Start dimos-viewer in --connect mode, pointing it at a non-existent gRPC + # proxy (it will fail to stream data, but that's fine) and at our WS server. + # Use DISPLAY="" to prevent it from opening a window (it will exit quickly + # without a display, but the WebSocket connection happens before the GUI loop). + proc = subprocess.Popen( + [ + "dimos-viewer", + "--connect", + f"--ws-url=ws://127.0.0.1:{_E2E_PORT}/ws", + ], + env={ + **os.environ, + "DISPLAY": "", + }, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # Give the viewer up to 5 s to connect its WebSocket client to our server. + # We detect the connection by waiting for the server to accept a client. + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + # Check if any connection was established by sending a message and + # verifying the viewer is still running. + if proc.poll() is not None: + # Viewer exited (expected without a display) — check if it connected first. + break + time.sleep(0.1) + + proc.terminate() + try: + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + proc.kill() + + stdout = proc.stdout.read().decode(errors="replace") if proc.stdout else "" + stderr = proc.stderr.read().decode(errors="replace") if proc.stderr else "" + server.stop() + + # The viewer should log that it is connecting to our WS URL. + # Check both stdout and stderr since log output destination varies. + combined = stdout + stderr + assert f"ws://127.0.0.1:{_E2E_PORT}" in combined, ( + f"Viewer did not attempt WS connection.\nstdout:\n{stdout}\nstderr:\n{stderr}" + ) diff --git a/dimos/visualization/rerun/test_websocket_server.py b/dimos/visualization/rerun/test_websocket_server.py new file mode 100644 index 0000000000..73c6759eec --- /dev/null +++ b/dimos/visualization/rerun/test_websocket_server.py @@ -0,0 +1,407 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RerunWebSocketServer. + +Uses ``MockViewerPublisher`` to simulate dimos-viewer sending events, matching +the exact JSON protocol used by the Rust ``WsPublisher`` in the viewer. +""" + +import asyncio +import json +import threading +import time +from typing import Any + +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + +_TEST_PORT = 13031 + + +class MockViewerPublisher: + """Python mirror of the Rust WsPublisher in dimos-viewer. + + Connects to a running ``RerunWebSocketServer`` and exposes the same + ``send_click`` / ``send_twist`` / ``send_stop`` / ``send_heartbeat`` + API that the real viewer uses. Useful for unit tests that need to + exercise the server without a real viewer binary. + + Usage:: + + with MockViewerPublisher("ws://127.0.0.1:13031/ws") as pub: + pub.send_click(1.0, 2.0, 0.0, "/world", timestamp_ms=1000) + pub.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) + pub.send_stop() + """ + + def __init__(self, url: str) -> None: + self._url = url + self._ws: Any = None + self._loop: asyncio.AbstractEventLoop | None = None + + def __enter__(self) -> "MockViewerPublisher": + self._loop = asyncio.new_event_loop() + self._ws = self._loop.run_until_complete(self._connect()) + return self + + def __exit__(self, *_: Any) -> None: + if self._ws is not None and self._loop is not None: + self._loop.run_until_complete(self._ws.close()) + if self._loop is not None: + self._loop.close() + + async def _connect(self) -> Any: + import websockets.asyncio.client as ws_client + + return await ws_client.connect(self._url) + + def send_click( + self, + x: float, + y: float, + z: float, + entity_path: str = "", + timestamp_ms: int = 0, + ) -> None: + """Send a click event — matches viewer SelectionChange handler output.""" + self._send( + { + "type": "click", + "x": x, + "y": y, + "z": z, + "entity_path": entity_path, + "timestamp_ms": timestamp_ms, + } + ) + + def send_twist( + self, + linear_x: float, + linear_y: float, + linear_z: float, + angular_x: float, + angular_y: float, + angular_z: float, + ) -> None: + """Send a twist (WASD keyboard) event.""" + self._send( + { + "type": "twist", + "linear_x": linear_x, + "linear_y": linear_y, + "linear_z": linear_z, + "angular_x": angular_x, + "angular_y": angular_y, + "angular_z": angular_z, + } + ) + + def send_stop(self) -> None: + """Send a stop event (Space bar or key release).""" + self._send({"type": "stop"}) + + def send_heartbeat(self, timestamp_ms: int = 0) -> None: + """Send a heartbeat (1 Hz keepalive from viewer).""" + self._send({"type": "heartbeat", "timestamp_ms": timestamp_ms}) + + def flush(self, delay: float = 0.1) -> None: + """Wait briefly so the server processes queued messages.""" + time.sleep(delay) + + def _send(self, msg: dict[str, Any]) -> None: + assert self._loop is not None and self._ws is not None, "Not connected" + self._loop.run_until_complete(self._ws.send(json.dumps(msg))) + + +def _collect(received: list[Any], done: threading.Event) -> Any: + """Return a callback that appends to *received* and signals *done*.""" + + def _cb(msg: Any) -> None: + received.append(msg) + done.set() + + return _cb + + +def _make_module(port: int = _TEST_PORT) -> RerunWebSocketServer: + return RerunWebSocketServer(port=port) + + +def _wait_for_server(port: int, timeout: float = 3.0) -> None: + """Block until the WebSocket server accepts an upgrade handshake.""" + + async def _probe() -> None: + import websockets.asyncio.client as ws_client + + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws"): + pass + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + asyncio.run(_probe()) + return + except Exception: + time.sleep(0.05) + raise TimeoutError(f"Server on port {port} did not become ready within {timeout}s") + + +class TestRerunWebSocketServerStartup: + def test_server_binds_port(self) -> None: + """After start(), the server must be reachable on the configured port.""" + mod = _make_module() + mod.start() + try: + _wait_for_server(_TEST_PORT) + finally: + mod.stop() + + def test_stop_is_idempotent(self) -> None: + """Calling stop() twice must not raise.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + mod.stop() + mod.stop() + + +class TestClickMessages: + def test_click_publishes_point_stamped(self) -> None: + """A single click publishes one PointStamped with correct coords.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.clicked_point.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(1.5, 2.5, 0.0, "/world", timestamp_ms=1000) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + pt = received[0] + assert abs(pt.x - 1.5) < 1e-9 + assert abs(pt.y - 2.5) < 1e-9 + assert abs(pt.z - 0.0) < 1e-9 + + def test_click_sets_frame_id_from_entity_path(self) -> None: + """entity_path is stored as frame_id on the published PointStamped.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.clicked_point.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(0.0, 0.0, 0.0, "/robot/base", timestamp_ms=2000) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + assert received and received[0].frame_id == "/robot/base" + + def test_click_timestamp_converted_from_ms(self) -> None: + """timestamp_ms is converted to seconds on PointStamped.ts.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.clicked_point.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(0.0, 0.0, 0.0, "", timestamp_ms=5000) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + assert received and abs(received[0].ts - 5.0) < 1e-6 + + def test_multiple_clicks_all_published(self) -> None: + """A burst of clicks all arrive on the stream.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + all_arrived = threading.Event() + + def _cb(pt: Any) -> None: + received.append(pt) + if len(received) >= 3: + all_arrived.set() + + mod.clicked_point.subscribe(_cb) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_click(1.0, 0.0, 0.0) + pub.send_click(2.0, 0.0, 0.0) + pub.send_click(3.0, 0.0, 0.0) + pub.flush() + + all_arrived.wait(timeout=3.0) + mod.stop() + + assert sorted(pt.x for pt in received) == [1.0, 2.0, 3.0] + + +class TestNonClickMessages: + def test_heartbeat_does_not_publish(self) -> None: + """Heartbeat messages must not trigger a clicked_point publish.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + mod.clicked_point.subscribe(received.append) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_heartbeat(9999) + pub.flush() + + mod.stop() + assert received == [] + + def test_twist_does_not_publish_clicked_point(self) -> None: + """Twist messages must not trigger a clicked_point publish.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + mod.clicked_point.subscribe(received.append) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) + pub.flush() + + mod.stop() + assert received == [] + + def test_stop_does_not_publish_clicked_point(self) -> None: + """Stop messages must not trigger a clicked_point publish.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + mod.clicked_point.subscribe(received.append) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_stop() + pub.flush() + + mod.stop() + assert received == [] + + def test_twist_publishes_on_tele_cmd_vel(self) -> None: + """Twist messages publish a Twist on the tele_cmd_vel stream.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.tele_cmd_vel.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + tw = received[0] + assert abs(tw.linear.x - 0.5) < 1e-9 + assert abs(tw.angular.z - 0.8) < 1e-9 + + def test_stop_publishes_zero_twist_on_tele_cmd_vel(self) -> None: + """Stop messages publish a zero Twist on the tele_cmd_vel stream.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + received: list[Any] = [] + done = threading.Event() + mod.tele_cmd_vel.subscribe(_collect(received, done)) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_stop() + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + tw = received[0] + assert tw.is_zero() + + def test_invalid_json_does_not_crash(self) -> None: + """Malformed JSON is silently dropped; server stays alive.""" + import websockets.asyncio.client as ws_client + + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + async def _send_bad() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{_TEST_PORT}/ws") as ws: + await ws.send("this is not json {{") + await asyncio.sleep(0.1) + await ws.send(json.dumps({"type": "heartbeat", "timestamp_ms": 0})) + await asyncio.sleep(0.1) + + asyncio.run(_send_bad()) + mod.stop() + + def test_mixed_message_sequence(self) -> None: + """Realistic sequence: heartbeat → click → twist → stop publishes one point.""" + mod = _make_module() + mod.start() + _wait_for_server(_TEST_PORT) + + # Subscribe before sending so we don't race against the click dispatch. + received: list[Any] = [] + done = threading.Event() + + def _cb(pt: Any) -> None: + received.append(pt) + done.set() + + mod.clicked_point.subscribe(_cb) + + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as pub: + pub.send_heartbeat(1000) + pub.send_click(7.0, 8.0, 9.0, "/map", timestamp_ms=1100) + pub.send_twist(0.3, 0.0, 0.0, 0.0, 0.0, 0.2) + pub.send_stop() + pub.flush() + + done.wait(timeout=2.0) + mod.stop() + + assert len(received) == 1 + assert abs(received[0].x - 7.0) < 1e-9 + assert abs(received[0].y - 8.0) < 1e-9 + assert abs(received[0].z - 9.0) < 1e-9 diff --git a/dimos/visualization/rerun/websocket_server.py b/dimos/visualization/rerun/websocket_server.py new file mode 100644 index 0000000000..e75df4eb25 --- /dev/null +++ b/dimos/visualization/rerun/websocket_server.py @@ -0,0 +1,201 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WebSocket server module that receives events from dimos-viewer. + +When dimos-viewer is started with ``--connect``, LCM multicast is unavailable +across machines. The viewer falls back to sending click, twist, and stop events +as JSON over a WebSocket connection. This module acts as the server-side +counterpart: it listens for those connections and translates incoming messages +into DimOS stream publishes. + +Message format (newline-delimited JSON, ``"type"`` discriminant): + + {"type":"heartbeat","timestamp_ms":1234567890} + {"type":"click","x":1.0,"y":2.0,"z":3.0,"entity_path":"/world","timestamp_ms":1234567890} + {"type":"twist","linear_x":0.5,"linear_y":0.0,"linear_z":0.0, + "angular_x":0.0,"angular_y":0.0,"angular_z":0.8} + {"type":"stop"} +""" + +import asyncio +import json +import threading +from typing import Any + +import websockets + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class Config(ModuleConfig): + # Intentionally binds 0.0.0.0 by default so the viewer can connect from + # any machine on the network (the typical robot deployment scenario). + host: str = "0.0.0.0" + port: int = 3030 + + +class RerunWebSocketServer(Module[Config]): + """Receives dimos-viewer WebSocket events and publishes them as DimOS streams. + + The viewer connects to this module (not the other way around) when running + in ``--connect`` mode. Each click event is converted to a ``PointStamped`` + and published on the ``clicked_point`` stream so downstream modules (e.g. + ``ReplanningAStarPlanner``) can consume it without modification. + + Outputs: + clicked_point: 3-D world-space point from the most recent viewer click. + tele_cmd_vel: Twist velocity commands from keyboard teleop, including stop events. + """ + + default_config = Config + + clicked_point: Out[PointStamped] + tele_cmd_vel: Out[Twist] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._ws_loop: asyncio.AbstractEventLoop | None = None + self._server_thread: threading.Thread | None = None + self._stop_event: asyncio.Event | None = None + self._server_ready = threading.Event() + + @rpc + def start(self) -> None: + super().start() + self._server_thread = threading.Thread( + target=self._run_server, daemon=True, name="rerun-ws-server" + ) + self._server_thread.start() + logger.info( + f"RerunWebSocketServer starting on ws://{self.config.host}:{self.config.port}/ws" + ) + + @rpc + def stop(self) -> None: + # Wait briefly for the server thread to initialise _stop_event so we + # don't silently skip the shutdown signal (race with _serve()). + self._server_ready.wait(timeout=5.0) + if ( + self._ws_loop is not None + and not self._ws_loop.is_closed() + and self._stop_event is not None + ): + self._ws_loop.call_soon_threadsafe(self._stop_event.set) + super().stop() + + def _run_server(self) -> None: + """Entry point for the background server thread.""" + self._ws_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._ws_loop) + try: + self._ws_loop.run_until_complete(self._serve()) + except Exception: + logger.exception("RerunWebSocketServer failed to start") + finally: + self._server_ready.set() # unblock stop() even on failure + self._ws_loop.close() + + async def _serve(self) -> None: + import websockets.asyncio.server as ws_server + + self._stop_event = asyncio.Event() + + async with ws_server.serve( + self._handle_client, + host=self.config.host, + port=self.config.port, + # Ping every 30 s, allow 30 s for pong — generous enough to + # survive brief network hiccups while still detecting dead clients. + ping_interval=30, + ping_timeout=30, + ): + self._server_ready.set() + logger.info( + f"RerunWebSocketServer listening on ws://{self.config.host}:{self.config.port}/ws" + ) + await self._stop_event.wait() + + async def _handle_client(self, websocket: Any) -> None: + if hasattr(websocket, "request") and websocket.request.path != "/ws": + await websocket.close(1008, "Not Found") + return + addr = websocket.remote_address + logger.info(f"RerunWebSocketServer: viewer connected from {addr}") + try: + async for raw in websocket: + self._dispatch(raw) + except websockets.ConnectionClosed as exc: + logger.debug(f"RerunWebSocketServer: client {addr} disconnected ({exc})") + + def _dispatch(self, raw: str | bytes) -> None: + try: + msg = json.loads(raw) + except json.JSONDecodeError: + logger.warning(f"RerunWebSocketServer: ignoring non-JSON message: {raw!r}") + return + + if not isinstance(msg, dict): + logger.warning(f"RerunWebSocketServer: expected JSON object, got {type(msg).__name__}") + return + + msg_type = msg.get("type") + + if msg_type == "click": + pt = PointStamped( + x=float(msg.get("x", 0)), + y=float(msg.get("y", 0)), + z=float(msg.get("z", 0)), + ts=float(msg.get("timestamp_ms", 0)) / 1000.0, + frame_id=str(msg.get("entity_path", "")), + ) + logger.debug(f"RerunWebSocketServer: click → {pt}") + self.clicked_point.publish(pt) + + elif msg_type == "twist": + twist = Twist( + linear=Vector3( + float(msg.get("linear_x", 0)), + float(msg.get("linear_y", 0)), + float(msg.get("linear_z", 0)), + ), + angular=Vector3( + float(msg.get("angular_x", 0)), + float(msg.get("angular_y", 0)), + float(msg.get("angular_z", 0)), + ), + ) + logger.debug(f"RerunWebSocketServer: twist → {twist}") + self.tele_cmd_vel.publish(twist) + + elif msg_type == "stop": + logger.debug("RerunWebSocketServer: stop") + self.tele_cmd_vel.publish(Twist.zero()) + + elif msg_type == "heartbeat": + logger.debug(f"RerunWebSocketServer: heartbeat ts={msg.get('timestamp_ms')}") + + else: + logger.warning(f"RerunWebSocketServer: unknown message type {msg_type!r}") + + +rerun_ws_server = RerunWebSocketServer.blueprint diff --git a/dimos/visualization/vis_module.py b/dimos/visualization/vis_module.py new file mode 100644 index 0000000000..688a6efb5b --- /dev/null +++ b/dimos/visualization/vis_module.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared visualization module factory for all robot blueprints.""" + +from typing import Any + +from dimos.core.blueprints import Blueprint, autoconnect +from dimos.core.global_config import ViewerBackend +from dimos.protocol.pubsub.impl.lcmpubsub import LCM + + +def vis_module( + viewer_backend: ViewerBackend, + rerun_config: dict[str, Any] | None = None, + foxglove_config: dict[str, Any] | None = None, +) -> Blueprint: + """Create a visualization blueprint based on the selected viewer backend. + + Bundles the appropriate viewer module (Rerun or Foxglove) together with + the ``WebsocketVisModule`` and ``RerunWebSocketServer`` so that the web + dashboard and remote viewer connections work out of the box. + + Example usage:: + + from dimos.core.global_config import global_config + viz = vis_module( + global_config.viewer, + rerun_config={ + "visual_override": { + "world/camera_info": lambda ci: ci.to_rerun(...), + }, + "static": { + "world/tf/base_link": lambda rr: [rr.Boxes3D(...)], + }, + }, + ) + """ + from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule + + if foxglove_config is None: + foxglove_config = {} + if rerun_config is None: + rerun_config = {} + rerun_config = {**rerun_config} + rerun_config.setdefault("pubsubs", [LCM()]) + + match viewer_backend: + case "foxglove": + from dimos.robot.foxglove_bridge import FoxgloveBridge + + return autoconnect( + FoxgloveBridge.blueprint(**foxglove_config), + WebsocketVisModule.blueprint(), + ) + case "rerun" | "rerun-web": + from dimos.visualization.rerun.bridge import _BACKEND_TO_MODE, RerunBridgeModule + from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + + viewer_mode = _BACKEND_TO_MODE.get(viewer_backend, "native") + return autoconnect( + RerunBridgeModule.blueprint(viewer_mode=viewer_mode, **rerun_config), + RerunWebSocketServer.blueprint(), + WebsocketVisModule.blueprint(), + ) + case "rerun-connect": + from dimos.visualization.rerun.bridge import RerunBridgeModule + from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + + return autoconnect( + RerunBridgeModule.blueprint(viewer_mode="connect", **rerun_config), + RerunWebSocketServer.blueprint(), + ) + case _: + return autoconnect(WebsocketVisModule.blueprint())