diff --git a/data/.lfs/go2_bigoffice.db.tar.gz b/data/.lfs/go2_bigoffice.db.tar.gz index cad393bfcc..315610b5cb 100644 --- a/data/.lfs/go2_bigoffice.db.tar.gz +++ b/data/.lfs/go2_bigoffice.db.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2d48cb0b8250bb2878d1008093d45ea377406de00ad42f0f96d7b382e1a9617b -size 191193336 +oid sha256:142f7a7d64d3b77c97acd0d15d53e9ea28c4f558776a6bb3919a4da32c2f4d37 +size 192241937 diff --git a/dimos/agents/agent_test_runner.py b/dimos/agents/agent_test_runner.py index 9bedd613f4..78e68bc139 100644 --- a/dimos/agents/agent_test_runner.py +++ b/dimos/agents/agent_test_runner.py @@ -49,8 +49,8 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(Disposable(self.agent.subscribe(self._on_agent_message))) - self._disposables.add(Disposable(self.agent_idle.subscribe(self._on_agent_idle))) + self.register_disposable(Disposable(self.agent.subscribe(self._on_agent_message))) + self.register_disposable(Disposable(self.agent_idle.subscribe(self._on_agent_idle))) # Signal that subscription is ready self._subscription_ready.set() diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index e0200a6323..4e0d4c8291 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -168,7 +168,7 @@ def start(self) -> None: def _on_human_input(string: str) -> None: self._message_queue.put(HumanMessage(content=string)) - self._disposables.add(Disposable(self.human_input.subscribe(_on_human_input))) + self.register_disposable(Disposable(self.human_input.subscribe(_on_human_input))) @rpc def on_system_modules(self, _modules: list[RPCClient]) -> None: diff --git a/dimos/agents/skills/demo_robot.py b/dimos/agents/skills/demo_robot.py index 2917ec2d76..9e7ac8433b 100644 --- a/dimos/agents/skills/demo_robot.py +++ b/dimos/agents/skills/demo_robot.py @@ -25,7 +25,7 @@ class DemoRobot(Module): def start(self) -> None: super().start() - self._disposables.add(interval(1.0).subscribe(lambda _: self._publish_gps_location())) + self.register_disposable(interval(1.0).subscribe(lambda _: self._publish_gps_location())) def stop(self) -> None: super().stop() diff --git a/dimos/agents/skills/google_maps_skill_container.py b/dimos/agents/skills/google_maps_skill_container.py index ee48e51653..259f3ced6c 100644 --- a/dimos/agents/skills/google_maps_skill_container.py +++ b/dimos/agents/skills/google_maps_skill_container.py @@ -15,6 +15,8 @@ import json from typing import Any +from reactivex.disposable import Disposable + from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module import Module @@ -49,7 +51,7 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) # type: ignore[arg-type] + self.register_disposable(Disposable(self.gps_location.subscribe(self._on_gps_location))) @rpc def stop(self) -> None: diff --git a/dimos/agents/skills/gps_nav_skill.py b/dimos/agents/skills/gps_nav_skill.py index c6f86951be..96fdfa25ad 100644 --- a/dimos/agents/skills/gps_nav_skill.py +++ b/dimos/agents/skills/gps_nav_skill.py @@ -14,6 +14,8 @@ import json +from reactivex.disposable import Disposable + from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module import Module @@ -37,7 +39,7 @@ class GpsNavSkillContainer(Module): @rpc def start(self) -> None: super().start() - self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) # type: ignore[arg-type] + self.register_disposable(Disposable(self.gps_location.subscribe(self._on_gps_location))) @rpc def stop(self) -> None: diff --git a/dimos/agents/skills/navigation.py b/dimos/agents/skills/navigation.py index e366465959..d625179619 100644 --- a/dimos/agents/skills/navigation.py +++ b/dimos/agents/skills/navigation.py @@ -68,8 +68,8 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) - self._disposables.add(Disposable(self.odom.subscribe(self._on_odom))) + self.register_disposable(Disposable(self.color_image.subscribe(self._on_color_image))) + self.register_disposable(Disposable(self.odom.subscribe(self._on_odom))) self._skill_started = True @rpc diff --git a/dimos/agents/skills/osm.py b/dimos/agents/skills/osm.py index a89e86044f..2172ed5dc0 100644 --- a/dimos/agents/skills/osm.py +++ b/dimos/agents/skills/osm.py @@ -13,6 +13,8 @@ # limitations under the License. +from reactivex.disposable import Disposable + from dimos.agents.annotation import skill from dimos.core.module import Module from dimos.core.stream import In @@ -39,7 +41,7 @@ def __init__(self) -> None: def start(self) -> None: super().start() if hasattr(self.gps_location, "subscribe"): - self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) # type: ignore[arg-type] + self.register_disposable(Disposable(self.gps_location.subscribe(self._on_gps_location))) else: logger.warning( "OsmSkill: gps_location stream does not support direct subscribe (RemoteIn)" diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index 9f97a23d53..4fe19f203d 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -93,9 +93,9 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) + self.register_disposable(Disposable(self.color_image.subscribe(self._on_color_image))) if self.config.use_3d_navigation: - self._disposables.add(Disposable(self.global_map.subscribe(self._on_pointcloud))) + self.register_disposable(Disposable(self.global_map.subscribe(self._on_pointcloud))) @rpc def stop(self) -> None: diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py index 114302b397..7e05cd7379 100644 --- a/dimos/agents/vlm_agent.py +++ b/dimos/agents/vlm_agent.py @@ -16,6 +16,7 @@ from langchain.chat_models import init_chat_model from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from reactivex.disposable import Disposable from dimos.agents.system_prompt import SYSTEM_PROMPT from dimos.core.core import rpc @@ -60,8 +61,8 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(self.color_image.subscribe(self._on_image)) # type: ignore[arg-type] - self._disposables.add(self.query_stream.subscribe(self._on_query)) # type: ignore[arg-type] + self.register_disposable(Disposable(self.color_image.subscribe(self._on_image))) + self.register_disposable(Disposable(self.query_stream.subscribe(self._on_query))) @rpc def stop(self) -> None: diff --git a/dimos/agents/vlm_stream_tester.py b/dimos/agents/vlm_stream_tester.py index 80353dbfe0..d916d1da8f 100644 --- a/dimos/agents/vlm_stream_tester.py +++ b/dimos/agents/vlm_stream_tester.py @@ -16,6 +16,7 @@ import time from langchain_core.messages import AIMessage, HumanMessage +from reactivex.disposable import Disposable from dimos.core.core import rpc from dimos.core.module import Module @@ -62,8 +63,8 @@ def __init__( # type: ignore[no-untyped-def] @rpc def start(self) -> None: super().start() - self._disposables.add(self.color_image.subscribe(self._on_image)) # type: ignore[arg-type] - self._disposables.add(self.answer_stream.subscribe(self._on_answer)) # type: ignore[arg-type] + self.register_disposable(Disposable(self.color_image.subscribe(self._on_image))) + self.register_disposable(Disposable(self.answer_stream.subscribe(self._on_answer))) self._worker = threading.Thread(target=self._run_queries, daemon=True) self._worker.start() diff --git a/dimos/agents/web_human_input.py b/dimos/agents/web_human_input.py index 2b84736d27..5d7e075810 100644 --- a/dimos/agents/web_human_input.py +++ b/dimos/agents/web_human_input.py @@ -64,11 +64,11 @@ def start(self) -> None: # Subscribe to both text input sources # 1. Direct text from web interface unsub = self._web_interface.query_stream.subscribe(self._human_transport.publish) - self._disposables.add(unsub) + self.register_disposable(unsub) # 2. Transcribed text from STT unsub = stt_node.emit_text().subscribe(self._human_transport.publish) - self._disposables.add(unsub) + self.register_disposable(unsub) self._thread = Thread(target=self._web_interface.run, daemon=True) self._thread.start() diff --git a/dimos/core/module.py b/dimos/core/module.py index 1c5b311883..3aba358b0b 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -30,13 +30,12 @@ ) from langchain_core.tools import tool -from reactivex.disposable import CompositeDisposable from dimos.core.core import T, rpc from dimos.core.global_config import GlobalConfig, global_config from dimos.core.introspection.module.info import extract_module_info from dimos.core.introspection.module.render import render_module_io -from dimos.core.resource import Resource +from dimos.core.resource import CompositeResource from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out, RemoteOut, Transport from dimos.protocol.rpc.pubsubrpc import LCMRPC @@ -92,7 +91,7 @@ class _BlueprintPartial(Protocol): def __call__(self, **kwargs: Any) -> "Blueprint": ... -class ModuleBase(Configurable[ModuleConfigT], Resource): +class ModuleBase(Configurable[ModuleConfigT], CompositeResource): # This won't type check against the TypeVar, but we need it as the default. default_config: type[ModuleConfigT] = ModuleConfig # type: ignore[assignment] @@ -100,7 +99,6 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): _tf: TFSpec[Any] | None = None _loop: asyncio.AbstractEventLoop | None = None _loop_thread: threading.Thread | None - _disposables: CompositeDisposable _bound_rpc_calls: dict[str, RpcCall] = {} _module_closed: bool = False _module_closed_lock: threading.Lock @@ -111,7 +109,6 @@ def __init__(self, config_args: dict[str, Any]): super().__init__(**config_args) self._module_closed_lock = threading.Lock() self._loop, self._loop_thread = get_loop() - self._disposables = CompositeDisposable() try: self.rpc = self.config.rpc_transport() self.rpc.serve_module_rpc(self) @@ -132,6 +129,7 @@ def start(self) -> None: @rpc def stop(self) -> None: + super().stop() self._close_module() def _close_module(self) -> None: @@ -158,14 +156,12 @@ def _close_module(self) -> None: if hasattr(self, "_tf") and self._tf is not None: self._tf.stop() self._tf = None - if hasattr(self, "_disposables"): - self._disposables.dispose() - # Break the In/Out -> owner -> self reference cycle so the instance - # can be freed by refcount instead of waiting for GC. - for attr in list(vars(self).values()): - if isinstance(attr, (In, Out)): - attr.owner = None + # Stop transports and break the In/Out -> owner -> self reference + # cycle so the instance can be freed by refcount instead of waiting for GC. + for attr in [*self.inputs.values(), *self.outputs.values()]: + attr.stop() + attr.owner = None def _close_rpc(self) -> None: if self.rpc: @@ -188,7 +184,6 @@ def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] """Restore object from pickled state.""" self.__dict__.update(state) # Reinitialize runtime attributes - self._disposables = CompositeDisposable() self._module_closed_lock = threading.Lock() self._loop = None self._loop_thread = None diff --git a/dimos/core/resource.py b/dimos/core/resource.py index a4c008b806..a924ed8be3 100644 --- a/dimos/core/resource.py +++ b/dimos/core/resource.py @@ -16,7 +16,7 @@ from abc import abstractmethod import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar if sys.version_info >= (3, 11): from typing import Self @@ -29,6 +29,8 @@ from reactivex.abc import DisposableBase from reactivex.disposable import CompositeDisposable +D = TypeVar("D", bound=DisposableBase) + class Resource(DisposableBase): @abstractmethod @@ -75,18 +77,17 @@ def __exit__( class CompositeResource(Resource): """Resource that owns child disposables, disposed on stop().""" - _disposables: CompositeDisposable - - def __init__(self) -> None: - self._disposables = CompositeDisposable() + _disposables: CompositeDisposable | None = None - def register_disposables(self, *disposables: DisposableBase) -> None: - """Register child disposables to be disposed when this resource stops.""" - for d in disposables: - self._disposables.add(d) + def register_disposable(self, disposable: D) -> D: + """Register a child disposable to be disposed when this resource stops.""" + if self._disposables is None: + self._disposables = CompositeDisposable() + self._disposables.add(disposable) + return disposable - def start(self) -> None: - pass + def start(self) -> None: ... def stop(self) -> None: - self._disposables.dispose() + if self._disposables is not None: + self._disposables.dispose() diff --git a/dimos/core/stream.py b/dimos/core/stream.py index 7791968a29..41462ddbaa 100644 --- a/dimos/core/stream.py +++ b/dimos/core/stream.py @@ -135,6 +135,10 @@ def __str__(self) -> str: + ("" if not self._transport else " via " + str(self._transport)) ) + def stop(self) -> None: + if self._transport is not None: + self._transport.stop() + class Out(Stream[T], ObservableMixin[T]): _transport: Transport # type: ignore[type-arg] diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index f9a89829d5..e69cde2fef 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -47,7 +47,7 @@ def _odom(msg) -> None: self.mov.publish(msg.position) unsub = self.odometry.subscribe(_odom) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) def _lidar(msg) -> None: self.lidar_msg_count += 1 @@ -57,7 +57,7 @@ def _lidar(msg) -> None: print("RCV: unknown time", msg) unsub = self.lidar.subscribe(_lidar) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) def test_classmethods() -> None: diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index b8165658d9..5a34ed3d65 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -77,11 +77,11 @@ def on_image(image: Image) -> None: self.color_image.publish(image) self._latest_image = image - self._disposables.add( + self.register_disposable( stream.subscribe(on_image), ) - self._disposables.add( + self.register_disposable( rx.interval(1.0).subscribe(lambda _: self.publish_metadata()), ) diff --git a/dimos/hardware/sensors/camera/realsense/camera.py b/dimos/hardware/sensors/camera/realsense/camera.py index 821982981d..48ecde4331 100644 --- a/dimos/hardware/sensors/camera/realsense/camera.py +++ b/dimos/hardware/sensors/camera/realsense/camera.py @@ -162,7 +162,7 @@ def start(self) -> None: if self.config.enable_pointcloud and self.config.enable_depth: interval_sec = 1.0 / self.config.pointcloud_fps - self._disposables.add( + self.register_disposable( backpressure(rx.interval(interval_sec)).subscribe( on_next=lambda _: self._generate_pointcloud(), on_error=lambda e: print(f"Pointcloud error: {e}"), @@ -170,7 +170,7 @@ def start(self) -> None: ) interval_sec = 1.0 / self.config.camera_info_fps - self._disposables.add( + self.register_disposable( rx.interval(interval_sec).subscribe( on_next=lambda _: self._publish_camera_info(), on_error=lambda e: print(f"CameraInfo error: {e}"), diff --git a/dimos/hardware/sensors/camera/zed/camera.py b/dimos/hardware/sensors/camera/zed/camera.py index dd429c29cf..a646554b48 100644 --- a/dimos/hardware/sensors/camera/zed/camera.py +++ b/dimos/hardware/sensors/camera/zed/camera.py @@ -180,7 +180,7 @@ def start(self) -> None: self._enable_tracking() interval_sec = 1.0 / self.config.camera_info_fps - self._disposables.add( + self.register_disposable( rx.interval(interval_sec).subscribe( on_next=lambda _: self._publish_camera_info(), on_error=lambda e: print(f"CameraInfo error: {e}"), @@ -193,7 +193,7 @@ def start(self) -> None: if self.config.enable_pointcloud and self.config.enable_depth: interval_sec = 1.0 / self.config.pointcloud_fps - self._disposables.add( + self.register_disposable( backpressure(rx.interval(interval_sec)).subscribe( on_next=lambda _: self._generate_pointcloud(), on_error=lambda e: print(f"Pointcloud error: {e}"), diff --git a/dimos/hardware/sensors/fake_zed_module.py b/dimos/hardware/sensors/fake_zed_module.py index 16e85aa93c..21c1d27599 100644 --- a/dimos/hardware/sensors/fake_zed_module.py +++ b/dimos/hardware/sensors/fake_zed_module.py @@ -224,7 +224,7 @@ def start(self) -> None: unsub = self._get_color_stream().subscribe( lambda msg: self.color_image.publish(msg) if self._running else None ) - self._disposables.add(unsub) + self.register_disposable(unsub) logger.info("Started color image replay stream") except Exception as e: logger.warning(f"Color image stream not available: {e}") @@ -234,7 +234,7 @@ def start(self) -> None: unsub = self._get_depth_stream().subscribe( lambda msg: self.depth_image.publish(msg) if self._running else None ) - self._disposables.add(unsub) + self.register_disposable(unsub) logger.info("Started depth image replay stream") except Exception as e: logger.warning(f"Depth image stream not available: {e}") @@ -244,7 +244,7 @@ def start(self) -> None: unsub = self._get_pose_stream().subscribe( lambda msg: self._publish_pose(msg) if self._running else None ) - self._disposables.add(unsub) + self.register_disposable(unsub) logger.info("Started pose replay stream") except Exception as e: logger.warning(f"Pose stream not available: {e}") @@ -254,7 +254,7 @@ def start(self) -> None: unsub = self._get_camera_info_stream().subscribe( lambda msg: self.camera_info.publish(msg) if self._running else None ) - self._disposables.add(unsub) + self.register_disposable(unsub) logger.info("Started camera info replay stream") except Exception as e: logger.warning(f"Camera info stream not available: {e}") diff --git a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py index f3de842b46..9273a22fbb 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py +++ b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py @@ -30,7 +30,7 @@ mid360_fastlio_voxels = autoconnect( FastLio2.blueprint(), - VoxelGridMapper.blueprint(publish_interval=1.0, voxel_size=voxel_size, carve_columns=False), + VoxelGridMapper.blueprint(voxel_size=voxel_size, carve_columns=False), RerunBridgeModule.blueprint( visual_override={ "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), diff --git a/dimos/mapping/costmapper.py b/dimos/mapping/costmapper.py index 87ed64d404..0ec376a88f 100644 --- a/dimos/mapping/costmapper.py +++ b/dimos/mapping/costmapper.py @@ -60,7 +60,7 @@ def _calculate_and_time( elapsed_ms = (time.perf_counter() - start) * 1000 return grid, elapsed_ms, rx_monotonic - self._disposables.add( + self.register_disposable( self.global_map.observable() # type: ignore[no-untyped-call] .pipe(ops.map(_calculate_and_time)) .subscribe(lambda result: _publish_costmap(result[0], result[1], result[2])) diff --git a/dimos/mapping/pointclouds/test_occupancy_speed.py b/dimos/mapping/pointclouds/test_occupancy_speed.py index ac4085e971..115ee73ae0 100644 --- a/dimos/mapping/pointclouds/test_occupancy_speed.py +++ b/dimos/mapping/pointclouds/test_occupancy_speed.py @@ -18,7 +18,7 @@ import pytest from dimos.mapping.pointclouds.occupancy import OCCUPANCY_ALGOS -from dimos.mapping.voxels import VoxelGridMapper +from dimos.mapping.voxels import VoxelGrid from dimos.utils.cli.plot import bar from dimos.utils.data import get_data, get_data_dir from dimos.utils.testing.replay import TimedSensorReplay @@ -26,18 +26,18 @@ @pytest.mark.tool def test_build_map(): - mapper = VoxelGridMapper(publish_interval=-1) + grid = VoxelGrid() for _ts, frame in TimedSensorReplay("unitree_go2_bigoffice/lidar").iterate(): - mapper.add_frame(frame) + grid.add_frame(frame) pickle_file = get_data_dir() / "unitree_go2_bigoffice_map.pickle" - global_pcd = mapper.get_global_pointcloud2() + global_pcd = grid.get_global_pointcloud2() with open(pickle_file, "wb") as f: pickle.dump(global_pcd, f) - mapper.stop() + grid.dispose() def test_costmap_calc(): diff --git a/dimos/mapping/test_voxels.py b/dimos/mapping/test_voxels.py index bb5f4ed764..fc95b4652b 100644 --- a/dimos/mapping/test_voxels.py +++ b/dimos/mapping/test_voxels.py @@ -19,7 +19,7 @@ import pytest from dimos.core.transport import LCMTransport -from dimos.mapping.voxels import VoxelGridMapper +from dimos.mapping.voxels import VoxelGrid from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data from dimos.utils.testing.moment import OutputMoment @@ -28,10 +28,10 @@ @pytest.fixture -def mapper() -> Generator[VoxelGridMapper, None, None]: - mapper = VoxelGridMapper() - yield mapper - mapper.stop() +def grid() -> Generator[VoxelGrid, None, None]: + g = VoxelGrid() + yield g + g.dispose() class Go2MapperMoment(Go2Moment): @@ -78,21 +78,19 @@ def two_perspectives_loop(moment: MomentFactory) -> None: @pytest.mark.tool -def test_carving( - mapper: VoxelGridMapper, moment1: Go2MapperMoment, moment2: Go2MapperMoment -) -> None: +def test_carving(grid: VoxelGrid, moment1: Go2MapperMoment, moment2: Go2MapperMoment) -> None: lidar_frame1 = moment1.lidar.value assert lidar_frame1 is not None lidar_frame2 = moment2.lidar.value assert lidar_frame2 is not None - # Carving mapper (default, carve_columns=True) - mapper.add_frame(lidar_frame1) - mapper.add_frame(lidar_frame2) - count_carving = mapper.size() + # Carving grid (default, carve_columns=True) + grid.add_frame(lidar_frame1) + grid.add_frame(lidar_frame2) + count_carving = grid.size() - voxel_size = mapper.config.voxel_size + voxel_size = grid._voxel_size pts1 = np.asarray(lidar_frame1.pointcloud.points) pts2 = np.asarray(lidar_frame2.pointcloud.points) combined_vox = np.floor(np.vstack([pts1, pts2]) / voxel_size).astype(np.int64) @@ -109,7 +107,7 @@ def test_carving( ) -def test_injest_a_few(mapper: VoxelGridMapper) -> None: +def test_ingest_a_few(grid: VoxelGrid) -> None: data_dir = get_data("unitree_go2_office_walk2") lidar_store = TimedSensorReplay(f"{data_dir}/lidar") @@ -117,9 +115,9 @@ def test_injest_a_few(mapper: VoxelGridMapper) -> None: frame = lidar_store.find_closest_seek(i) assert frame is not None print("add", frame) - mapper.add_frame(frame) + grid.add_frame(frame) - assert len(mapper.get_global_pointcloud2()) == 30136 + assert len(grid.get_global_pointcloud2()) == 30136 @pytest.mark.parametrize( @@ -134,10 +132,10 @@ def test_roundtrip(moment1: Go2MapperMoment, voxel_size: float, expected_points: lidar_frame = moment1.lidar.value assert lidar_frame is not None - mapper = VoxelGridMapper(voxel_size=voxel_size) - mapper.add_frame(lidar_frame) + grid = VoxelGrid(voxel_size=voxel_size) + grid.add_frame(lidar_frame) - global1 = mapper.get_global_pointcloud2() + global1 = grid.get_global_pointcloud2() assert len(global1) == expected_points # loseless roundtrip @@ -146,15 +144,15 @@ def test_roundtrip(moment1: Go2MapperMoment, voxel_size: float, expected_points: # TODO: we want __eq__ on PointCloud2 - should actually compare # all points in both frames - mapper.add_frame(global1) + grid.add_frame(global1) # no new information, no global map change - assert len(mapper.get_global_pointcloud2()) == len(global1) + assert len(grid.get_global_pointcloud2()) == len(global1) moment1.publish() - mapper.stop() + grid.dispose() -def test_roundtrip_range_preserved(mapper: VoxelGridMapper) -> None: +def test_roundtrip_range_preserved(grid: VoxelGrid) -> None: """Test that input coordinate ranges are preserved in output.""" data_dir = get_data("unitree_go2_office_walk2") lidar_store = TimedSensorReplay(f"{data_dir}/lidar") @@ -163,12 +161,12 @@ def test_roundtrip_range_preserved(mapper: VoxelGridMapper) -> None: assert frame is not None input_pts = np.asarray(frame.pointcloud.points) - mapper.add_frame(frame) + grid.add_frame(frame) - out_pcd = mapper.get_global_pointcloud().to_legacy() + out_pcd = grid.get_global_pointcloud().to_legacy() out_pts = np.asarray(out_pcd.points) - voxel_size = mapper.config.voxel_size + voxel_size = grid._voxel_size tolerance = voxel_size # Allow one voxel of difference at boundaries # TODO: we want __eq__ on PointCloud2 - should actually compare diff --git a/dimos/mapping/voxels.py b/dimos/mapping/voxels.py index 92cbeed03e..71352a7fcf 100644 --- a/dimos/mapping/voxels.py +++ b/dimos/mapping/voxels.py @@ -12,61 +12,65 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time -from typing import Any +from typing import TYPE_CHECKING, Any -import numpy as np import open3d as o3d # type: ignore[import-untyped] import open3d.core as o3c # type: ignore[import-untyped] -from reactivex import interval, operators as ops -from reactivex.disposable import Disposable -from reactivex.subject import Subject -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig +from dimos.core.module import ModuleConfig from dimos.core.stream import In, Out +from dimos.memory2.module import StreamModule +from dimos.memory2.stream import Stream +from dimos.memory2.transform import Transformer from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.logging_config import setup_logger -from dimos.utils.reactive import backpressure -logger = setup_logger() +if TYPE_CHECKING: + from collections.abc import Iterator + from dimos.memory2.type.observation import Observation -class Config(ModuleConfig): - frame_id: str = "world" - # -1 never publishes, 0 publishes on every frame, >0 publishes at interval in seconds - publish_interval: float = 0 - voxel_size: float = 0.05 - block_count: int = 2_000_000 - device: str = "CUDA:0" - carve_columns: bool = True +logger = setup_logger() -class VoxelGridMapper(Module[Config]): - default_config = Config +class VoxelGrid: + """Pure voxel grid accumulator using Open3D VoxelBlockGrid. - lidar: In[PointCloud2] - global_map: Out[PointCloud2] + No Module/framework dependency. Can be used standalone or wrapped + by VoxelGridMapper (Module) or VoxelMapTransformer (memory2 Transformer). + """ - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__( + self, + voxel_size: float = 0.05, + block_count: int = 2_000_000, + device: str = "CUDA:0", + carve_columns: bool = True, + frame_id: str = "world", + ) -> None: + self._voxel_size = voxel_size + self._carve_columns = carve_columns + self._frame_id = frame_id dev = ( - o3c.Device(self.config.device) - if (self.config.device.startswith("CUDA") and o3c.cuda.is_available()) + o3c.Device(device) + if (device.startswith("CUDA") and o3c.cuda.is_available()) else o3c.Device("CPU:0") ) - logger.info(f"VoxelGridMapper using device: {dev}") + logger.info(f"VoxelGrid using device: {dev}") - self.vbg = o3d.t.geometry.VoxelBlockGrid( + self.vbg: o3d.t.geometry.VoxelBlockGrid | None = o3d.t.geometry.VoxelBlockGrid( attr_names=("dummy",), attr_dtypes=(o3c.uint8,), attr_channels=(o3c.SizeVector([1]),), - voxel_size=self.config.voxel_size, + voxel_size=voxel_size, block_resolution=1, - block_count=self.config.block_count, + block_count=block_count, device=dev, ) @@ -74,71 +78,27 @@ def __init__(self, **kwargs: Any) -> None: self._voxel_hashmap = self.vbg.hashmap() self._key_dtype = self._voxel_hashmap.key_tensor().dtype self._latest_frame_ts: float = 0.0 + self._disposed = False - @rpc - def start(self) -> None: - super().start() - - # Subject to trigger publishing, with backpressure to drop if busy - self._publish_trigger: Subject[None] = Subject() - self._disposables.add( - backpressure(self._publish_trigger) - .pipe(ops.map(lambda _: self.publish_global_map())) - .subscribe() - ) - - lidar_unsub = self.lidar.subscribe(self._on_frame) - self._disposables.add(Disposable(lidar_unsub)) - - # If publish_interval > 0, publish on timer; otherwise publish on each frame - if self.config.publish_interval > 0: - self._disposables.add( - interval(self.config.publish_interval).subscribe( - lambda _: self._publish_trigger.on_next(None) - ) - ) - - @rpc - def stop(self) -> None: - super().stop() - # Free tensor-tracked objects eagerly so Open3D does not report them as leaks. - self.get_global_pointcloud.invalidate_cache(self) - self.get_global_pointcloud2.invalidate_cache(self) - self.vbg = None - self._voxel_hashmap = None - - def _on_frame(self, frame: PointCloud2) -> None: - self.add_frame(frame) - if self.config.publish_interval == 0: - self._publish_trigger.on_next(None) - - def publish_global_map(self) -> None: - pc = self.get_global_pointcloud2() - self.global_map.publish(pc) - - def size(self) -> int: - return self._voxel_hashmap.size() # type: ignore[no-any-return] - - def __len__(self) -> int: - return self.size() + def _check_disposed(self) -> None: + if self._disposed: + raise RuntimeError("VoxelGrid has been disposed and cannot be used") - # @timed() # TODO: fix thread leak in timed decorator def add_frame(self, frame: PointCloud2) -> None: - # Track latest frame timestamp for proper latency measurement - if hasattr(frame, "ts") and frame.ts: + self._check_disposed() + if frame.ts is not None: self._latest_frame_ts = frame.ts - # we are potentially moving into CUDA here pcd = ensure_tensor_pcd(frame.pointcloud, self._dev) if pcd.is_empty(): return pts = pcd.point["positions"].to(self._dev, o3c.float32) - vox = (pts / self.config.voxel_size).floor().to(self._key_dtype) + vox = (pts / self._voxel_size).floor().to(self._key_dtype) keys_Nx3 = vox.contiguous() - if self.config.carve_columns: + if self._carve_columns: self._carve_and_insert(keys_Nx3) else: self._voxel_hashmap.activate(keys_Nx3) @@ -152,10 +112,8 @@ def _carve_and_insert(self, new_keys: o3c.Tensor) -> None: self._voxel_hashmap.activate(new_keys) return - # Extract (X, Y) from incoming keys xy_keys = new_keys[:, :2].contiguous() - # Build temp hashmap for O(1) (X,Y) membership lookup xy_hashmap = o3c.HashMap( init_capacity=xy_keys.shape[0], key_dtype=self._key_dtype, @@ -167,7 +125,6 @@ def _carve_and_insert(self, new_keys: o3c.Tensor) -> None: dummy_vals = o3c.Tensor.zeros((xy_keys.shape[0], 1), o3c.uint8, self._dev) xy_hashmap.insert(xy_keys, dummy_vals) - # Get existing keys from main hashmap active_indices = self._voxel_hashmap.active_buf_indices() if active_indices.shape[0] == 0: self._voxel_hashmap.activate(new_keys) @@ -176,36 +133,126 @@ def _carve_and_insert(self, new_keys: o3c.Tensor) -> None: existing_keys = self._voxel_hashmap.key_tensor()[active_indices] existing_xy = existing_keys[:, :2].contiguous() - # Find which existing keys have (X,Y) in the incoming set _, found_mask = xy_hashmap.find(existing_xy) - # Erase those columns to_erase = existing_keys[found_mask] if to_erase.shape[0] > 0: self._voxel_hashmap.erase(to_erase) - # Insert new keys self._voxel_hashmap.activate(new_keys) - # returns PointCloud2 message (ready to send off down the pipeline) @simple_mcache def get_global_pointcloud2(self) -> PointCloud2: + self._check_disposed() return PointCloud2( - # we are potentially moving out of CUDA here ensure_legacy_pcd(self.get_global_pointcloud()), - frame_id=self.frame_id, + frame_id=self._frame_id, ts=self._latest_frame_ts if self._latest_frame_ts else time.time(), ) @simple_mcache - # @timed() def get_global_pointcloud(self) -> o3d.t.geometry.PointCloud: + self._check_disposed() + assert self.vbg is not None voxel_coords, _ = self.vbg.voxel_coordinates_and_flattened_indices() - pts = voxel_coords + (self.config.voxel_size * 0.5) + pts = voxel_coords + (self._voxel_size * 0.5) out = o3d.t.geometry.PointCloud(device=self._dev) out.point["positions"] = pts return out + def size(self) -> int: + self._check_disposed() + return self._voxel_hashmap.size() # type: ignore[no-any-return] + + def __len__(self) -> int: + return self.size() + + def dispose(self) -> None: + """Free GPU resources. The object is unusable after this call.""" + if self._disposed: + return + self._disposed = True + self.get_global_pointcloud.invalidate_cache(self) # type: ignore[attr-defined] + self.get_global_pointcloud2.invalidate_cache(self) # type: ignore[attr-defined] + self.vbg = None + self._voxel_hashmap = None + + +class VoxelMapTransformer(Transformer[PointCloud2, PointCloud2]): + """Accumulate PointCloud2 observations into a global voxel map. + + Assumes input clouds are already in world frame. + All keyword arguments except ``emit_every`` are forwarded to + :class:`VoxelGrid`. + + Args: + emit_every: Yield the current accumulated map every *n* frames. + ``1`` (default) = yield after every frame (live-compatible). + ``0`` = yield only when upstream exhausts (batch mode). + **grid_kwargs: Forwarded to ``VoxelGrid()``. + """ + + def __init__(self, *, emit_every: int = 1, **grid_kwargs: Any) -> None: + self.emit_every = emit_every + self._grid_kwargs = grid_kwargs + + def _make_obs( + self, grid: VoxelGrid, last_obs: Observation[PointCloud2], count: int + ) -> Observation[PointCloud2]: + # pose=None: the global map is in world frame, per-observation pose is meaningless + return last_obs.derive( + data=grid.get_global_pointcloud2(), + pose=None, + tags={**last_obs.tags, "frame_count": count}, + ) + + def __call__( + self, upstream: Iterator[Observation[PointCloud2]] + ) -> Iterator[Observation[PointCloud2]]: + grid = VoxelGrid(**self._grid_kwargs) + try: + last_obs: Observation[PointCloud2] | None = None + count = 0 + + for obs in upstream: + grid.add_frame(obs.data) + last_obs = obs + count += 1 + + if self.emit_every > 0 and count % self.emit_every == 0: + yield self._make_obs(grid, last_obs, count) + + # Yield on exhaustion: always in batch mode, or if there are un-emitted frames + if last_obs is not None and (self.emit_every == 0 or count % self.emit_every != 0): + yield self._make_obs(grid, last_obs, count) + finally: + grid.dispose() + + +class VoxelGridMapperConfig(ModuleConfig): + """Configuration for VoxelGridMapper.""" + + voxel_size: float = 0.05 + block_count: int = 2_000_000 + device: str = "CUDA:0" + carve_columns: bool = True + frame_id: str = "world" + + +class VoxelGridMapper(StreamModule[VoxelGridMapperConfig]): + """Accumulate lidar point clouds into a global voxel map.""" + + default_config = VoxelGridMapperConfig + + def pipeline(self, stream: Stream[PointCloud2]) -> Stream[PointCloud2]: + cfg = self.config.model_dump( + include=set(VoxelGridMapperConfig.model_fields) - set(ModuleConfig.model_fields) + ) + return stream.transform(VoxelMapTransformer(**cfg)) + + lidar: In[PointCloud2] + global_map: Out[PointCloud2] + def ensure_tensor_pcd( pcd_any: o3d.t.geometry.PointCloud | o3d.geometry.PointCloud, @@ -220,14 +267,7 @@ def ensure_tensor_pcd( "Input must be a legacy PointCloud or a tensor PointCloud" ) - # Legacy CPU point cloud -> tensor - if isinstance(pcd_any, o3d.geometry.PointCloud): - return o3d.t.geometry.PointCloud.from_legacy(pcd_any, o3c.float32, device) - - pts = np.asarray(pcd_any.points, dtype=np.float32) - pcd_t = o3d.t.geometry.PointCloud(device=device) - pcd_t.point["positions"] = o3c.Tensor(pts, o3c.float32, device) - return pcd_t + return o3d.t.geometry.PointCloud.from_legacy(pcd_any, o3c.float32, device) def ensure_legacy_pcd( diff --git a/dimos/memory/embedding.py b/dimos/memory/embedding.py index df047292a0..9dece58bb7 100644 --- a/dimos/memory/embedding.py +++ b/dimos/memory/embedding.py @@ -56,7 +56,7 @@ class EmbeddingMemory(Module[Config]): def get_costmap(self) -> OccupancyGrid: if self._costmap_getter is None: self._costmap_getter = getter_hot(self.global_costmap.pure_observable()) - self._disposables.add(self._costmap_getter) + self.register_disposable(self._costmap_getter) return self._costmap_getter() @rpc diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index c861993de9..d330b10fd5 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -19,6 +19,7 @@ from dataclasses import replace from typing import TYPE_CHECKING, Any, Generic, TypeVar +from dimos.core.resource import CompositeResource from dimos.memory2.codecs.base import Codec, codec_id from dimos.memory2.notifier.subject import SubjectNotifier from dimos.memory2.type.observation import _UNLOADED @@ -39,12 +40,9 @@ T = TypeVar("T") -class Backend(Generic[T]): +class Backend(CompositeResource, Generic[T]): """Orchestrates metadata, blob, vector, and live stores for one stream. - - This is a concrete class — NOT a protocol. All shared orchestration logic (encode → insert → store blob → index vector → notify) lives here, - eliminating duplication between ListObservationStore and SqliteObservationStore. """ def __init__( @@ -57,13 +55,21 @@ def __init__( notifier: Notifier[T] | None = None, eager_blobs: bool = False, ) -> None: - self.metadata_store = metadata_store + super().__init__() + self.metadata_store = self.register_disposable(metadata_store) self.codec = codec - self.blob_store = blob_store - self.vector_store = vector_store - self.notifier: Notifier[T] = notifier or SubjectNotifier() + self.blob_store = self.register_disposable(blob_store) if blob_store else None + self.vector_store = self.register_disposable(vector_store) if vector_store else None + self.notifier: Notifier[T] = self.register_disposable(notifier or SubjectNotifier()) self.eager_blobs = eager_blobs + def start(self) -> None: + self.metadata_store.start() + if self.blob_store is not None: + self.blob_store.start() + if self.vector_store is not None: + self.vector_store.start() + @property def name(self) -> str: return self.metadata_store.name @@ -237,8 +243,3 @@ def serialize(self) -> dict[str, Any]: "vector_store": self.vector_store.serialize() if self.vector_store else None, "notifier": self.notifier.serialize(), } - - def stop(self) -> None: - """Stop the metadata store (closes per-stream connections if any).""" - if hasattr(self.metadata_store, "stop"): - self.metadata_store.stop() diff --git a/dimos/memory2/blobstore/sqlite.py b/dimos/memory2/blobstore/sqlite.py index 1cb5f1aa38..8092a34d1d 100644 --- a/dimos/memory2/blobstore/sqlite.py +++ b/dimos/memory2/blobstore/sqlite.py @@ -78,7 +78,7 @@ def start(self) -> None: if self._conn is None: assert self._path is not None disposable, self._conn = open_disposable_sqlite_connection(self._path) - self.register_disposables(disposable) + self.register_disposable(disposable) def put(self, stream_name: str, key: int, data: bytes) -> None: self._ensure_table(stream_name) diff --git a/dimos/memory2/module.py b/dimos/memory2/module.py new file mode 100644 index 0000000000..881b1d929a --- /dev/null +++ b/dimos/memory2/module.py @@ -0,0 +1,110 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from typing import Any + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfigT +from dimos.memory2.store.null import NullStore +from dimos.memory2.stream import Stream + + +class StreamModule(Module[ModuleConfigT]): + """Module base class that wires a memory2 stream pipeline. + + **Static pipeline** + + class VoxelGridMapper(StreamModule): + pipeline = Stream().transform(VoxelMapTransformer()) + lidar: In[PointCloud2] + global_map: Out[PointCloud2] + + **Config-driven pipeline** + + class VoxelGridMapper(StreamModule[VoxelGridMapperConfig]): + def pipeline(self, stream: Stream) -> Stream: + return stream.transform(VoxelMap(**self.config.model_dump())) + + lidar: In[PointCloud2] + global_map: Out[PointCloud2] + + On start, the single ``In`` port feeds a MemoryStore, and the pipeline + is applied to the live stream, publishing results to the single ``Out`` port. + + The MemoryStore acts as a bridge between the push-based Module In port + and the pull-based memory2 stream pipeline — it also enables replay and + persistence if the store is swapped for a persistent backend later. + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + @rpc + def start(self) -> None: + super().start() + + if len(self.inputs) != 1 or len(self.outputs) != 1: + raise TypeError( + f"{self.__class__.__name__} must have exactly one In and one Out port, " + f"found {len(self.inputs)} In and {len(self.outputs)} Out" + ) + + ((in_name, inp_port),) = self.inputs.items() + ((_, out_port),) = self.outputs.items() + + store = self.register_disposable(NullStore()) + store.start() + + stream: Stream[Any] = store.stream(in_name, inp_port.type) + + # we push input into the stream + inp_port.subscribe(lambda msg: stream.append(msg)) + + live = stream.live() + # and we push stream output to the output port + self._apply_pipeline(live).subscribe( + lambda obs: out_port.publish(obs.data), + ) + + def _apply_pipeline(self, stream: Stream[Any]) -> Stream[Any]: + """Apply the pipeline to a live stream. + + Handles both static (class attr) and dynamic (method) pipelines. + """ + pipeline = getattr(self.__class__, "pipeline", None) + if pipeline is None: + raise TypeError( + f"{self.__class__.__name__} must define a 'pipeline' attribute or method" + ) + + # Method pipeline: self.pipeline(stream) -> stream + if inspect.isfunction(pipeline): + result = pipeline(self, stream) + if not isinstance(result, Stream): + raise TypeError( + f"{self.__class__.__name__}.pipeline() must return a Stream, got {type(result).__name__}" + ) + return result + + # Static class attr: Stream (unbound chain) or Transformer + if isinstance(pipeline, Stream): + return stream.chain(pipeline) + return stream.transform(pipeline) + + @rpc + def stop(self) -> None: + super().stop() diff --git a/dimos/memory2/notifier/base.py b/dimos/memory2/notifier/base.py index 022d26d4e0..bb25a1cbf6 100644 --- a/dimos/memory2/notifier/base.py +++ b/dimos/memory2/notifier/base.py @@ -17,6 +17,7 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Any, Generic, TypeVar +from dimos.core.resource import Resource from dimos.memory2.registry import qual from dimos.protocol.service.spec import BaseConfig, Configurable @@ -33,7 +34,7 @@ class NotifierConfig(BaseConfig): pass -class Notifier(Configurable[NotifierConfig], Generic[T]): +class Notifier(Configurable[NotifierConfig], Resource, Generic[T]): """Push-notification for live observation delivery. Decouples the notification mechanism from storage. The built-in @@ -47,6 +48,12 @@ class Notifier(Configurable[NotifierConfig], Generic[T]): def __init__(self, **kwargs: Any) -> None: Configurable.__init__(self, **kwargs) + def start(self) -> None: + pass + + def stop(self) -> None: + pass + @abstractmethod def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: """Register *buf* to receive new observations. Returns a diff --git a/dimos/memory2/notifier/subject.py b/dimos/memory2/notifier/subject.py index d1b8d7f888..4b43d28c0a 100644 --- a/dimos/memory2/notifier/subject.py +++ b/dimos/memory2/notifier/subject.py @@ -68,3 +68,11 @@ def notify(self, obs: Observation[T]) -> None: subs = list(self._subscribers) for buf in subs: buf.put(obs) + + def stop(self) -> None: + """Close all subscribed buffers, unblocking any live iterators.""" + with self._lock: + subs = list(self._subscribers) + self._subscribers.clear() + for buf in subs: + buf.close() diff --git a/dimos/memory2/observationstore/memory.py b/dimos/memory2/observationstore/memory.py index 529cd06394..faeb0fbec1 100644 --- a/dimos/memory2/observationstore/memory.py +++ b/dimos/memory2/observationstore/memory.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections import deque import threading from typing import TYPE_CHECKING, Any, TypeVar @@ -30,10 +31,17 @@ class ListObservationStoreConfig(ObservationStoreConfig): name: str = "" + max_size: int | None = None class ListObservationStore(ObservationStore[T]): - """In-memory metadata store for experimentation. Thread-safe.""" + """In-memory metadata store for experimentation. Thread-safe. + + ``max_size`` controls how many observations are retained: + - ``None`` (default) — keep all (unbounded). + - ``N`` — rolling window of the most recent N observations. + - ``0`` — discard immediately (live-only, no history). + """ default_config = ListObservationStoreConfig config: ListObservationStoreConfig @@ -41,7 +49,8 @@ class ListObservationStore(ObservationStore[T]): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._name = self.config.name - self._observations: list[Observation[T]] = [] + max_size = self.config.max_size + self._observations: deque[Observation[T]] = deque(maxlen=max_size) self._next_id = 0 self._lock = threading.Lock() diff --git a/dimos/memory2/observationstore/sqlite.py b/dimos/memory2/observationstore/sqlite.py index 5d680c540a..960bb2ce55 100644 --- a/dimos/memory2/observationstore/sqlite.py +++ b/dimos/memory2/observationstore/sqlite.py @@ -273,7 +273,7 @@ def start(self) -> None: if self._conn is None: assert self._path is not None disposable, self._conn = open_disposable_sqlite_connection(self._path) - self.register_disposables(disposable) + self.register_disposable(disposable) self._ensure_tables() def _ensure_tables(self) -> None: diff --git a/dimos/memory2/store/README.md b/dimos/memory2/store/README.md index ff18640c0b..4766c24998 100644 --- a/dimos/memory2/store/README.md +++ b/dimos/memory2/store/README.md @@ -1,25 +1,50 @@ -# store — Store implementations +# store — Store and ObservationStore implementations -Metadata index backends for memory. Each index implements the `ObservationStore` protocol to provide observation metadata storage with query support. The concrete `Backend` class handles orchestration (blob, vector, live) on top of any index. +Store is the top-level user-facing entry point. You create one, ask it for named streams, and use those streams. Internally, each stream gets a Backend that orchestrates the lower-level pieces: -## Existing implementations +``` +Store + └── stream("lidar") → Backend + ├── ObservationStore (metadata: id, timestamp, tags, frame_id) + ├── BlobStore (raw bytes: encoded payloads) + ├── VectorStore (embeddings: similarity search) + └── Notifier (live push: new observation events) +``` + +- **ObservationStore** stores observation *metadata* and handles queries (filters, ordering, limit/offset, text search). Doesn't touch raw data or vectors. +- **BlobStore** stores/retrieves encoded payloads by `(stream_name, row_id)`. Just a key-value byte store. +- **VectorStore** stores/retrieves embedding vectors, handles similarity search. +- **Notifier** pushes new observations to live subscribers (for `.live()` tails). + +The **Backend** is the glue — on `append()` it encodes the payload, inserts metadata into ObservationStore, stores the blob in BlobStore, indexes the vector in VectorStore, and notifies live subscribers. On iterate, it queries ObservationStore for metadata, attaches lazy blob loaders, and handles vector search routing. + +**Store** sits above all that — it manages the mapping of stream names to Backends, handles config inheritance (store-level defaults vs per-stream overrides), and provides the `store.stream("name")` / `store.streams.name` API. `MemoryStore` vs `SqliteStore` vs `NullStore` differ in which component implementations they wire up by default and how they persist the registry of known streams. + +## Store implementations + +| Store | File | Description | +|----------------|-------------|------------------------------------------------------| +| `MemoryStore` | `memory.py` | In-memory store for experimentation | +| `SqliteStore` | `sqlite.py` | SQLite-backed persistent store (WAL, registry, vec0) | +| `NullStore` | `null.py` | Live-only O(1) memory, no history/replay | + +## ObservationStore implementations -| ObservationStore | File | Status | Storage | -|-----------------|-------------|----------|-------------------------------------| -| `ListObservationStore` | `memory.py` | Complete | In-memory lists, brute-force search | -| `SqliteObservationStore` | `sqlite.py` | Complete | SQLite (WAL, R*Tree, vec0) | +| ObservationStore | File | Storage | +|--------------------------|----------------------------|-------------------------------------| +| `ListObservationStore` | `observationstore/memory.py` | In-memory deque, brute-force search. `max_size` controls retention (None=all, N=rolling window, 0=discard) | +| `SqliteObservationStore` | `observationstore/sqlite.py` | SQLite (WAL, R*Tree, vec0) | -## Writing a new index +## Writing a new ObservationStore -### 1. Implement the ObservationStore protocol +### 1. Subclass ObservationStore ```python from dimos.memory2.observationstore.base import ObservationStore -from dimos.memory2.type.filter import StreamQuery -from dimos.memory2.type.observation import Observation -class MyObservationStore(Generic[T]): - def __init__(self, name: str) -> None: +class MyObservationStore(ObservationStore[T]): + def __init__(self, name: str, **kwargs: Any) -> None: + super().__init__(**kwargs) self._name = name @property @@ -35,8 +60,8 @@ class MyObservationStore(Generic[T]): def query(self, q: StreamQuery) -> Iterator[Observation[T]]: """Yield observations matching the query.""" - # The index handles metadata query fields: - # q.filters — list of Filter objects (each has .matches(obs)) + # The query carries metadata fields: + # q.filters — tuple of Filter objects (each has .matches(obs)) # q.order_field — sort field name (e.g. "ts") # q.order_desc — sort direction # q.limit_val — max results @@ -53,7 +78,7 @@ class MyObservationStore(Generic[T]): ... ``` -`ObservationStore` is a `@runtime_checkable` Protocol — no base class needed, just implement the methods. +`ObservationStore` is an abstract base class (extends `CompositeResource` and `Configurable`). ### 2. Create a Store subclass @@ -66,10 +91,11 @@ class MyStore(Store): def _create_backend( self, name: str, payload_type: type | None = None, **config: Any ) -> Backend: - index = MyObservationStore(name) - codec = codec_for(payload_type) + obs = MyObservationStore(name) + obs.start() + codec = self._resolve_codec(payload_type, config.get("codec")) return Backend( - index=index, + metadata_store=obs, codec=codec, blob_store=config.get("blob_store"), vector_store=config.get("vector_store"), @@ -84,29 +110,32 @@ class MyStore(Store): self._streams.pop(name, None) ``` -The Store creates a `Backend` composite for each stream. The `Backend` handles all orchestration (encode → insert → store blob → index vector → notify) so your index only needs to handle metadata. +The Store creates a `Backend` composite for each stream. The `Backend` handles all orchestration (encode -> insert -> store blob -> index vector -> notify) so your ObservationStore only needs to handle metadata. -### 3. Add to the grid test +### 3. Add to the test grid -In `test_impl.py`, add your store to the fixture so all standard tests run against it: +In `conftest.py`, add your store fixture and include it in the parametrized `session` fixture so all standard tests run against it: ```python -@pytest.fixture(params=["memory", "sqlite", "myindex"]) -def store(request, tmp_path): - if request.param == "myindex": - return MyStore(...) - ... +@pytest.fixture +def my_store() -> Iterator[MyStore]: + with MyStore() as store: + yield store + +@pytest.fixture(params=["memory_store", "sqlite_store", "my_store"]) +def session(request): + return request.getfixturevalue(request.param) ``` Use `pytest.mark.xfail` for features not yet implemented — the grid test covers: append, fetch, iterate, count, first/last, exists, all filters, ordering, limit/offset, embeddings, text search. ### Query contract -The index must handle the `StreamQuery` metadata fields. Vector search and blob loading are handled by the `Backend` composite — the index never needs to deal with them. +The ObservationStore must handle the `StreamQuery` metadata fields. Vector search and blob loading are handled by the `Backend` composite — the ObservationStore never needs to deal with them. -`StreamQuery.apply(iterator)` provides a complete Python-side execution path — filters, text search, vector search, ordering, offset/limit — all as in-memory operations. ObservationStorees can use it in three ways: +`StreamQuery.apply(iterator)` provides a complete Python-side execution path — filters, text search, vector search, ordering, offset/limit — all as in-memory operations. ObservationStores can use it in three ways: -**Full delegation** — simplest, good enough for in-memory indexes: +**Full delegation** — simplest, good enough for in-memory stores: ```python def query(self, q: StreamQuery) -> Iterator[Observation[T]]: return q.apply(iter(self._data)) @@ -127,4 +156,4 @@ def query(self, q: StreamQuery) -> Iterator[Observation[T]]: **Full push-down** — translate everything to native queries (SQL WHERE, FTS5 MATCH) without calling `apply()` at all. -For filters, each `Filter` object has a `.matches(obs) -> bool` method that indexes can use directly if they don't have a native equivalent. +For filters, each `Filter` object has a `.matches(obs) -> bool` method that ObservationStores can use directly if they don't have a native equivalent. diff --git a/dimos/memory2/store/base.py b/dimos/memory2/store/base.py index cf571f23b0..ffb4ace8cd 100644 --- a/dimos/memory2/store/base.py +++ b/dimos/memory2/store/base.py @@ -120,17 +120,14 @@ def _create_backend( obs = config.pop("observation_store", self.config.observation_store) if obs is None or isinstance(obs, type): obs = (obs or ListObservationStore)(name=name) - obs.start() bs = config.pop("blob_store", self.config.blob_store) if isinstance(bs, type): bs = bs() - bs.start() vs = config.pop("vector_store", self.config.vector_store) if isinstance(vs, type): vs = vs() - vs.start() notifier = config.pop("notifier", self.config.notifier) if notifier is None or isinstance(notifier, type): @@ -154,6 +151,7 @@ def stream(self, name: str, payload_type: type[T] | None = None, **overrides: An if name not in self._streams: resolved = {**self.config.model_dump(exclude_none=True), **overrides} backend = self._create_backend(name, payload_type, **resolved) + backend.start() self._streams[name] = Stream(source=backend) return cast("Stream[T]", self._streams[name]) @@ -163,4 +161,11 @@ def list_streams(self) -> list[str]: def delete_stream(self, name: str) -> None: """Delete a stream by name (from cache and underlying storage).""" - self._streams.pop(name, None) + stream = self._streams.pop(name, None) + if stream is not None: + stream.stop() + + def stop(self) -> None: + for stream in self._streams.values(): + stream.stop() + super().stop() diff --git a/dimos/memory2/store/memory.py b/dimos/memory2/store/memory.py index 6aecde29dd..5b4523aac6 100644 --- a/dimos/memory2/store/memory.py +++ b/dimos/memory2/store/memory.py @@ -12,10 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.memory2.store.base import Store +from typing import Any + +from dimos.memory2.backend import Backend +from dimos.memory2.observationstore.memory import ListObservationStore +from dimos.memory2.store.base import Store, StoreConfig + + +class MemoryStoreConfig(StoreConfig): + max_size: int | None = None class MemoryStore(Store): - """In-memory store for experimentation.""" + """In-memory store for experimentation. + + ``max_size`` controls how many observations each stream retains: + - ``None`` (default) — keep all (unbounded). + - ``N`` — rolling window of the most recent N observations. + - ``0`` — discard immediately (live-only, no history). + """ + + default_config = MemoryStoreConfig + config: MemoryStoreConfig - pass + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: + if "observation_store" not in config and self.config.observation_store is None: + obs: ListObservationStore[Any] = ListObservationStore( + name=name, max_size=self.config.max_size + ) + config["observation_store"] = obs + return super()._create_backend(name, payload_type, **config) diff --git a/dimos/memory2/store/null.py b/dimos/memory2/store/null.py new file mode 100644 index 0000000000..71f02c4aee --- /dev/null +++ b/dimos/memory2/store/null.py @@ -0,0 +1,29 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from dimos.memory2.store.memory import MemoryStore + + +class NullStore(MemoryStore): + """Live-only store — O(1) memory, no history/replay. + + Shorthand for ``MemoryStore(max_size=0)``. + Observations get IDs (for live dedup) but are immediately discarded. + """ + + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("max_size", 0) + super().__init__(**kwargs) diff --git a/dimos/memory2/store/sqlite.py b/dimos/memory2/store/sqlite.py index b655e0a8bc..1071e9977f 100644 --- a/dimos/memory2/store/sqlite.py +++ b/dimos/memory2/store/sqlite.py @@ -14,8 +14,11 @@ from __future__ import annotations +import os import sqlite3 -from typing import Any +from typing import Annotated, Any + +from pydantic import BeforeValidator from dimos.memory2.backend import Backend from dimos.memory2.blobstore.base import BlobStore @@ -33,7 +36,9 @@ class SqliteStoreConfig(StoreConfig): """Config for SQLite-backed store.""" - path: str = "memory.db" + path: Annotated[ + str, BeforeValidator(lambda v: os.fspath(v) if isinstance(v, os.PathLike) else v) + ] = "memory.db" page_size: int = 256 @@ -51,7 +56,7 @@ def __init__(self, **kwargs: Any) -> None: def _open_connection(self) -> sqlite3.Connection: """Open a new WAL-mode connection with sqlite-vec loaded.""" disposable, connection = open_disposable_sqlite_connection(self.config.path) - self.register_disposables(disposable) + self.register_disposable(disposable) return connection def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: @@ -75,7 +80,6 @@ def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: bs = deserialize_component(bs_data) else: bs = SqliteBlobStore(conn=backend_conn) - bs.start() vs_data = stored.get("vector_store") if vs_data is not None: @@ -86,7 +90,6 @@ def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: vs = deserialize_component(vs_data) else: vs = SqliteVectorStore(conn=backend_conn) - vs.start() notifier_data = stored.get("notifier") if notifier_data is not None: @@ -105,8 +108,6 @@ def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: blob_store_conn_match=blob_store_conn_match and eager_blobs, page_size=page_size, ) - metadata_store.start() - backend: Backend[Any] = Backend( metadata_store=metadata_store, codec=codec, @@ -161,13 +162,9 @@ def _create_backend( # Inject conn-shared instances unless user provided overrides if not isinstance(config.get("blob_store"), BlobStore): - bs = SqliteBlobStore(conn=backend_conn) - bs.start() - config["blob_store"] = bs + config["blob_store"] = SqliteBlobStore(conn=backend_conn) if not isinstance(config.get("vector_store"), VectorStore): - vs = SqliteVectorStore(conn=backend_conn) - vs.start() - config["vector_store"] = vs + config["vector_store"] = SqliteVectorStore(conn=backend_conn) # Resolve codec early — needed for SqliteObservationStore codec = self._resolve_codec(payload_type, config.get("codec")) @@ -184,7 +181,6 @@ def _create_backend( blob_store_conn_match=blob_conn_match and eager_blobs, page_size=config.pop("page_size", self.config.page_size), ) - obs_store.start() config["observation_store"] = obs_store backend = super()._create_backend(name, payload_type, **config) diff --git a/dimos/memory2/store/test_null.py b/dimos/memory2/store/test_null.py new file mode 100644 index 0000000000..3461ff3d9d --- /dev/null +++ b/dimos/memory2/store/test_null.py @@ -0,0 +1,56 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for NullStore and max_size=0 discard behavior.""" + +from __future__ import annotations + +from dimos.memory2.store.null import NullStore + + +def test_max_size_zero_monotonic_ids() -> None: + """NullStore assigns monotonically increasing IDs despite discarding data.""" + store = NullStore() + with store: + stream = store.stream("test", str) + obs0 = stream.append("hello") + obs1 = stream.append("world") + obs2 = stream.append("!") + + assert obs0.id == 0 + assert obs1.id == 1 + assert obs2.id == 2 + + +def test_max_size_zero_empty_query() -> None: + """NullStore queries always return empty.""" + store = NullStore() + with store: + stream = store.stream("test", str) + stream.append("data") + assert stream.count() == 0 + assert stream.fetch() == [] + + +def test_null_store_discards_history() -> None: + """NullStore discards history but still supports live streaming.""" + store = NullStore() + with store: + stream = store.stream("test", int) + stream.append(1) + stream.append(2) + stream.append(3) + + assert stream.count() == 0 + assert stream.fetch() == [] diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index 545d387c32..75bf6ab6a0 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -15,9 +15,9 @@ from __future__ import annotations import time -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast -from dimos.core.resource import Resource +from dimos.core.resource import CompositeResource from dimos.memory2.buffer import BackpressureBuffer, KeepLast from dimos.memory2.transform import FnIterTransformer, FnTransformer, Transformer from dimos.memory2.type.filter import ( @@ -32,6 +32,7 @@ TimeRangeFilter, ) from dimos.memory2.type.observation import EmbeddedObservation, Observation +from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: from collections.abc import Callable, Iterator @@ -44,53 +45,61 @@ T = TypeVar("T") R = TypeVar("R") +logger = setup_logger() -class Stream(Resource, Generic[T]): +class Stream(CompositeResource, Generic[T]): """Lazy, pull-based stream over observations. Every filter/transform method returns a new Stream — no computation happens until iteration. Backends handle query application for stored data; transform sources apply filters as Python predicates. - Implements Resource so live streams can be cleanly stopped via - ``stop()`` or used as a context manager. + Implements CompositeResource so subscriptions created via ``.subscribe()`` + and ``.publish()`` are tracked and disposed on ``stop()``. + + An *unbound* stream (``Stream()``) records a chain of transforms + without a real source. Use ``.chain()`` to apply it to a bound stream:: + + pipeline = Stream().transform(VoxelMapTransformer()).map(postprocess) + store.stream("lidar", PointCloud2).live().chain(pipeline) """ def __init__( self, - source: Backend[T] | Stream[Any], + source: Backend[T] | Stream[Any] | None = None, *, - xf: Transformer[Any, T] | None = None, + transform: Transformer[Any, T] | None = None, query: StreamQuery = StreamQuery(), ) -> None: + super().__init__() self._source = source - self._xf = xf + if source is not None: + self.register_disposable(source) + self._transform = transform self._query = query - def start(self) -> None: - pass - def stop(self) -> None: - """Close the live buffer (if any), unblocking iteration.""" buf = self._query.live_buffer if buf is not None: buf.close() - if isinstance(self._source, Stream): - self._source.stop() + super().stop() def __str__(self) -> str: # Walk the source chain to collect (xf, query) pairs chain: list[tuple[Any, StreamQuery]] = [] current: Any = self while isinstance(current, Stream): - chain.append((current._xf, current._query)) + chain.append((current._transform, current._query)) current = current._source chain.reverse() # innermost first - # current is the Backend - name = getattr(current, "name", "?") - result = f'Stream("{name}")' + # current is the Backend (or None for unbound) + if current is None: + result = "Stream(unbound)" + else: + name = getattr(current, "name", "?") + result = f'Stream("{name}")' for xf, query in chain: if xf is not None: @@ -110,9 +119,10 @@ def is_live(self) -> bool: return False def __iter__(self) -> Iterator[Observation[T]]: - return self._build_iter() - - def _build_iter(self) -> Iterator[Observation[T]]: + if self._source is None: + raise TypeError( + "Cannot iterate an unbound stream. Use .chain() to apply it to a real stream first." + ) if isinstance(self._source, Stream): return self._iter_transform() # Backend handles all query application (including live if requested) @@ -120,8 +130,8 @@ def _build_iter(self) -> Iterator[Observation[T]]: def _iter_transform(self) -> Iterator[Observation[T]]: """Iterate a transform source, applying query filters in Python.""" - assert isinstance(self._source, Stream) and self._xf is not None - it: Iterator[Observation[T]] = self._xf(iter(self._source)) + assert isinstance(self._source, Stream) and self._transform is not None + it: Iterator[Observation[T]] = self._transform(iter(self._source)) return self._query.apply(it, live=self.is_live()) def _replace_query(self, **overrides: Any) -> Stream[T]: @@ -137,7 +147,7 @@ def _replace_query(self, **overrides: Any) -> Stream[T]: search_k=overrides.get("search_k", q.search_k), search_text=overrides.get("search_text", q.search_text), ) - return Stream(self._source, xf=self._xf, query=new_q) + return Stream(self._source, transform=self._transform, query=new_q) def _with_filter(self, f: Filter) -> Stream[T]: return self._replace_query(filters=(*self._query.filters, f)) @@ -210,7 +220,7 @@ def detect(upstream): """ if not isinstance(xf, Transformer): xf = FnIterTransformer(xf) - return Stream(source=self, xf=xf, query=StreamQuery()) + return Stream(source=self, transform=xf, query=StreamQuery()) def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: """Return a stream whose iteration never ends — backfill then live tail. @@ -221,9 +231,9 @@ def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> St Default buffer: KeepLast(). The backend handles subscription, dedup, and backpressure — how it does so is its business. """ - if isinstance(self._source, Stream): + if isinstance(self._source, Stream) or self._source is None: raise TypeError( - "Cannot call .live() on a transform stream. " + "Cannot call .live() on a transform/unbound stream. " "Call .live() on the source stream, then .transform()." ) buf = buffer if buffer is not None else KeepLast() @@ -234,8 +244,10 @@ def save(self, target: Stream[T]) -> Stream[T]: Returns the target stream for continued querying. """ - if isinstance(target._source, Stream): - raise TypeError("Cannot save to a transform stream. Target must be backend-backed.") + if isinstance(target._source, Stream) or target._source is None: + raise TypeError( + "Cannot save to a transform/unbound stream. Target must be backend-backed." + ) backend = target._source for obs in self: backend.append(obs) @@ -264,7 +276,7 @@ def last(self) -> Observation[T]: def count(self) -> int: """Count matching observations.""" - if not isinstance(self._source, Stream): + if self._source is not None and not isinstance(self._source, Stream): return self._source.count(self._query) if self.is_live(): raise TypeError(".count() on a live transform stream would block forever.") @@ -328,13 +340,79 @@ def subscribe( on_error: Callable[[Exception], None] | None = None, on_completed: Callable[[], None] | None = None, ) -> DisposableBase: - """Subscribe to this stream as an RxPY Observable.""" - return self.observable().subscribe( # type: ignore[call-overload] - on_next=on_next, - on_error=on_error, - on_completed=on_completed, + """Subscribe to this stream as an RxPY Observable. + + The subscription is tracked and disposed when this stream is stopped. + """ + return self.register_disposable( + self.observable().subscribe( # type: ignore[call-overload] + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + ) + ) + + def publish(self, out: Any) -> DisposableBase: + """Publish each observation's data to a Module ``Out`` port. + + Iteration runs on the dimos thread pool (via :meth:`subscribe`). + Returns a ``DisposableBase`` suitable for ``register_disposable()``. + + Example:: + + lidar.live().transform(VoxelMapTransformer()).publish(self.global_map) + """ + + def _on_error(e: Exception) -> None: + logger.error("Stream.publish() pipeline error: %s", e, exc_info=True) + + return self.subscribe( + on_next=lambda obs: out.publish(obs.data), + on_error=_on_error, ) + def chain(self, other: Stream[R]) -> Stream[R]: + """Append operations from an unbound stream to this stream. + + Extracts the transform/filter chain from *other* (which must be + unbound) and replays it on top of ``self``:: + + pipeline = Stream().transform(VoxelMapTransformer()).map(postprocess) + store.stream("lidar").live().chain(pipeline) + """ + ops: list[tuple[Transformer[Any, Any] | None, StreamQuery]] = [] + current: Stream[Any] | None | Any = other + found_root = False + while isinstance(current, Stream): + ops.append((current._transform, current._query)) + if current._source is None: + found_root = True + break + current = current._source + if not found_root: + raise TypeError("Can only chain an unbound stream (created with Stream())") + + # Validate no unsupported query fields in the unbound chain + for _, query in ops: + if query.search_vec is not None or query.search_text is not None: + raise TypeError("search() / search_text() cannot be used on unbound streams") + if query.live_buffer is not None: + raise TypeError("live() cannot be used on unbound streams") + + result: Stream[Any] = self + for xf, query in reversed(ops): + if xf is not None: + result = result.transform(xf) + for f in query.filters: + result = result._with_filter(f) + if query.limit_val is not None: + result = result.limit(query.limit_val) + if query.offset_val is not None and query.offset_val != 0: + result = result.offset(query.offset_val) + if query.order_field is not None: + result = result.order_by(query.order_field, desc=query.order_desc) + return cast("Stream[R]", result) + def append( self, payload: T, @@ -345,8 +423,10 @@ def append( embedding: Embedding | None = None, ) -> Observation[T]: """Append to the backing store. Only works if source is a Backend.""" - if isinstance(self._source, Stream): - raise TypeError("Cannot append to a transform stream. Append to the source stream.") + if isinstance(self._source, Stream) or self._source is None: + raise TypeError( + "Cannot append to a transform/unbound stream. Append to the source stream." + ) _ts = ts if ts is not None else time.time() _tags = tags or {} if embedding is not None: diff --git a/dimos/memory2/test_e2e.py b/dimos/memory2/test_e2e.py index efea5a59a2..31d5ee1720 100644 --- a/dimos/memory2/test_e2e.py +++ b/dimos/memory2/test_e2e.py @@ -126,6 +126,31 @@ def test_import_lidar( assert lidar.count() == count print(f"Imported {count} lidar frames") + def test_embed_and_save(self, session: SqliteStore, clip: CLIPModel) -> None: + """Embed video frames at 1Hz and persist to an embedded stream.""" + video = session.stream("color_image", Image) + + # Clear any prior run so the test is idempotent + if "color_image_embedded" in session.list_streams(): + session.delete_stream("color_image_embedded") + + embedded = session.stream("color_image_embedded", Image) + + # Downsample to 1Hz, then embed + pipeline = ( + video.transform(QualityWindow(lambda img: img.sharpness, window=1.0)) + .transform(EmbedImages(clip)) + .save(embedded) + ) + + count = 0 + for obs in pipeline: + count += 1 + print(f" [{count}] ts={obs.ts:.2f} pose={obs.pose}") + + assert count > 0 + print(f"Embedded {count} frames (1Hz from {video.count()} total)") + def test_query_imported_data(self, session: SqliteStore) -> None: video = session.stream("color_image", Image) lidar = session.stream("lidar", PointCloud2) @@ -256,38 +281,12 @@ def test_cross_stream_time_alignment(self, session: SqliteStore) -> None: overlap_start = max(v_first, l_first) overlap_end = min(v_last, l_last) assert overlap_start < overlap_end, "Video and lidar should overlap in time" - assert overlap_start < overlap_end, "Video and lidar should overlap in time" @pytest.mark.tool class TestEmbedImages: """CLIP-embed imported video frames and search by text.""" - def test_embed_and_save(self, session: SqliteStore, clip: CLIPModel) -> None: - """Embed video frames at 1Hz and persist to an embedded stream.""" - video = session.stream("color_image", Image) - - # Clear any prior run so the test is idempotent - if "color_image_embedded" in session.list_streams(): - session.delete_stream("color_image_embedded") - - embedded = session.stream("color_image_embedded", Image) - - # Downsample to 1Hz, then embed - pipeline = ( - video.transform(QualityWindow(lambda img: img.sharpness, window=1.0)) - .transform(EmbedImages(clip)) - .save(embedded) - ) - - count = 0 - for obs in pipeline: - count += 1 - print(f" [{count}] ts={obs.ts:.2f} pose={obs.pose}") - - assert count > 0 - print(f"Embedded {count} frames (1Hz from {video.count()} total)") - def test_search_by_text(self, session: SqliteStore, clip: CLIPModel) -> None: """Search embedded frames with a text query.""" embedded = session.stream("color_image_embedded", Image) diff --git a/dimos/memory2/test_module.py b/dimos/memory2/test_module.py new file mode 100644 index 0000000000..a944539063 --- /dev/null +++ b/dimos/memory2/test_module.py @@ -0,0 +1,131 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for StreamModule — same e2e logic across all pipeline styles.""" + +from __future__ import annotations + +from collections.abc import Iterator +import threading + +import pytest +from reactivex.scheduler import ThreadPoolScheduler + +from dimos.core.module import ModuleConfig +from dimos.core.stream import In, Out +from dimos.core.transport import pLCMTransport +from dimos.memory2.module import StreamModule +from dimos.memory2.stream import Stream +from dimos.memory2.transform import Transformer +from dimos.memory2.type.observation import Observation + +# -- Shared transformer --------------------------------------------------- + + +class Double(Transformer[int, int]): + def __init__(self, factor: int = 2) -> None: + self.factor = factor + + def __call__(self, upstream: Iterator[Observation[int]]) -> Iterator[Observation[int]]: + for obs in upstream: + yield obs.derive(data=obs.data * self.factor) + + +# -- Pipeline styles ------------------------------------------------------- + + +class StaticStreamModule(StreamModule): + """Pipeline as a static Stream chain on the class.""" + + pipeline = Stream().transform(Double()) + numbers: In[int] + doubled: Out[int] + + +class StaticTransformerModule(StreamModule): + """Pipeline as a bare Transformer on the class.""" + + pipeline = Double() + numbers: In[int] + doubled: Out[int] + + +class MethodPipelineConfig(ModuleConfig): + factor: int = 2 + + +class MethodPipelineModule(StreamModule[MethodPipelineConfig]): + """Pipeline as a method with access to self.config.""" + + default_config = MethodPipelineConfig + + def pipeline(self, stream: Stream) -> Stream: + return stream.transform(Double(factor=self.config.factor)) + + numbers: In[int] + doubled: Out[int] + + +# -- Grid ------------------------------------------------------------------ + +module_cases = [ + pytest.param(StaticStreamModule, id="static-stream"), + pytest.param(StaticTransformerModule, id="static-transformer"), + pytest.param(MethodPipelineModule, id="method-pipeline"), +] + + +@pytest.mark.parametrize("module_cls", module_cases) +def test_blueprint_ports(module_cls: type[StreamModule]) -> None: + """All pipeline styles produce a blueprint with the correct In/Out ports.""" + bp = module_cls.blueprint() + + assert len(bp.blueprints) == 1 + atom = bp.blueprints[0] + stream_names = {s.name for s in atom.streams} + assert "numbers" in stream_names + assert "doubled" in stream_names + + +def _reset_thread_pool() -> None: + """Shut down and replace the global RxPY thread pool so conftest thread-leak check passes.""" + import dimos.utils.threadpool as tp + + tp.scheduler.executor.shutdown(wait=True) + tp.scheduler = ThreadPoolScheduler(max_workers=tp.get_max_workers()) + + +@pytest.mark.tool +@pytest.mark.parametrize("module_cls", module_cases) +def test_e2e_runtime_wiring(module_cls: type[StreamModule]) -> None: + """Push data into In port, assert doubled data arrives on Out port.""" + module = module_cls() + module.numbers.transport = pLCMTransport("/test/numbers") + module.doubled.transport = pLCMTransport("/test/doubled") + + received: list[int] = [] + done = threading.Event() + + unsub = module.doubled.subscribe(lambda msg: (received.append(msg), done.set())) + + module.start() + try: + module.numbers.transport.publish(42) + assert done.wait(timeout=5.0), f"Timed out, received={received}" + assert received == [84] + finally: + unsub() + module.stop() + _reset_thread_pool() + _reset_thread_pool() diff --git a/dimos/memory2/test_save.py b/dimos/memory2/test_save.py index 13ee73d46a..8ebb12082b 100644 --- a/dimos/memory2/test_save.py +++ b/dimos/memory2/test_save.py @@ -101,7 +101,7 @@ def test_save_rejects_transform_target(self) -> None: base = make_stream(2) transform_stream = base.transform(FnTransformer(lambda obs: obs.derive(obs.data))) - with pytest.raises(TypeError, match="Cannot save to a transform stream"): + with pytest.raises(TypeError, match="Cannot save to a transform"): source.save(transform_stream) def test_save_target_queryable(self) -> None: diff --git a/dimos/memory2/test_store.py b/dimos/memory2/test_store.py index dfba6d6d2b..aa525c8758 100644 --- a/dimos/memory2/test_store.py +++ b/dimos/memory2/test_store.py @@ -24,6 +24,7 @@ import pytest +from dimos.memory2.backend import Backend from dimos.memory2.blobstore.base import BlobStore from dimos.memory2.vectorstore.base import VectorStore @@ -525,3 +526,94 @@ def test_accessor_dynamic(self, session: Store) -> None: assert "late" not in dir(session.streams) session.stream("late", str) assert "late" in dir(session.streams) + + +class TestStoreLifecycle: + """Cleanup chain: Store → Stream → Backend → components.""" + + def test_stop_stream_keeps_other_streams(self, session: Store) -> None: + """Stopping one stream doesn't affect another.""" + s1 = session.stream("a", int) + s2 = session.stream("b", int) + s1.append(1) + s2.append(2) + + s1.stop() + + # s2 still works + s2.append(3) + assert [obs.data for obs in s2] == [2, 3] + + def test_store_stop_stops_backends(self, session: Store) -> None: + """Store.stop() disposes backends transitively via streams.""" + s1 = session.stream("x", int) + s2 = session.stream("y", int) + s1.append(10) + s2.append(20) + + backend1 = s1._source + backend2 = s2._source + assert isinstance(backend1, Backend) + assert isinstance(backend2, Backend) + + session.stop() + + # Both backends' disposables are disposed + assert backend1._disposables is not None + assert backend1._disposables.is_disposed + assert backend2._disposables is not None + assert backend2._disposables.is_disposed + + def test_stream_stop_stops_backend(self, session: Store) -> None: + """stream.stop() disposes its backend (Stream owns Backend).""" + s = session.stream("owned", int) + s.append(42) + + backend = s._source + assert isinstance(backend, Backend) + + s.stop() + + assert backend._disposables is not None + assert backend._disposables.is_disposed + + def test_stream_stop_stops_backend_components(self, session: Store) -> None: + """stream.stop() cascades through backend to its components.""" + s = session.stream("cascade", int) + backend = s._source + assert isinstance(backend, Backend) + + s.stop() + + # Backend registers notifier as disposable, so it gets disposed + assert backend._disposables is not None + assert backend._disposables.is_disposed + # Notifier's own disposables may be None (no children registered), + # but the backend's disposal cascade is what matters. + + def test_delete_stream_stops_backend(self, session: Store) -> None: + """delete_stream() stops the stream+backend and removes from cache.""" + s = session.stream("ephemeral", int) + s.append(1) + + backend = s._source + assert isinstance(backend, Backend) + assert "ephemeral" in session.list_streams() + + session.delete_stream("ephemeral") + + assert backend._disposables is not None + assert backend._disposables.is_disposed + assert "ephemeral" not in session.list_streams() + + def test_backend_stop_stops_components(self, session: Store) -> None: + """Backend.stop() propagates to metadata_store, blob_store, vector_store.""" + s = session.stream("z", int) + backend = s._source + assert isinstance(backend, Backend) + + session.stop() + + # Backend always registers its components, so _disposables is always set + assert backend._disposables is not None + assert backend._disposables.is_disposed diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index 03c3caec76..e53cd15d9f 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -26,14 +26,14 @@ import pytest from dimos.memory2.buffer import KeepLast, Unbounded +from dimos.memory2.store.memory import MemoryStore +from dimos.memory2.stream import Stream from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer from dimos.memory2.type.observation import Observation if TYPE_CHECKING: from collections.abc import Callable - from dimos.memory2.stream import Stream - @pytest.fixture def make_stream(session) -> Callable[..., Stream[int]]: @@ -50,11 +50,6 @@ def f(n: int = 5, start_ts: float = 0.0): return f -# ═══════════════════════════════════════════════════════════════════ -# 1. Basic iteration -# ═══════════════════════════════════════════════════════════════════ - - class TestBasicIteration: """Streams are lazy iterables — nothing runs until you iterate.""" @@ -85,11 +80,6 @@ def test_stream_is_reiterable(self, make_stream): assert first == second == [0, 10, 20] -# ═══════════════════════════════════════════════════════════════════ -# 2. Temporal filters -# ═══════════════════════════════════════════════════════════════════ - - class TestTemporalFilters: """Temporal filters constrain observations by timestamp.""" @@ -119,11 +109,6 @@ def test_chained_temporal_filters(self, make_stream): assert [o.ts for o in result] == [3.0, 4.0, 5.0, 6.0] -# ═══════════════════════════════════════════════════════════════════ -# 3. Spatial filter -# ═══════════════════════════════════════════════════════════════════ - - class TestSpatialFilter: """.near(pose, radius) filters by Euclidean distance.""" @@ -145,11 +130,6 @@ def test_near_excludes_no_pose(self, memory_session): assert [o.data for o in result] == ["has_pose"] -# ═══════════════════════════════════════════════════════════════════ -# 4. Tags filter -# ═══════════════════════════════════════════════════════════════════ - - class TestTagsFilter: """.filter_tags() matches on observation metadata.""" @@ -171,11 +151,6 @@ def test_filter_multiple_tags(self, memory_session): assert [o.data for o in result] == ["a"] -# ═══════════════════════════════════════════════════════════════════ -# 5. Ordering, limit, offset -# ═══════════════════════════════════════════════════════════════════ - - class TestOrderLimitOffset: def test_limit(self, make_stream): result = make_stream(10).limit(3).fetch() @@ -220,11 +195,6 @@ def test_drain(self, make_stream): assert make_stream(0).drain() == 0 -# ═══════════════════════════════════════════════════════════════════ -# 6. Functional API: .filter(), .map() -# ═══════════════════════════════════════════════════════════════════ - - class TestFunctionalAPI: """Functional combinators receive the full Observation.""" @@ -249,11 +219,6 @@ def test_map_preserves_ts(self, make_stream): assert [o.data for o in result] == ["0", "10", "20"] -# ═══════════════════════════════════════════════════════════════════ -# 7. Transform chaining -# ═══════════════════════════════════════════════════════════════════ - - class TestTransformChaining: """Transforms chain lazily — each obs flows through the full pipeline.""" @@ -352,9 +317,109 @@ def __call__(self, upstream): assert len(calls) == 3 -# ═══════════════════════════════════════════════════════════════════ -# 8. Store -# ═══════════════════════════════════════════════════════════════════ +class TestUnboundStream: + """Unbound streams: pipelines built without a source, applied later via .chain().""" + + def test_creation(self) -> None: + """Stream() with no args creates an unbound stream.""" + s = Stream() + assert s._transform is None + + def test_multi_transform_chain(self) -> None: + """Unbound pipeline with multiple transforms produces correct results when bound.""" + + class Double(Transformer[int, int]): + def __call__(self, upstream): + for obs in upstream: + yield obs.derive(data=obs.data * 2) + + pipeline = Stream().transform(Double()).map(lambda obs: obs.derive(data=obs.data + 1)) + + store = MemoryStore() + with store: + stream = store.stream("test", int) + stream.append(5) + stream.append(10) + + result = stream.chain(pipeline).fetch() + assert [obs.data for obs in result] == [11, 21] + + def test_iteration_raises(self) -> None: + """Iterating an unbound stream raises TypeError.""" + s = Stream().transform(FnTransformer(lambda obs: obs)) + with pytest.raises(TypeError, match="unbound"): + list(s) + + def test_chain_applies_transforms(self) -> None: + """chain() replays unbound transforms on a real stream.""" + store = MemoryStore() + with store: + stream = store.stream("test", int) + stream.append(10) + stream.append(20) + stream.append(30) + + class Double(Transformer[int, int]): + def __call__(self, upstream): + for obs in upstream: + yield obs.derive(data=obs.data * 2) + + pipeline = Stream().transform(Double()) + result = stream.chain(pipeline).fetch() + assert [obs.data for obs in result] == [20, 40, 60] + + def test_chain_multiple_transforms(self) -> None: + """chain() preserves order of multiple transforms.""" + store = MemoryStore() + with store: + stream = store.stream("test", int) + stream.append(5) + + class Double(Transformer[int, int]): + def __call__(self, upstream): + for obs in upstream: + yield obs.derive(data=obs.data * 2) + + class AddTen(Transformer[int, int]): + def __call__(self, upstream): + for obs in upstream: + yield obs.derive(data=obs.data + 10) + + pipeline = Stream().transform(Double()).transform(AddTen()) + result = stream.chain(pipeline).fetch() + assert result[0].data == 20 # (5 * 2) + 10 + + def test_chain_preserves_filters(self) -> None: + """chain() replays filters from the unbound stream.""" + store = MemoryStore() + with store: + stream = store.stream("test", int) + stream.append(10, ts=1.0) + stream.append(20, ts=2.0) + stream.append(30, ts=3.0) + + pipeline = Stream().after(1.5) + result = stream.chain(pipeline).fetch() + assert [obs.data for obs in result] == [20, 30] + + def test_chain_rejects_bound_stream(self) -> None: + """chain() raises if passed a bound (non-unbound) stream.""" + store = MemoryStore() + with store: + s1 = store.stream("a", int) + s2 = store.stream("b", int) + with pytest.raises(TypeError, match="unbound"): + s1.chain(s2) + + def test_live_rejects_unbound(self) -> None: + """live() raises on an unbound stream.""" + with pytest.raises(TypeError, match="unbound"): + Stream().live() + + def test_str(self) -> None: + """Unbound streams display as Stream(unbound).""" + s = Stream() + assert "unbound" in str(s) class TestStore: @@ -385,11 +450,6 @@ def test_delete_stream(self, memory_store): assert "temp" not in memory_store.list_streams() -# ═══════════════════════════════════════════════════════════════════ -# 9. Lazy data loading -# ═══════════════════════════════════════════════════════════════════ - - class TestLazyData: """Observation.data supports lazy loading with cleanup.""" @@ -430,11 +490,6 @@ def test_derive_preserves_metadata(self): assert derived.data == "transformed" -# ═══════════════════════════════════════════════════════════════════ -# 10. Live mode -# ═══════════════════════════════════════════════════════════════════ - - class TestLiveMode: """Live streams yield backfill then block for new observations.""" diff --git a/dimos/memory2/test_visualizer.py b/dimos/memory2/test_visualizer.py index 033d60205f..0830c946fd 100644 --- a/dimos/memory2/test_visualizer.py +++ b/dimos/memory2/test_visualizer.py @@ -16,6 +16,7 @@ from __future__ import annotations +import pickle from typing import TYPE_CHECKING import pytest @@ -24,8 +25,12 @@ from dimos.memory2.transform import Batch, QualityWindow from dimos.models.embedding.clip import CLIPModel from dimos.models.vl.florence import Florence2Model +from dimos.models.vl.moondream import MoondreamVlModel +from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.sensor_msgs.Image import Image -from dimos.utils.data import get_data_dir +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC +from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.utils.data import get_data, get_data_dir if TYPE_CHECKING: from collections.abc import Iterator @@ -105,13 +110,11 @@ def test_search_near_pose(self, store: SqliteStore) -> None: # VIS GOAL: draw 2d detections somehow, or project into 3d, draw 3d bounding boxes def test_detect_objects(self, store: SqliteStore, clip: CLIPModel) -> None: """CLIP pre-filter + VLM detection on top candidates.""" - from dimos.models.vl.moondream import MoondreamVlModel - vlm = MoondreamVlModel() embedded = store.streams.color_image_embedded lidar = store.streams.lidar - for obs in embedded.search(clip.embed_text("bottle"), k=10).map( + for obs in embedded.search(clip.embed_text("bottle"), k=1).map( lambda obs: obs.derive(data=vlm.query_detections(obs.data, "bottle")) ): print(f"ts={obs.ts:.2f} sim={obs.similarity:.3f} pose={obs.pose}") @@ -135,13 +138,65 @@ def test_search_reconstruct_full_path(self, store: SqliteStore) -> None: def test_agent_visual_description_passive(self, store: SqliteStore) -> None: florence = Florence2Model() with florence: - pipeline = store.streams.color_image.transform( - QualityWindow(lambda img: img.sharpness, window=5.0) - # we are batch processing images here, - # so we can use the more efficient batch captioning API - # (instead of using .map() and calling caption() for each image, - ).transform(Batch(lambda imgs: florence.caption_batch(*imgs))) + pipeline = ( + store.streams.color_image.limit(200) + .transform( + QualityWindow(lambda img: img.sharpness, window=5.0) + # we are batch processing images here, + # so we can use the more efficient batch captioning API + # (instead of using .map() and calling caption() for each image, + ) + .transform(Batch(lambda imgs: florence.caption_batch(*imgs))) + ) # this can be stored, further embedded etc for obs in pipeline: print(obs.ts, obs.data) + + def test_build_global_map(self, store: SqliteStore) -> None: + global_map = pickle.loads(get_data("unitree_go2_bigoffice_map.pickle").read_bytes()) + print(f"Global map: {len(global_map)}") + + # we semantically search, then detect with a detection model + # + # VIS GOAL: draw 2d detections somehow, or project into 3d, draw 3d bounding boxes + def test_detect_objects_smart(self, store: SqliteStore, clip: CLIPModel) -> None: + """CLIP pre-filter + VLM detection on top candidates.""" + vlm = MoondreamVlModel() + embedded = store.streams.color_image_embedded + lidar = store.streams.lidar + + # find a location in the world with highest semantic similarity to a bottle + bottle_pos = embedded.search(clip.embed_text("bottle"), k=1).first().pose_stamped + + for obs in ( + store.streams.color_image + # find all frames within 60 seconds of the semantic hotspot + .at(bottle_pos.ts, tolerance=60.0) + # filter the frames within 1m radius near the semantic hotspot + .near(bottle_pos, radius=1.0) + # select highest quality frames from these results (based on sharpness) + .transform(QualityWindow(lambda img: img.sharpness, window=1.0)) + # run detection on these frames to find bottles + .map(lambda obs: obs.derive(data=vlm.query_detections(obs.data, "bottle"))) + ): + print(f"ts={obs.ts:.2f} pose={obs.pose_stamped}") + + # find the lidar frame captured closest in time to an image + lidar_frame = lidar.at(obs.ts).first().data + + for det in obs.data.detections: + print(det) + # project each bottle into 3D using lidar frame + # known camera intrinsics + extrinsics + det3d = Detection3DPC.from_2d( + det, + lidar_frame, + camera_info=GO2Connection.camera_info_static, + world_to_optical_transform=Transform( + ts=obs.ts, + translation=obs.pose_stamped.position, + rotation=obs.pose_stamped.orientation, + ).inverse(), + ) + print(det3d) diff --git a/dimos/memory2/test_voxel_map.py b/dimos/memory2/test_voxel_map.py new file mode 100644 index 0000000000..0fd254be60 --- /dev/null +++ b/dimos/memory2/test_voxel_map.py @@ -0,0 +1,135 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from dimos.mapping.voxels import VoxelMapTransformer +from dimos.memory2.store.sqlite import SqliteStore +from dimos.memory2.type.observation import Observation +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.data import get_data + +if TYPE_CHECKING: + from collections.abc import Iterator + + +@pytest.fixture(scope="module") +def store() -> Iterator[SqliteStore]: + db = SqliteStore(path=get_data("go2_bigoffice.db")) + with db: + yield db + + +def _make_obs(obs_id: int, points: np.ndarray, ts: float = 0.0) -> Observation[PointCloud2]: + return Observation(id=obs_id, ts=ts, _data=PointCloud2.from_numpy(points)) + + +def _unit_cube_points(n: int = 100) -> np.ndarray: + rng = np.random.default_rng(42) + return rng.random((n, 3)).astype(np.float32) + + +def test_accumulate_two_frames() -> None: + """Two non-overlapping frames produce a larger global map.""" + pts = _unit_cube_points(50) + obs1 = _make_obs(0, pts, ts=1.0) + obs2 = _make_obs(1, pts + 10.0, ts=2.0) # offset by 10m, no overlap + + xf = VoxelMapTransformer(voxel_size=0.5, carve_columns=False) + results = list(xf(iter([obs1, obs2]))) + + assert len(results) == 2 # emit_every=1 default + global_map = results[-1].data # last result has the full accumulated map + + single_results = list(VoxelMapTransformer(voxel_size=0.5)(iter([obs1]))) + assert len(global_map) > len(single_results[0].data) + + +def test_empty_stream() -> None: + xf = VoxelMapTransformer(voxel_size=0.5) + assert list(xf(iter([]))) == [] + + +def test_frame_count_tag() -> None: + pts = _unit_cube_points(30) + obs = [_make_obs(i, pts, ts=float(i)) for i in range(5)] + + xf = VoxelMapTransformer(voxel_size=0.5, device="CPU:0") + results = list(xf(iter(obs))) + + assert len(results) == 5 # emit_every=1 (default), one result per frame + assert results[-1].tags["frame_count"] == 5 + + +def test_emit_every_batch_mode() -> None: + """emit_every=0 yields only on exhaustion (batch mode).""" + pts = _unit_cube_points(30) + obs = [_make_obs(i, pts, ts=float(i)) for i in range(5)] + + xf = VoxelMapTransformer(voxel_size=0.5, device="CPU:0", emit_every=0) + results = list(xf(iter(obs))) + + assert len(results) == 1 + assert results[0].tags["frame_count"] == 5 + + +def test_emit_every_n() -> None: + """emit_every=3 yields after every 3rd frame, plus remainder on exhaustion.""" + pts = _unit_cube_points(30) + obs = [_make_obs(i, pts, ts=float(i)) for i in range(7)] + + xf = VoxelMapTransformer(voxel_size=0.5, device="CPU:0", emit_every=3) + results = list(xf(iter(obs))) + + # 7 frames / emit_every=3 → yields at frame 3, 6, then remainder (7) on exhaustion + assert len(results) == 3 + assert results[0].tags["frame_count"] == 3 + assert results[1].tags["frame_count"] == 6 + assert results[2].tags["frame_count"] == 7 + + +# -- Integration tests against real replay data -- + + +@pytest.mark.tool +def test_build_global_map(store: SqliteStore) -> None: + t_total = time.perf_counter() + + lidar = store.stream("lidar", PointCloud2) + n_frames = lidar.count() + + t0 = time.perf_counter() + result = lidar.transform(VoxelMapTransformer(voxel_size=0.05)).last() + t_transform = time.perf_counter() - t0 + + t_total = time.perf_counter() - t_total + + global_map = result.data + frame_count = result.tags["frame_count"] + + assert frame_count == n_frames + assert len(global_map) > 0 + + print( + lidar.summary(), + f"\n{frame_count} frames -> {len(global_map)} voxels" + f"\n transform: {t_transform:.2f}s ({t_transform / frame_count * 1000:.1f}ms/frame)" + f"\n total wall: {t_total:.2f}s", + ) diff --git a/dimos/memory2/transform.py b/dimos/memory2/transform.py index 20d6bf0baf..5754ac36e3 100644 --- a/dimos/memory2/transform.py +++ b/dimos/memory2/transform.py @@ -105,6 +105,19 @@ def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R yield o.derive(data=r) +def stride(n: int) -> FnIterTransformer[T, T]: + """Yield every *n*-th observation, skipping the rest.""" + if n < 1: + raise ValueError(f"stride(n) requires n >= 1, got {n}") + + def _stride(upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + for i, obs in enumerate(upstream): + if i % n == 0: + yield obs + + return FnIterTransformer(_stride) + + class QualityWindow(Transformer[T, T]): """Keeps the highest-quality item per time window. diff --git a/dimos/memory2/type/observation.py b/dimos/memory2/type/observation.py index 0a6dd16ea5..03a8819867 100644 --- a/dimos/memory2/type/observation.py +++ b/dimos/memory2/type/observation.py @@ -22,6 +22,7 @@ from collections.abc import Callable from dimos.models.embedding.base import Embedding + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped T = TypeVar("T") @@ -50,6 +51,15 @@ class Observation(Generic[T]): _loader: Callable[[], T] | None = field(default=None, repr=False) _data_lock: threading.Lock = field(default_factory=threading.Lock, repr=False) + @property + def pose_stamped(self) -> PoseStamped: + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + if self.pose is None: + raise LookupError("No pose set on this observation") + x, y, z, qx, qy, qz, qw = self.pose + return PoseStamped(ts=self.ts, position=(x, y, z), orientation=(qx, qy, qz, qw)) + @property def data(self) -> T: val = self._data diff --git a/dimos/memory2/utils/sqlite.py b/dimos/memory2/utils/sqlite.py index e242a6e1f5..02a48f22b7 100644 --- a/dimos/memory2/utils/sqlite.py +++ b/dimos/memory2/utils/sqlite.py @@ -14,12 +14,13 @@ from __future__ import annotations +from pathlib import Path import sqlite3 from reactivex.disposable import Disposable -def open_sqlite_connection(path: str) -> sqlite3.Connection: +def open_sqlite_connection(path: str | Path) -> sqlite3.Connection: """Open a WAL-mode SQLite connection with sqlite-vec loaded.""" import sqlite_vec @@ -33,7 +34,7 @@ def open_sqlite_connection(path: str) -> sqlite3.Connection: def open_disposable_sqlite_connection( - path: str, + path: str | Path, ) -> tuple[Disposable, sqlite3.Connection]: """Open a WAL-mode SQLite connection and return (disposable, connection). diff --git a/dimos/memory2/vectorstore/sqlite.py b/dimos/memory2/vectorstore/sqlite.py index cd6573cc0c..31ebba45d6 100644 --- a/dimos/memory2/vectorstore/sqlite.py +++ b/dimos/memory2/vectorstore/sqlite.py @@ -76,7 +76,7 @@ def start(self) -> None: if self._conn is None: assert self._path is not None disposable, self._conn = open_disposable_sqlite_connection(self._path) - self.register_disposables(disposable) + self.register_disposable(disposable) def put(self, stream_name: str, key: int, embedding: Embedding) -> None: vec = embedding.to_numpy().tolist() diff --git a/dimos/navigation/bbox_navigation.py b/dimos/navigation/bbox_navigation.py index c96ba9efad..2be8015721 100644 --- a/dimos/navigation/bbox_navigation.py +++ b/dimos/navigation/bbox_navigation.py @@ -48,10 +48,10 @@ def start(self) -> None: unsub = self.camera_info.subscribe( lambda msg: setattr(self, "camera_intrinsics", [msg.K[0], msg.K[4], msg.K[2], msg.K[5]]) ) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) unsub = self.detection2d.subscribe(self._on_detection) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) @rpc def stop(self) -> None: diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index e2f408b538..bee3a83b85 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -153,22 +153,22 @@ def start(self) -> None: super().start() unsub = self.global_costmap.subscribe(self._on_costmap) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) unsub = self.odom.subscribe(self._on_odometry) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.goal_reached.transport is not None: unsub = self.goal_reached.subscribe(self._on_goal_reached) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.explore_cmd.transport is not None: unsub = self.explore_cmd.subscribe(self._on_explore_cmd) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.stop_explore_cmd.transport is not None: unsub = self.stop_explore_cmd.subscribe(self._on_stop_explore_cmd) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) @rpc def stop(self) -> None: diff --git a/dimos/navigation/patrolling/module.py b/dimos/navigation/patrolling/module.py index 48ee59699b..647eeae989 100644 --- a/dimos/navigation/patrolling/module.py +++ b/dimos/navigation/patrolling/module.py @@ -62,11 +62,11 @@ def __init__(self, g: GlobalConfig = global_config) -> None: def start(self) -> None: super().start() - self._disposables.add(Disposable(self.odom.subscribe(self._on_odom))) - self._disposables.add( + self.register_disposable(Disposable(self.odom.subscribe(self._on_odom))) + self.register_disposable( Disposable(self.global_costmap.subscribe(self._router.handle_occupancy_grid)) ) - self._disposables.add(Disposable(self.goal_reached.subscribe(self._on_goal_reached))) + self.register_disposable(Disposable(self.goal_reached.subscribe(self._on_goal_reached))) @rpc def stop(self) -> None: diff --git a/dimos/navigation/replanning_a_star/module.py b/dimos/navigation/replanning_a_star/module.py index 26c540a254..2375af20ce 100644 --- a/dimos/navigation/replanning_a_star/module.py +++ b/dimos/navigation/replanning_a_star/module.py @@ -53,16 +53,18 @@ def __init__(self, **kwargs: Any) -> None: def start(self) -> None: super().start() - self._disposables.add(Disposable(self.odom.subscribe(self._planner.handle_odom))) - self._disposables.add( + self.register_disposable(Disposable(self.odom.subscribe(self._planner.handle_odom))) + self.register_disposable( Disposable(self.global_costmap.subscribe(self._planner.handle_global_costmap)) ) - self._disposables.add( + self.register_disposable( Disposable(self.goal_request.subscribe(self._planner.handle_goal_request)) ) - self._disposables.add(Disposable(self.target.subscribe(self._planner.handle_goal_request))) + self.register_disposable( + Disposable(self.target.subscribe(self._planner.handle_goal_request)) + ) - self._disposables.add( + self.register_disposable( Disposable( self.clicked_point.subscribe( lambda pt: self._planner.handle_goal_request(pt.to_pose_stamped()) @@ -70,14 +72,14 @@ def start(self) -> None: ) ) - self._disposables.add(self._planner.path.subscribe(self.path.publish)) + self.register_disposable(self._planner.path.subscribe(self.path.publish)) - self._disposables.add(self._planner.cmd_vel.subscribe(self.cmd_vel.publish)) + self.register_disposable(self._planner.cmd_vel.subscribe(self.cmd_vel.publish)) - self._disposables.add(self._planner.goal_reached.subscribe(self.goal_reached.publish)) + self.register_disposable(self._planner.goal_reached.subscribe(self.goal_reached.publish)) if "DEBUG_NAVIGATION" in os.environ: - self._disposables.add( + self.register_disposable( self._planner.navigation_costmap.subscribe(self.navigation_costmap.publish) ) diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index ef76539d5f..44d1c300c4 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -131,7 +131,7 @@ def __init__(self, **kwargs: Any) -> None: def start(self) -> None: self._running = True - self._disposables.add( + self.register_disposable( self._local_pointcloud_subject.pipe( ops.sample(1.0 / self.config.local_pointcloud_freq), ).subscribe( @@ -140,7 +140,7 @@ def start(self) -> None: ) ) - self._disposables.add( + self.register_disposable( self._global_map_subject.pipe( ops.sample(1.0 / self.config.global_map_freq), ).subscribe( diff --git a/dimos/perception/detection/type/detection2d/bbox.py b/dimos/perception/detection/type/detection2d/bbox.py index 9ce3f11b96..a38d6f3bce 100644 --- a/dimos/perception/detection/type/detection2d/bbox.py +++ b/dimos/perception/detection/type/detection2d/bbox.py @@ -98,33 +98,6 @@ def to_repr_dict(self) -> dict[str, Any]: "bbox": f"[{x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f}]", } - def center_to_3d( - self, - pixel: tuple[int, int], - camera_info: CameraInfo, # type: ignore[name-defined] - assumed_depth: float = 1.0, - ) -> PoseStamped: # type: ignore[name-defined] - """Unproject 2D pixel coordinates to 3D position in camera optical frame. - - Args: - camera_info: Camera calibration information - assumed_depth: Assumed depth in meters (default 1.0m from camera) - - Returns: - Vector3 position in camera optical frame coordinates - """ - # Extract camera intrinsics - fx, fy = camera_info.K[0], camera_info.K[4] - cx, cy = camera_info.K[2], camera_info.K[5] - - # Unproject pixel to normalized camera coordinates - x_norm = (pixel[0] - cx) / fx - y_norm = (pixel[1] - cy) / fy - - # Create 3D point at assumed depth in camera optical frame - # Camera optical frame: X right, Y down, Z forward - return Vector3(x_norm * assumed_depth, y_norm * assumed_depth, assumed_depth) # type: ignore[name-defined] - # return focused image, only on the bbox def cropped_image(self, padding: int = 20) -> Image: """Return a cropped version of the image focused on the bounding box. diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py index da9fe62370..3342ef9a5e 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -297,11 +297,11 @@ def _on_frame(img: Image) -> None: f"buffered={len(self._accumulator._buffer)}" ) - self._disposables.add( + self.register_disposable( frame_subject.pipe(sharpness_barrier(self.config.fps)).subscribe(_on_frame) ) unsub_image = self.color_image.subscribe(frame_subject.on_next) - self._disposables.add(Disposable(unsub_image)) + self.register_disposable(Disposable(unsub_image)) # Odometry tracking for entity world positioning (optional — # module works without it, entities just won't have world positions) @@ -313,14 +313,14 @@ def _on_odom(msg: PoseStamped) -> None: if self.odom.transport is not None: unsub_odom = self.odom.subscribe(_on_odom) - self._disposables.add(Disposable(unsub_odom)) + self.register_disposable(Disposable(unsub_odom)) else: logger.warning( "[temporal-memory] odom stream not connected — entity positions will be (0,0,0)" ) # Periodic window analysis - self._disposables.add( + self.register_disposable( interval(self.config.stride_s).subscribe(lambda _: self._analyze_window()) ) logger.info("TemporalMemory started") diff --git a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py index fc1895373c..5407bf97a1 100644 --- a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py +++ b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py @@ -537,7 +537,7 @@ def emit_frames(observer, scheduler): # type: ignore[no-untyped-def] time.sleep(0.5) observer.on_completed() - self._disposables.add( + self.register_disposable( reactivex.create(emit_frames) .pipe( ops.observe_on(reactivex.scheduler.NewThreadScheduler()), diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index a8970c61d8..a42033e1b0 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -147,7 +147,7 @@ def on_aligned_frames(frames_tuple) -> None: # type: ignore[no-untyped-def] match_tolerance=0.5, # 500ms tolerance ) unsub = aligned_frames.subscribe(on_aligned_frames) - self._disposables.add(unsub) + self.register_disposable(unsub) # Subscribe to camera info stream separately (doesn't need alignment) def on_camera_info(camera_info_msg: CameraInfo) -> None: @@ -162,7 +162,7 @@ def on_camera_info(camera_info_msg: CameraInfo) -> None: ] unsub = self.camera_info.subscribe(on_camera_info) # type: ignore[assignment] - self._disposables.add(Disposable(unsub)) # type: ignore[arg-type] + self.register_disposable(Disposable(unsub)) # type: ignore[arg-type] @rpc def stop(self) -> None: diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py index a53d331aef..5261d039f7 100644 --- a/dimos/perception/object_tracker_2d.py +++ b/dimos/perception/object_tracker_2d.py @@ -96,7 +96,7 @@ def on_frame(frame_msg: Image) -> None: self._frame_arrival_time = arrival_time unsub = self.color_image.subscribe(on_frame) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) logger.info("ObjectTracker2D module started") @rpc diff --git a/dimos/perception/object_tracker_3d.py b/dimos/perception/object_tracker_3d.py index 317a58dba0..f6945920fb 100644 --- a/dimos/perception/object_tracker_3d.py +++ b/dimos/perception/object_tracker_3d.py @@ -99,7 +99,7 @@ def on_aligned_frames(frames_tuple) -> None: # type: ignore[no-untyped-def] match_tolerance=0.5, # 500ms tolerance ) unsub = aligned_frames.subscribe(on_aligned_frames) - self._disposables.add(unsub) + self.register_disposable(unsub) # Subscribe to camera info def on_camera_info(camera_info_msg: CameraInfo) -> None: diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 13a3c8e289..18389007ca 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -196,10 +196,10 @@ def set_video(image_msg: Image) -> None: else: logger.warning("Received image message without data attribute") - self._disposables.add(Disposable(self.color_image.subscribe(set_video))) + self.register_disposable(Disposable(self.color_image.subscribe(set_video))) # Start periodic processing using interval - self._disposables.add( + self.register_disposable( interval(self._process_interval).subscribe(lambda _: self._process_frame()) ) diff --git a/dimos/robot/drone/connection_module.py b/dimos/robot/drone/connection_module.py index 863f719bad..485f8d8383 100644 --- a/dimos/robot/drone/connection_module.py +++ b/dimos/robot/drone/connection_module.py @@ -22,7 +22,7 @@ from typing import Any from dimos_lcm.std_msgs import String -from reactivex.disposable import CompositeDisposable, Disposable +from reactivex.disposable import Disposable from dimos.agents.annotation import skill from dimos.core.core import rpc @@ -42,13 +42,6 @@ logger = setup_logger() -def _add_disposable(composite: CompositeDisposable, item: Disposable | Any) -> None: - if isinstance(item, Disposable): - composite.add(item) - elif callable(item): - composite.add(Disposable(item)) - - class Config(ModuleConfig): connection_string: str = "udp:0.0.0.0:14550" video_port: int = 5600 @@ -126,8 +119,7 @@ def start(self) -> None: if self.video_stream.start(): logger.info("Video stream started") # Subscribe to video, store latest frame and publish it - _add_disposable( - self._disposables, + self.register_disposable( self.video_stream.get_stream().subscribe(self._store_and_publish_frame), ) # # TEMPORARY - DELETE AFTER RECORDING @@ -139,29 +131,25 @@ def start(self) -> None: logger.warning("Video stream failed to start") # Subscribe to drone streams - _add_disposable( - self._disposables, self.connection.odom_stream().subscribe(self._publish_tf) - ) - _add_disposable( - self._disposables, self.connection.status_stream().subscribe(self._publish_status) - ) - _add_disposable( - self._disposables, self.connection.telemetry_stream().subscribe(self._publish_telemetry) + self.register_disposable(self.connection.odom_stream().subscribe(self._publish_tf)) + self.register_disposable(self.connection.status_stream().subscribe(self._publish_status)) + self.register_disposable( + self.connection.telemetry_stream().subscribe(self._publish_telemetry) ) # Subscribe to movement commands - _add_disposable(self._disposables, self.movecmd.subscribe(self.move)) + self.register_disposable(Disposable(self.movecmd.subscribe(self.move))) # Subscribe to Twist movement commands if self.movecmd_twist.transport: - _add_disposable(self._disposables, self.movecmd_twist.subscribe(self._on_move_twist)) + self.register_disposable(Disposable(self.movecmd_twist.subscribe(self._on_move_twist))) if self.gps_goal.transport: - _add_disposable(self._disposables, self.gps_goal.subscribe(self._on_gps_goal)) + self.register_disposable(Disposable(self.gps_goal.subscribe(self._on_gps_goal))) if self.tracking_status.transport: - _add_disposable( - self._disposables, self.tracking_status.subscribe(self._on_tracking_status) + self.register_disposable( + Disposable(self.tracking_status.subscribe(self._on_tracking_status)) ) # Start telemetry update thread diff --git a/dimos/robot/drone/test_drone.py b/dimos/robot/drone/test_drone.py index 0b30c22c35..2b9517614a 100644 --- a/dimos/robot/drone/test_drone.py +++ b/dimos/robot/drone/test_drone.py @@ -240,13 +240,13 @@ def test_connection_module_replay_mode(self) -> None: mock_conn_instance = MagicMock() mock_conn_instance.connected = True mock_conn_instance.odom_stream.return_value.subscribe = MagicMock( - return_value=lambda: None + return_value=MagicMock() ) mock_conn_instance.status_stream.return_value.subscribe = MagicMock( - return_value=lambda: None + return_value=MagicMock() ) mock_conn_instance.telemetry_stream.return_value.subscribe = MagicMock( - return_value=lambda: None + return_value=MagicMock() ) mock_conn_instance.disconnect = MagicMock() mock_fake_conn.return_value = mock_conn_instance @@ -255,7 +255,7 @@ def test_connection_module_replay_mode(self) -> None: mock_video_instance = MagicMock() mock_video_instance.start.return_value = True mock_video_instance.get_stream.return_value.subscribe = MagicMock( - return_value=lambda: None + return_value=MagicMock() ) mock_video_instance.stop = MagicMock() mock_fake_video.return_value = mock_video_instance @@ -264,7 +264,7 @@ def test_connection_module_replay_mode(self) -> None: module = DroneConnectionModule(connection_string="replay") module.video = MagicMock() module.movecmd = MagicMock() - module.movecmd.subscribe = MagicMock(return_value=lambda: None) + module.movecmd.subscribe = MagicMock(return_value=MagicMock()) module.tf = MagicMock() try: diff --git a/dimos/robot/test_all_blueprints_generation.py b/dimos/robot/test_all_blueprints_generation.py index c4b9652e47..28a2d1fa66 100644 --- a/dimos/robot/test_all_blueprints_generation.py +++ b/dimos/robot/test_all_blueprints_generation.py @@ -33,7 +33,7 @@ "dimos/core/test_blueprints.py", } BLUEPRINT_METHODS = {"transports", "global_config", "remappings", "requirements", "configurators"} -_EXCLUDED_MODULE_NAMES = {"Module", "ModuleBase"} +_EXCLUDED_MODULE_NAMES = {"Module", "ModuleBase", "StreamModule"} def test_all_blueprints_is_current() -> None: diff --git a/dimos/robot/unitree/b1/connection.py b/dimos/robot/unitree/b1/connection.py index 11af31b296..26fe3db933 100644 --- a/dimos/robot/unitree/b1/connection.py +++ b/dimos/robot/unitree/b1/connection.py @@ -121,24 +121,24 @@ def start(self) -> None: # Subscribe to input streams if self.cmd_vel: unsub = self.cmd_vel.subscribe(self.handle_twist_stamped) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.mode_cmd: unsub = self.mode_cmd.subscribe(self.handle_mode) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.odom_in: unsub = self.odom_in.subscribe(self._publish_odom_pose) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) # Subscribe to ROS In ports if self.ros_cmd_vel: unsub = self.ros_cmd_vel.subscribe(self.handle_twist_stamped) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.ros_odom_in: unsub = self.ros_odom_in.subscribe(self._publish_odom_pose) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.ros_tf: unsub = self.ros_tf.subscribe(self._on_ros_tf) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) # Start threads self.running = True diff --git a/dimos/robot/unitree/g1/connection.py b/dimos/robot/unitree/g1/connection.py index bc2ca7d3d9..c7ac64800c 100644 --- a/dimos/robot/unitree/g1/connection.py +++ b/dimos/robot/unitree/g1/connection.py @@ -92,7 +92,7 @@ def start(self) -> None: assert self.connection is not None self.connection.start() - self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) + self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) @rpc def stop(self) -> None: diff --git a/dimos/robot/unitree/g1/sim.py b/dimos/robot/unitree/g1/sim.py index 22fc33a978..14b39961bb 100644 --- a/dimos/robot/unitree/g1/sim.py +++ b/dimos/robot/unitree/g1/sim.py @@ -69,10 +69,10 @@ def start(self) -> None: assert self.connection is not None self.connection.start() - self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) - self._disposables.add(self.connection.odom_stream().subscribe(self._publish_sim_odom)) - self._disposables.add(self.connection.lidar_stream().subscribe(self.lidar.publish)) - self._disposables.add(self.connection.video_stream().subscribe(self.color_image.publish)) + self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) + self.register_disposable(self.connection.odom_stream().subscribe(self._publish_sim_odom)) + self.register_disposable(self.connection.lidar_stream().subscribe(self.lidar.publish)) + self.register_disposable(self.connection.video_stream().subscribe(self.color_image.publish)) self._camera_info_thread = Thread( target=self._publish_camera_info_loop, diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index 5123dc9a31..a449d9f448 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -237,10 +237,10 @@ def onimage(image: Image) -> None: self.color_image.publish(image) self._latest_video_frame = image - self._disposables.add(self.connection.lidar_stream().subscribe(self.lidar.publish)) - self._disposables.add(self.connection.odom_stream().subscribe(self._publish_tf)) - self._disposables.add(self.connection.video_stream().subscribe(onimage)) - self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) + self.register_disposable(self.connection.lidar_stream().subscribe(self.lidar.publish)) + self.register_disposable(self.connection.odom_stream().subscribe(self._publish_tf)) + self.register_disposable(self.connection.video_stream().subscribe(onimage)) + self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) self._camera_info_thread = Thread( target=self.publish_camera_info, diff --git a/dimos/robot/unitree/type/map.py b/dimos/robot/unitree/type/map.py index 4ec9419c53..8b462f86fe 100644 --- a/dimos/robot/unitree/type/map.py +++ b/dimos/robot/unitree/type/map.py @@ -67,11 +67,11 @@ def __init__(self, **kwargs: Any) -> None: def start(self) -> None: super().start() - self._disposables.add(Disposable(self.lidar.subscribe(self.add_frame))) + self.register_disposable(Disposable(self.lidar.subscribe(self.add_frame))) if self.global_publish_interval is not None: unsub = interval(self.global_publish_interval).subscribe(self._publish) - self._disposables.add(unsub) + self.register_disposable(unsub) @rpc def stop(self) -> None: diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index d051154065..2e92810611 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -294,8 +294,8 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(Disposable(self.cmd_vel.subscribe(self._on_cmd_vel))) - self._disposables.add(Disposable(self.terrain_map.subscribe(self._on_terrain))) + self.register_disposable(Disposable(self.cmd_vel.subscribe(self._on_cmd_vel))) + self.register_disposable(Disposable(self.terrain_map.subscribe(self._on_terrain))) self._running.set() self._sim_thread = threading.Thread(target=self._sim_loop, daemon=True) self._sim_thread.start() diff --git a/dimos/simulation/unity/test_unity_sim.py b/dimos/simulation/unity/test_unity_sim.py index 1e7dfea8b2..e0e740e081 100644 --- a/dimos/simulation/unity/test_unity_sim.py +++ b/dimos/simulation/unity/test_unity_sim.py @@ -65,6 +65,9 @@ def subscribe(self, cb, *_a): self._subscribers.append(cb) return lambda: self._subscribers.remove(cb) + def stop(self): + pass + def _wire(module) -> dict[str, _MockTransport]: ts = {} diff --git a/dimos/utils/demo_image_encoding.py b/dimos/utils/demo_image_encoding.py index 84b91acf79..6601b1659d 100644 --- a/dimos/utils/demo_image_encoding.py +++ b/dimos/utils/demo_image_encoding.py @@ -76,7 +76,7 @@ class ReceiverModule(Module): def start(self) -> None: super().start() - self._disposables.add(Disposable(self.image.subscribe(self._on_image))) + self.register_disposable(Disposable(self.image.subscribe(self._on_image))) self._open_file = open("/tmp/receiver-times", "w") def stop(self) -> None: diff --git a/dimos/utils/testing/collector.py b/dimos/utils/testing/collector.py index bcc3150e73..faf9464843 100644 --- a/dimos/utils/testing/collector.py +++ b/dimos/utils/testing/collector.py @@ -30,7 +30,7 @@ class CallbackCollector: assert len(collector.results) == 3 """ - def __init__(self, n: int, timeout: float = 2.0) -> None: + def __init__(self, n: int, timeout: float = 5.0) -> None: self.results: list[tuple[Any, Any]] = [] self._done = threading.Event() self._n = n diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 8b1cda443c..40c29ba4ef 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -322,12 +322,12 @@ def start(self) -> None: if hasattr(pubsub, "start"): pubsub.start() # type: ignore[union-attr] unsub = pubsub.subscribe_all(self._on_message) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) # Add pubsub stop as disposable for pubsub in self.config.pubsubs: if hasattr(pubsub, "stop"): - self._disposables.add(Disposable(pubsub.stop)) # type: ignore[union-attr] + self.register_disposable(Disposable(pubsub.stop)) # type: ignore[union-attr] self._log_static() diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 685ca2b1ee..519a6d1f4b 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -173,25 +173,25 @@ def start(self) -> None: try: unsub = self.odom.subscribe(self._on_robot_pose) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) except Exception: ... try: unsub = self.gps_location.subscribe(self._on_gps_location) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) except Exception: ... try: unsub = self.path.subscribe(self._on_path) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) except Exception: ... try: unsub = self.global_costmap.subscribe(self._on_global_costmap) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) except Exception: ... diff --git a/docs/agents/index.md b/docs/agents/index.md index ec9d66e886..4170a0e898 100644 --- a/docs/agents/index.md +++ b/docs/agents/index.md @@ -1,19 +1,8 @@ # For Agents -These docs are mostly for coding agents - -```sh -tree . -P '*.md' --prune -``` - - -``` -. -├── docs +├── testing.md (docs about writing tests) +├── docs (these are docs about writing docs) │   ├── codeblocks.md │   ├── doclinks.md │   └── index.md └── index.md - -2 directories, 4 files -``` diff --git a/docs/agents/style.md b/docs/agents/style.md new file mode 100644 index 0000000000..37354cc681 --- /dev/null +++ b/docs/agents/style.md @@ -0,0 +1,49 @@ +# Code Style Guidelines + +Rules for writing code in dimos. These address recurring issues found in code review. + +## No comment banners + +Don't use decorative section dividers or box comments. + +```python +# BAD +# ═══════════════════════════════════════════════════════════════════ +# 1. Basic iteration +# ═══════════════════════════════════════════════════════════════════ + +# BAD +# ------------------------------------------------------------------- +# Section name +# ------------------------------------------------------------------- + +# GOOD — just use a plain comment if a section heading is needed +# Basic iteration +``` + +If a file has enough sections to warrant banners, it should probably be split into separate files instead. For example, instead of one large `test_something.py` with banner-separated sections, create a `something/` directory: + +``` +# BAD +test_something.py (500 lines with banner-separated sections) + +# GOOD +something/ + test_iteration.py + test_lifecycle.py + test_queries.py +``` + +## No `__init__.py` re-exports + +Never add imports to `__init__.py` files. Re-exporting from `__init__.py` makes imports too wide and slow — importing one symbol pulls in the entire package tree. + +```python +# BAD — dimos/memory2/__init__.py +from dimos.memory2.store import Store, SqliteStore +from dimos.memory2.stream import Stream + +# GOOD — import directly from the module +from dimos.memory2.store.base import Store +from dimos.memory2.stream import Stream +``` diff --git a/docs/agents/testing.md b/docs/agents/testing.md new file mode 100644 index 0000000000..45614c81d2 --- /dev/null +++ b/docs/agents/testing.md @@ -0,0 +1,149 @@ +# Testing Guidelines + +Rules for writing tests in dimos. These address recurring issues found in code review. + +For grid testing (spec/impl tests across multiple backends), see [Grid Testing Strategy](/docs/development/grid_testing.md). + +## Imports at the top + +All imports must be at module level, not inside test functions. + +```python +# BAD +def test_something() -> None: + import threading + from dimos.core.transport import pLCMTransport + ... + +# GOOD +import threading +from dimos.core.transport import pLCMTransport + +def test_something() -> None: + ... +``` + +## Always clean up resources + +Use context managers or try/finally. If a test creates a resource, it must be cleaned up even if assertions fail. + +```python +# BAD - store.stop() never called +def test_something() -> None: + store = ListObservationStore(name="test", max_size=0) + store.start() + assert store.count(StreamQuery()) == 0 + +# BAD - module.stop() skipped if assertion fails +def test_wiring() -> None: + module = MyModule() + module.start() + assert received == [84] + module.stop() + +# GOOD - context manager (ideal) +def test_something() -> None: + store = ListObservationStore(name="test", max_size=0) + with store: + assert store.count(StreamQuery()) == 0 + +# GOOD - try/finally +def test_wiring() -> None: + module = MyModule() + module.start() + try: + assert received == [84] + finally: + module.stop() +``` + +When a resource is shared across multiple tests, use a pytest fixture with `yield` instead of repeating context managers in each test: + +```python +# GOOD - fixture handles lifecycle for all tests that use it +@pytest.fixture(scope="module") +def store() -> Iterator[SqliteStore]: + db = SqliteStore(path=str(DB_PATH)) + with db: + yield db + +def test_query(store: SqliteStore) -> None: + assert store.stream("video", Image).count() > 0 + +def test_search(store: SqliteStore) -> None: + results = store.stream("video", Image).limit(5).fetch() + assert len(results) == 5 +``` + +## No conditional logic in assertions + +Tests must be deterministic. If you don't know the state, the test is wrong. + +```python +# BAD - assertion may never execute +if hasattr(obj, "_disposables") and obj._disposables is not None: + assert obj._disposables.is_disposed + +# BAD - masks whether disposables were created +assert obj._disposables is None or obj._disposables.is_disposed + +# GOOD - explicit about what we expect +assert obj._disposables is not None +assert obj._disposables.is_disposed +``` + +## Print statements + +- **Unit tests**: no prints. Use assertions. +- **`@pytest.mark.tool` tests** (integration/exploration): prints are fine for progress and inspection output. + +## Avoid unnecessary sleeps + +Don't use `time.sleep()` to wait for async operations. Use `threading.Event` to synchronize emitter/receiver patterns. + +```python +# BAD - arbitrary sleep, fragile +module.start() +time.sleep(0.5) +module.numbers.transport.publish(42) +time.sleep(1.0) +assert len(received) == 1 + +# GOOD - use threading.Event with a timeout +done = threading.Event() +unsub = module.doubled.subscribe(lambda msg: (received.append(msg), done.set())) +module.start() +module.numbers.transport.publish(42) +assert done.wait(timeout=5.0), f"Timed out, received={received}" +assert received == [84] +``` + +## Private fields + +Configuration fields on non-Pydantic classes should be private (underscore-prefixed) unless they are part of the public API. + +```python +# BAD +self.voxel_size = voxel_size +self.carve_columns = carve_columns + +# GOOD +self._voxel_size = voxel_size +self._carve_columns = carve_columns +``` + +## Type ignores + +Avoid `# type: ignore` by using proper types: + +```python +# BAD +self.vbg = None # type: ignore[assignment] + +# GOOD - type as Optional +self.vbg: VoxelBlockGrid | None = VoxelBlockGrid(...) +# then later: +self.vbg = None # no ignore needed +``` + +Type ignores are acceptable when caused by untyped third-party libraries (e.g. `open3d`) or decorator-generated attributes (e.g. `@simple_mcache` adding `invalidate_cache`). diff --git a/examples/simplerobot/simplerobot.py b/examples/simplerobot/simplerobot.py index 517684d7cd..902736f06a 100644 --- a/examples/simplerobot/simplerobot.py +++ b/examples/simplerobot/simplerobot.py @@ -68,11 +68,11 @@ class SimpleRobot(Module[SimpleRobotConfig]): @rpc def start(self) -> None: - self._disposables.add(self.cmd_vel.observable().subscribe(self._on_twist)) - self._disposables.add( + self.register_disposable(self.cmd_vel.observable().subscribe(self._on_twist)) + self.register_disposable( rx.interval(1.0 / self.config.update_rate).subscribe(lambda _: self._update()) ) - self._disposables.add( + self.register_disposable( rx.interval(1.0).subscribe(lambda _: print(f"\033[34m{self._pose}\033[0m")) )