From 5dac45263d54e945563115012b039d9022c1f239 Mon Sep 17 00:00:00 2001 From: YANG Zhenfei Date: Wed, 15 Apr 2026 18:05:53 +0800 Subject: [PATCH] feat(python): add InputAligner on humble (#4) * feat(python): add InputAligner on humble * fix(python): use ROS time for InputAligner timestamps * fix(python): avoid rclpy.clock_type dependency on Humble * test(python): expand InputAligner parity coverage * test(python): add remaining InputAligner parity cases * style(python): polish InputAligner time helpers (cherry picked from commit 3eba8f9819b8e8a0dd2b5c785599077f0fd3417e) --- CMakeLists.txt | 14 +- src/message_filters/__init__.py | 2 + src/message_filters/input_aligner.py | 151 ++++++++++++++++++++ test/test_input_aligner.py | 203 +++++++++++++++++++++++++++ 4 files changed, 366 insertions(+), 4 deletions(-) create mode 100644 src/message_filters/input_aligner.py create mode 100644 test/test_input_aligner.py diff --git a/CMakeLists.txt b/CMakeLists.txt index acd79b5..63813e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -144,10 +144,16 @@ if(BUILD_TESTING) # python tests with python interfaces of message filters find_package(ament_cmake_pytest REQUIRED) - ament_add_pytest_test(test_time_synchronizer.py "test/test_time_synchronizer.py") - ament_add_pytest_test(test_approxsync.py "test/test_approxsync.py") - ament_add_pytest_test(test_message_filters_cache.py "test/test_message_filters_cache.py") - ament_add_pytest_test(test_message_filters_chain.py "test/test_message_filters_chain.py") + ament_add_pytest_test(test_time_synchronizer.py "test/test_time_synchronizer.py" + PYTHON_EXECUTABLE "${_PYTHON_EXECUTABLE}") + ament_add_pytest_test(test_approxsync.py "test/test_approxsync.py" + PYTHON_EXECUTABLE "${_PYTHON_EXECUTABLE}") + ament_add_pytest_test(test_message_filters_cache.py "test/test_message_filters_cache.py" + PYTHON_EXECUTABLE "${_PYTHON_EXECUTABLE}") + ament_add_pytest_test(test_message_filters_chain.py "test/test_message_filters_chain.py" + PYTHON_EXECUTABLE "${_PYTHON_EXECUTABLE}") + ament_add_pytest_test(test_input_aligner.py "test/test_input_aligner.py" + PYTHON_EXECUTABLE "${_PYTHON_EXECUTABLE}") endif() ament_package() diff --git a/src/message_filters/__init__.py b/src/message_filters/__init__.py index 74af36c..05fa50c 100644 --- a/src/message_filters/__init__.py +++ b/src/message_filters/__init__.py @@ -50,6 +50,8 @@ from rclpy.time import Time from rclpy.type_support import MsgT +from .input_aligner import InputAligner, QueueStatus + class SimpleFilter(object): diff --git a/src/message_filters/input_aligner.py b/src/message_filters/input_aligner.py new file mode 100644 index 0000000..dcd5d6b --- /dev/null +++ b/src/message_filters/input_aligner.py @@ -0,0 +1,151 @@ +from bisect import insort_right +from dataclasses import dataclass +import threading + +from builtin_interfaces.msg import Time as TimeMsg +from rclpy.duration import Duration +from rclpy.time import Time + + +@dataclass +class QueueStatus: + active: bool + queue_size: int + msgs_processed: int + msgs_dropped: int + + +class _Signal: + def __init__(self): + self.callbacks = {} + + def registerCallback(self, cb, *args): + conn = len(self.callbacks) + self.callbacks[conn] = (cb, args) + return conn + + def signalMessage(self, *msg): + for (cb, args) in self.callbacks.values(): + cb(*(msg + args)) + + +def _ros_zero_time(): + return Time.from_msg(TimeMsg()) + + +def _ros_max_time(): + zero = _ros_zero_time() + return Time(nanoseconds=9223372036854775807, clock_type=zero.clock_type) + + +class _EventQueue: + def __init__(self): + self.events = [] + self.next_ts = _ros_max_time() + self.period = Duration(seconds=0) + self.active = False + self.msgs_processed = 0 + self.msgs_dropped = 0 + + def first_timestamp(self): + if self.events: + first_ts = self.events[0][0] + self.next_ts = first_ts + self.period + self.active = True + return first_ts + if self.active: + return self.next_ts + return _ros_max_time() + + def pop_first(self): + self.events.pop(0) + self.msgs_processed += 1 + + def msg_dropped(self): + self.msgs_dropped += 1 + + def set_period(self, period): + self.period = period + + def set_active(self, active): + self.active = active + + def get_status(self): + return QueueStatus(self.active, len(self.events), self.msgs_processed, self.msgs_dropped) + + +class InputAligner: + def __init__(self, timeout, *filters): + self.timeout = timeout + zero_time = _ros_zero_time() + self.last_in_ts = zero_time + self.last_out_ts = zero_time + self.name = '' + self.lock = threading.Lock() + self.event_queues = [] + self.input_connections = [] + self.signals = [] + self.dispatch_timer = None + if filters: + self.connectInput(*filters) + + def connectInput(self, *filters): + self.disconnectAll() + self.event_queues = [_EventQueue() for _ in filters] + self.signals = [_Signal() for _ in filters] + self.input_connections = [f.registerCallback(self.add, idx) for idx, f in enumerate(filters)] + + def disconnectAll(self): + self.input_connections = [] + + def registerCallback(self, index, cb, *args): + return self.signals[index].registerCallback(cb, *args) + + def setName(self, name): + self.name = name + + def getName(self): + return self.name + + def add(self, msg, queue_index): + msg_timestamp = Time.from_msg(msg.header.stamp) + with self.lock: + queue = self.event_queues[queue_index] + if msg_timestamp < self.last_out_ts: + queue.msg_dropped() + return + if msg_timestamp > self.last_in_ts: + self.last_in_ts = msg_timestamp + insort_right(queue.events, (msg_timestamp, msg), key=lambda x: x[0].nanoseconds) + + def setInputPeriod(self, index, period): + self.event_queues[index].set_period(period) + + def getQueueStatus(self, index): + return self.event_queues[index].get_status() + + def setupDispatchTimer(self, node, update_rate): + self.dispatch_timer = node.create_timer(update_rate.nanoseconds / 1e9, self.dispatchMessages) + + def dispatchMessages(self): + with self.lock: + if not any(queue.events for queue in self.event_queues): + return + input_available = True + while input_available: + input_available = self._dispatch_first_message() + + def _dispatch_first_message(self): + timestamps = [queue.first_timestamp() for queue in self.event_queues] + idx = min(range(len(timestamps)), key=lambda i: timestamps[i].nanoseconds) + queue = self.event_queues[idx] + if queue.events: + stamp, msg = queue.events[0] + self.last_out_ts = stamp + self.signals[idx].signalMessage(msg) + queue.pop_first() + return True + if (self.last_in_ts - queue.first_timestamp()) >= self.timeout: + queue.set_active(False) + return True + return False diff --git a/test/test_input_aligner.py b/test/test_input_aligner.py new file mode 100644 index 0000000..9f10353 --- /dev/null +++ b/test/test_input_aligner.py @@ -0,0 +1,203 @@ +import time +import unittest + +from builtin_interfaces.msg import Time as TimeMsg +from message_filters import InputAligner, SimpleFilter +import rclpy +from rclpy.duration import Duration +from rclpy.time import Time + + +class Header: + def __init__(self, stamp=None): + self.stamp = stamp if stamp is not None else TimeMsg() + + +class Msg1: + def __init__(self, stamp=None, data=None): + self.header = Header(stamp) + self.data = data + + +class Msg2: + def __init__(self, stamp=None, data=None): + self.header = Header(stamp) + self.data = data + + +class TestInputAligner(unittest.TestCase): + @classmethod + def setUpClass(cls): + rclpy.init() + cls.node = rclpy.create_node('test_input_aligner_node') + + @classmethod + def tearDownClass(cls): + cls.node.destroy_node() + rclpy.shutdown() + + def setUp(self): + self.timeout = Duration(seconds=1.0) + self.update_rate = Duration(nanoseconds=10000000) + self.cb_content = [] + + def cb(self, msg): + self.cb_content.append(msg.data) + + def create_msg(self, cls, milliseconds, data): + return cls(stamp=Time(nanoseconds=int(milliseconds * 1e6)).to_msg(), data=data) + + def test_init(self): + f0, f1, f2, f3 = SimpleFilter(), SimpleFilter(), SimpleFilter(), SimpleFilter() + aligner1 = InputAligner(self.timeout, f0, f1, f2, f3) + self.assertEqual(len(aligner1.event_queues), 4) + aligner2 = InputAligner(self.timeout) + aligner2.connectInput(f0, f2, f3) + self.assertEqual(len(aligner2.event_queues), 3) + + def test_dispatch_inputs_in_order(self): + aligner = InputAligner(self.timeout) + aligner.connectInput(SimpleFilter(), SimpleFilter(), SimpleFilter(), SimpleFilter()) + for i in range(4): + aligner.registerCallback(i, self.cb) + aligner.setInputPeriod(i, Duration(nanoseconds=int(4e6))) + aligner.add(self.create_msg(Msg1, 3, 3), 2) + aligner.add(self.create_msg(Msg1, 1, 1), 0) + aligner.add(self.create_msg(Msg1, 7, 7), 2) + aligner.add(self.create_msg(Msg1, 5, 5), 0) + aligner.add(self.create_msg(Msg2, 2, 2), 3) + aligner.add(self.create_msg(Msg1, 9, 9), 0) + aligner.add(self.create_msg(Msg2, 4, 4), 1) + aligner.add(self.create_msg(Msg2, 8, 8), 1) + aligner.add(self.create_msg(Msg2, 6, 6), 3) + aligner.dispatchMessages() + self.assertEqual(self.cb_content, list(range(1, 10))) + + def test_ignores_inactive_inputs(self): + aligner = InputAligner(self.timeout) + aligner.connectInput(SimpleFilter(), SimpleFilter(), SimpleFilter()) + for i in range(3): + aligner.registerCallback(i, self.cb) + aligner.setInputPeriod(i, Duration(nanoseconds=int(2e6))) + aligner.add(self.create_msg(Msg1, 2, 2), 2) + aligner.add(self.create_msg(Msg2, 1, 1), 1) + aligner.add(self.create_msg(Msg1, 4, 4), 2) + aligner.add(self.create_msg(Msg2, 3, 3), 1) + aligner.add(self.create_msg(Msg2, 5, 5), 1) + aligner.dispatchMessages() + self.assertEqual(self.cb_content, [1, 2, 3, 4, 5]) + + def test_input_timeout(self): + self.timeout = Duration(nanoseconds=int(1e7)) + aligner = InputAligner(self.timeout) + aligner.connectInput(SimpleFilter(), SimpleFilter()) + for i in range(2): + aligner.registerCallback(i, self.cb) + aligner.setInputPeriod(i, Duration(nanoseconds=int(2e6))) + for i in range(1, 17, 2): + aligner.add(self.create_msg(Msg1, i, i), 0) + aligner.add(self.create_msg(Msg2, 2, 2), 1) + aligner.add(self.create_msg(Msg2, 4, 4), 1) + aligner.dispatchMessages() + self.assertEqual(self.cb_content, [1, 2, 3, 4, 5]) + aligner.add(self.create_msg(Msg1, 17, 17), 0) + aligner.dispatchMessages() + self.assertEqual(self.cb_content, [1, 2, 3, 4, 5, 7, 9, 11, 13, 15, 17]) + + def test_drops_msgs(self): + aligner = InputAligner(self.timeout) + aligner.connectInput(SimpleFilter(), SimpleFilter()) + for i in range(2): + aligner.registerCallback(i, self.cb) + aligner.setInputPeriod(i, Duration(nanoseconds=int(2e6))) + aligner.add(self.create_msg(Msg2, 4, 4), 1) + aligner.add(self.create_msg(Msg1, 3, 3), 0) + aligner.dispatchMessages() + self.assertEqual(self.cb_content, [3, 4]) + aligner.add(self.create_msg(Msg1, 1, 1), 0) + aligner.add(self.create_msg(Msg1, 5, 5), 0) + aligner.add(self.create_msg(Msg1, 7, 7), 0) + aligner.add(self.create_msg(Msg2, 2, 2), 1) + aligner.add(self.create_msg(Msg2, 6, 6), 1) + aligner.dispatchMessages() + self.assertEqual(self.cb_content, [3, 4, 5, 6, 7]) + + def test_dispatch_by_timer(self): + aligner = InputAligner(self.timeout) + aligner.connectInput(SimpleFilter(), SimpleFilter()) + aligner.setupDispatchTimer(self.node, self.update_rate) + for i in range(2): + aligner.registerCallback(i, self.cb) + aligner.setInputPeriod(i, Duration(nanoseconds=int(2e6))) + aligner.add(self.create_msg(Msg2, 2, 2), 1) + aligner.add(self.create_msg(Msg1, 1, 1), 0) + time.sleep(0.05) + rclpy.spin_once(self.node, timeout_sec=0.01) + self.assertEqual(self.cb_content, [1, 2]) + + def test_no_period_information(self): + self.timeout = Duration(nanoseconds=int(1e7)) + aligner = InputAligner(self.timeout) + aligner.connectInput(SimpleFilter(), SimpleFilter(), SimpleFilter()) + for i in range(3): + aligner.registerCallback(i, self.cb) + aligner.add(self.create_msg(Msg1, 6, 6), 0) + aligner.add(self.create_msg(Msg1, 2, 2), 2) + aligner.add(self.create_msg(Msg1, 4, 4), 2) + aligner.add(self.create_msg(Msg2, 1, 1), 1) + aligner.add(self.create_msg(Msg2, 3, 3), 1) + aligner.add(self.create_msg(Msg2, 5, 5), 1) + aligner.dispatchMessages() + self.assertEqual(self.cb_content, [1, 2, 3, 4]) + aligner.add(self.create_msg(Msg1, 16, 16), 0) + aligner.dispatchMessages() + self.assertEqual(self.cb_content, [1, 2, 3, 4, 5, 6, 16]) + + def test_get_queue_status(self): + self.timeout = Duration(nanoseconds=int(1e7)) + aligner = InputAligner(self.timeout) + aligner.connectInput(SimpleFilter(), SimpleFilter()) + for i in range(2): + aligner.registerCallback(i, self.cb) + aligner.setInputPeriod(i, Duration(nanoseconds=int(2e6))) + aligner.add(self.create_msg(Msg2, 2, 2), 1) + aligner.add(self.create_msg(Msg1, 3, 3), 0) + aligner.add(self.create_msg(Msg1, 5, 5), 0) + status_0 = aligner.getQueueStatus(0) + self.assertFalse(status_0.active) + self.assertEqual(status_0.queue_size, 2) + self.assertEqual(status_0.msgs_processed, 0) + self.assertEqual(status_0.msgs_dropped, 0) + status_1 = aligner.getQueueStatus(1) + self.assertFalse(status_1.active) + self.assertEqual(status_1.queue_size, 1) + self.assertEqual(status_1.msgs_processed, 0) + self.assertEqual(status_1.msgs_dropped, 0) + aligner.dispatchMessages() + status_0 = aligner.getQueueStatus(0) + self.assertTrue(status_0.active) + self.assertEqual(status_0.queue_size, 1) + self.assertEqual(status_0.msgs_processed, 1) + self.assertEqual(status_0.msgs_dropped, 0) + status_1 = aligner.getQueueStatus(1) + self.assertTrue(status_1.active) + self.assertEqual(status_1.queue_size, 0) + self.assertEqual(status_1.msgs_processed, 1) + self.assertEqual(status_1.msgs_dropped, 0) + aligner.add(self.create_msg(Msg1, 1, 1), 0) + aligner.add(self.create_msg(Msg1, 17, 17), 0) + aligner.dispatchMessages() + status_0 = aligner.getQueueStatus(0) + self.assertTrue(status_0.active) + self.assertEqual(status_0.queue_size, 0) + self.assertEqual(status_0.msgs_processed, 3) + self.assertEqual(status_0.msgs_dropped, 1) + status_1 = aligner.getQueueStatus(1) + self.assertFalse(status_1.active) + self.assertEqual(status_1.queue_size, 0) + self.assertEqual(status_1.msgs_processed, 1) + self.assertEqual(status_1.msgs_dropped, 0) + + +if __name__ == '__main__': + unittest.main()