Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 97 additions & 4 deletions tool/benchmark/benchmark_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ class BagMetadataExtractor:
for timestamp-based pose sampling.
"""

@staticmethod
def get_bag_topics_with_message_count(bag_path: str):
"""Return topic metadata entries from a ROS2 bag."""
try:
info = rosbag2_py.Info()
metadata = info.read_metadata(bag_path, "")
return list(metadata.topics_with_message_count)
except Exception as e:
print(f"Failed to extract topic metadata from {bag_path}: {e}")
return []

@staticmethod
def get_bag_time_range(bag_path: str) -> Tuple[Optional[int], Optional[int]]:
"""
Expand Down Expand Up @@ -727,6 +738,67 @@ def sample_timestamps_from_bag(bag_path: str, num_samples: int) -> np.ndarray:
return timestamps


def sample_image_timestamps_from_bag(
bag_path: str,
num_samples: Optional[int] = None,
image_stride: Optional[int] = None,
topic_hint: str = "image",
) -> np.ndarray:
"""Sample timestamps directly from image messages by image index."""
if num_samples is None and image_stride is None:
raise ValueError("Either num_samples or image_stride must be provided")

reader = rosbag2_py.SequentialReader()
storage_options = rosbag2_py.StorageOptions(uri=bag_path, storage_id="sqlite3")
converter_options = rosbag2_py.ConverterOptions("", "")
reader.open(storage_options, converter_options)

topic_infos = BagMetadataExtractor.get_bag_topics_with_message_count(bag_path)
image_topics = []
for topic_info in topic_infos:
name = topic_info.topic_metadata.name
type_name = topic_info.topic_metadata.type
if "sensor_msgs/msg/Image" in type_name or topic_hint in name.lower():
image_topics.append(name)

if not image_topics:
raise RuntimeError(f"No image topics found in bag: {bag_path}")

topic = sorted(image_topics)[0]
print(f"Sampling evaluation images from topic: {topic}")
reader.set_filter(rosbag2_py.StorageFilter(topics=[topic]))

timestamps = []
while reader.has_next():
_topic, _data, t = reader.read_next()
timestamps.append(int(t))

if not timestamps:
raise RuntimeError(f"No image messages found on topic {topic} in bag {bag_path}")

total = len(timestamps)
start_idx = min(total - 1, max(0, int(total * 0.05)))
end_idx = max(start_idx + 1, int(total * 0.95))
candidate = timestamps[start_idx:end_idx]
if not candidate:
candidate = timestamps

if image_stride is not None:
if image_stride <= 0:
raise ValueError("image_stride must be positive")
sampled = candidate[::image_stride]
else:
count = min(num_samples, len(candidate))
indices = np.linspace(0, len(candidate) - 1, count, dtype=np.int64)
sampled = [candidate[int(i)] for i in indices]

sampled_arr = np.array(sampled, dtype=np.int64)
print(
f"Sampled {len(sampled_arr)} image timestamps by image index from {len(candidate)} candidates (total images={total})"
)
return sampled_arr


def query_poses_at_timestamps(
timestamps: np.ndarray, map_result_dir_b: str
) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]:
Expand Down Expand Up @@ -887,6 +959,8 @@ def run_benchmark(
num_samples: int,
timeout: float,
verbose_timer: bool = False,
sample_mode: str = "timestamp",
image_stride: Optional[int] = None,
) -> bool:
"""
Run benchmark using timestamp-based sampling instead of keyframe-based.
Expand Down Expand Up @@ -936,12 +1010,17 @@ def run_benchmark(
):
print("Error: Failed to localize bag B in map A")
return False

print(f"\nStep 3: Sampling {num_samples} timestamps from bag time ranges...")
sampled_timestamps = sample_timestamps_from_bag(bag_b_path, num_samples)
if sample_mode == "image-index":
print(f"\nStep 3: Sampling evaluation images from bag B by image index...")
sampled_timestamps = sample_image_timestamps_from_bag(
bag_b_path, num_samples=num_samples, image_stride=image_stride
)
else:
print(f"\nStep 3: Sampling {num_samples} timestamps from bag time ranges...")
sampled_timestamps = sample_timestamps_from_bag(bag_b_path, num_samples)

if sampled_timestamps is None or len(sampled_timestamps) == 0:
print("Error: Timestamp sampling failed")
print("Error: Timestamp/image-index sampling failed")
return False

print(f"\nStep 4: Querying poses at sampled timestamps...")
Expand Down Expand Up @@ -1012,6 +1091,18 @@ def main():
default=1.0,
help="Playback rate for ROS bags (default: 1.0x)",
)
parser.add_argument(
"--sample_mode",
choices=["timestamp", "image-index"],
default="timestamp",
help="Sampling mode for evaluation targets (default: timestamp)",
)
parser.add_argument(
"--image_stride",
type=int,
default=None,
help="When sample_mode=image-index, sample every N images instead of evenly choosing num_images",
)
parser.add_argument(
"--timeout",
type=int,
Expand Down Expand Up @@ -1053,6 +1144,8 @@ def main():
args.num_images,
args.timeout,
args.verbose_timer,
args.sample_mode,
args.image_stride,
)
if benchmark_return:
print("\nBenchmark completed!")
Expand Down
Loading