From e34a3174434ff609236fde47394fe06eb0fe622d Mon Sep 17 00:00:00 2001 From: Jules Giraud Date: Tue, 26 May 2026 19:07:41 +0200 Subject: [PATCH 1/3] Add box filtering helpers --- scripts/python_smoke_test.py | 42 +++++- src/bbox.rs | 258 +++++++++++++++++++++++++++++++++++ src/lib.rs | 7 +- src/python.rs | 113 +++++++++++++++ 4 files changed, 416 insertions(+), 4 deletions(-) diff --git a/scripts/python_smoke_test.py b/scripts/python_smoke_test.py index 22ba474..0f25044 100644 --- a/scripts/python_smoke_test.py +++ b/scripts/python_smoke_test.py @@ -144,6 +144,46 @@ def main() -> None: ) assert np.allclose(clipped_boxes[0], np.array([0.0, 3.0, 10.0, 20.0], dtype=np.float32)) + score_values = np.array([0.9, 0.8, 0.7], dtype=np.float32) + filtered_by_score = rusty_cv.filter_boxes_by_score_numpy(boxes, score_values, 0.75) + assert filtered_by_score["indices"].tolist() == [0, 1] + assert np.allclose( + filtered_by_score["boxes"], + np.array( + [ + [0.0, 0.0, 10.0, 10.0], + [1.0, 1.0, 11.0, 11.0], + ], + dtype=np.float32, + ), + ) + + filtered_by_area = rusty_cv.filter_boxes_by_area_numpy(boxes, min_area=90.0, max_area=110.0) + assert filtered_by_area["indices"].tolist() == [0, 1, 2] + + filtered_by_min_size = rusty_cv.filter_boxes_by_min_size_numpy(boxes, 10.0, 10.0) + assert filtered_by_min_size["indices"].tolist() == [0, 1, 2] + + clipped_and_filtered = rusty_cv.clip_and_filter_boxes_numpy( + np.array( + [ + [-5.0, 0.0, 6.0, 9.0], + [4.0, 4.0, 5.0, 5.0], + [9.0, 9.0, 15.0, 15.0], + ], + dtype=np.float32, + ), + 10, + 10, + min_width=2.0, + min_height=2.0, + ) + assert clipped_and_filtered["indices"].tolist() == [0] + assert np.allclose( + clipped_and_filtered["boxes"], + np.array([[-0.0, 0.0, 6.0, 9.0]], dtype=np.float32), + ) + resized_boxes = rusty_cv.resize_boxes_numpy( np.array([[10.0, 20.0, 40.0, 60.0]], dtype=np.float32), 100, @@ -169,7 +209,7 @@ def main() -> None: np.array([[100.0, 50.0, 300.0, 150.0]], dtype=np.float32), ) - scores = np.array([0.9, 0.8, 0.7], dtype=np.float32) + scores = score_values class_ids = np.array([0, 0, 1], dtype=np.int64) class_scores = np.array( [ diff --git a/src/bbox.rs b/src/bbox.rs index 4e1e53a..a3c945c 100644 --- a/src/bbox.rs +++ b/src/bbox.rs @@ -165,6 +165,13 @@ pub struct Detection { pub score: f32, } +/// Result returned by box filtering helpers that transform box geometry. +#[derive(Debug, Clone, PartialEq)] +pub struct BoxFilterResult { + pub boxes: Vec, + pub indices: Vec, +} + /// Errors for box postprocessing operations. #[derive(Debug, Clone, PartialEq)] pub enum BBoxError { @@ -189,6 +196,16 @@ pub enum BBoxError { width: u32, height: u32, }, + InvalidMinArea(f32), + InvalidMaxArea(f32), + InvalidAreaRange { + min_area: f32, + max_area: f32, + }, + InvalidMinSize { + min_width: f32, + min_height: f32, + }, NonFiniteBox { index: usize, }, @@ -239,6 +256,25 @@ impl std::fmt::Display for BBoxError { "image width and height must be greater than zero, got {}x{}", width, height ), + Self::InvalidMinArea(value) => { + write!(f, "min_area must be finite and non-negative, got {}", value) + } + Self::InvalidMaxArea(value) => { + write!(f, "max_area must be finite and non-negative, got {}", value) + } + Self::InvalidAreaRange { min_area, max_area } => write!( + f, + "max_area must be greater than or equal to min_area, got min_area={} and max_area={}", + min_area, max_area + ), + Self::InvalidMinSize { + min_width, + min_height, + } => write!( + f, + "min_width and min_height must be finite and non-negative, got {} and {}", + min_width, min_height + ), Self::NonFiniteBox { index } => { write!(f, "box at index {} contains a non-finite coordinate", index) } @@ -288,6 +324,100 @@ pub fn clip_boxes(boxes: &[BBoxXYXY], width: u32, height: u32) -> Result Result, BBoxError> { + if boxes.len() != scores.len() { + return Err(BBoxError::LengthMismatch { + boxes: boxes.len(), + scores: scores.len(), + }); + } + + if threshold.is_nan() { + return Err(BBoxError::InvalidScoreThreshold(threshold)); + } + + validate_boxes(boxes)?; + validate_scores(scores)?; + Ok(scores + .iter() + .enumerate() + .filter_map(|(index, score)| (*score >= threshold).then_some(index)) + .collect()) +} + +/// Keep the indices of boxes whose area falls within the requested range. +pub fn filter_boxes_by_area( + boxes: &[BBoxXYXY], + min_area: Option, + max_area: Option, +) -> Result, BBoxError> { + validate_boxes(boxes)?; + validate_area_bounds(min_area, max_area)?; + + Ok(boxes + .iter() + .enumerate() + .filter_map(|(index, bbox)| { + let area = bbox.area(); + let keep_min = min_area.is_none_or(|value| area >= value); + let keep_max = max_area.is_none_or(|value| area <= value); + (keep_min && keep_max).then_some(index) + }) + .collect()) +} + +/// Keep the indices of boxes whose width and height meet the requested minimums. +pub fn filter_boxes_by_min_size( + boxes: &[BBoxXYXY], + min_width: f32, + min_height: f32, +) -> Result, BBoxError> { + validate_boxes(boxes)?; + validate_min_size(min_width, min_height)?; + + Ok(boxes + .iter() + .enumerate() + .filter_map(|(index, bbox)| { + (bbox.width() >= min_width && bbox.height() >= min_height).then_some(index) + }) + .collect()) +} + +/// Clip boxes to image bounds, then keep only boxes whose clipped width and height are large enough. +pub fn clip_and_filter_boxes( + boxes: &[BBoxXYXY], + width: u32, + height: u32, + min_width: f32, + min_height: f32, +) -> Result { + validate_boxes(boxes)?; + validate_image_size(width, height)?; + validate_min_size(min_width, min_height)?; + + let clipped = clip_boxes(boxes, width, height)?; + let mut filtered_boxes = Vec::new(); + let mut kept_indices = Vec::new(); + + for (index, bbox) in clipped.into_iter().enumerate() { + if bbox.width() >= min_width && bbox.height() >= min_height { + filtered_boxes.push(bbox); + kept_indices.push(index); + } + } + + Ok(BoxFilterResult { + boxes: filtered_boxes, + indices: kept_indices, + }) +} + /// Map boxes from one image size to another with direct resize scaling. pub fn resize_boxes( boxes: &[BBoxXYXY], @@ -742,6 +872,42 @@ fn validate_image_size(width: u32, height: u32) -> Result<(), BBoxError> { Ok(()) } +fn validate_area_bounds(min_area: Option, max_area: Option) -> Result<(), BBoxError> { + if let Some(value) = min_area { + if !value.is_finite() || value < 0.0 { + return Err(BBoxError::InvalidMinArea(value)); + } + } + + if let Some(value) = max_area { + if !value.is_finite() || value < 0.0 { + return Err(BBoxError::InvalidMaxArea(value)); + } + } + + if let (Some(min_area), Some(max_area)) = (min_area, max_area) { + if max_area < min_area { + return Err(BBoxError::InvalidAreaRange { min_area, max_area }); + } + } + + Ok(()) +} + +fn validate_min_size(min_width: f32, min_height: f32) -> Result<(), BBoxError> { + let width_ok = min_width.is_finite() && min_width >= 0.0; + let height_ok = min_height.is_finite() && min_height >= 0.0; + + if !width_ok || !height_ok { + return Err(BBoxError::InvalidMinSize { + min_width, + min_height, + }); + } + + Ok(()) +} + fn validate_scores(scores: &[f32]) -> Result<(), BBoxError> { for (index, score) in scores.iter().copied().enumerate() { if !score.is_finite() { @@ -955,6 +1121,76 @@ mod tests { assert_eq!(restored, boxes); } + #[test] + fn filters_boxes_by_score() { + let boxes = vec![ + BBoxXYXY::from_xywh(0.0, 0.0, 10.0, 10.0), + BBoxXYXY::from_xywh(10.0, 10.0, 4.0, 4.0), + BBoxXYXY::from_xywh(20.0, 20.0, 2.0, 2.0), + ]; + let scores = vec![0.9, 0.4, 0.7]; + + let kept = filter_boxes_by_score(&boxes, &scores, 0.5).unwrap(); + + assert_eq!(kept, vec![0, 2]); + } + + #[test] + fn filters_boxes_by_area() { + let boxes = vec![ + BBoxXYXY::from_xywh(0.0, 0.0, 10.0, 10.0), + BBoxXYXY::from_xywh(10.0, 10.0, 4.0, 4.0), + BBoxXYXY::from_xywh(20.0, 20.0, 2.0, 2.0), + ]; + + let kept = filter_boxes_by_area(&boxes, Some(10.0), Some(20.0)).unwrap(); + + assert_eq!(kept, vec![1]); + } + + #[test] + fn filters_boxes_by_min_size() { + let boxes = vec![ + BBoxXYXY::from_xywh(0.0, 0.0, 10.0, 10.0), + BBoxXYXY::from_xywh(10.0, 10.0, 4.0, 6.0), + BBoxXYXY::from_xywh(20.0, 20.0, 2.0, 8.0), + ]; + + let kept = filter_boxes_by_min_size(&boxes, 4.0, 7.0).unwrap(); + + assert_eq!(kept, vec![0]); + } + + #[test] + fn clips_and_filters_boxes() { + let boxes = vec![ + BBoxXYXY::from_xywh(-4.0, 1.0, 8.0, 8.0), + BBoxXYXY::from_xywh(2.0, 2.0, 1.0, 1.0), + BBoxXYXY::from_xywh(8.0, 8.0, 6.0, 6.0), + ]; + + let result = clip_and_filter_boxes(&boxes, 10, 10, 2.0, 2.0).unwrap(); + + assert_eq!(result.indices, vec![0, 2]); + assert_eq!( + result.boxes, + vec![ + BBoxXYXY { + x1: 0.0, + y1: 1.0, + x2: 4.0, + y2: 9.0, + }, + BBoxXYXY { + x1: 8.0, + y1: 8.0, + x2: 10.0, + y2: 10.0, + }, + ] + ); + } + #[test] fn keeps_highest_scoring_boxes() { let boxes = vec![ @@ -1294,6 +1530,28 @@ mod tests { height: 10, } ); + assert!(matches!( + filter_boxes_by_score(&boxes, &[0.5], f32::NAN).unwrap_err(), + BBoxError::InvalidScoreThreshold(value) if value.is_nan() + )); + assert_eq!( + filter_boxes_by_area(&boxes, Some(-1.0), None).unwrap_err(), + BBoxError::InvalidMinArea(-1.0) + ); + assert_eq!( + filter_boxes_by_area(&boxes, Some(5.0), Some(4.0)).unwrap_err(), + BBoxError::InvalidAreaRange { + min_area: 5.0, + max_area: 4.0, + } + ); + assert_eq!( + filter_boxes_by_min_size(&boxes, -1.0, 0.0).unwrap_err(), + BBoxError::InvalidMinSize { + min_width: -1.0, + min_height: 0.0, + } + ); assert_eq!( soft_nms( &boxes, diff --git a/src/lib.rs b/src/lib.rs index 839e5fe..ce61180 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,10 +28,11 @@ mod python; /// Error returned by box postprocessing operations. pub use bbox::{ - batched_nms, batched_soft_nms, clip_boxes, iou, letterbox_boxes, multiclass_nms, + batched_nms, batched_soft_nms, clip_and_filter_boxes, clip_boxes, filter_boxes_by_area, + filter_boxes_by_min_size, filter_boxes_by_score, iou, letterbox_boxes, multiclass_nms, multiclass_soft_nms, nms, nms_with_options, resize_boxes, soft_nms, unletterbox_boxes, - xywh_to_xyxy, xyxy_to_xywh, BBoxError, BBoxXYWH, BBoxXYXY, Detection, NmsOptions, - SoftNmsMethod, SoftNmsOptions, + xywh_to_xyxy, xyxy_to_xywh, BBoxError, BBoxXYWH, BBoxXYXY, BoxFilterResult, Detection, + NmsOptions, SoftNmsMethod, SoftNmsOptions, }; /// Error returned by crop operations. pub use crop::{center_crop_image, crop_image, CropError, CropInfo, CropResult}; diff --git a/src/python.rs b/src/python.rs index 78f34f1..45c8737 100644 --- a/src/python.rs +++ b/src/python.rs @@ -380,6 +380,60 @@ fn detections_to_pydict<'py>( Ok(result) } +fn filtered_indices_to_pydict<'py>( + py: Python<'py>, + boxes: &[BBoxXYXY], + indices: Vec, + scores: Option<&[f32]>, +) -> PyResult> { + let selected_boxes = indices + .iter() + .map(|&index| boxes[index]) + .collect::>(); + let selected_indices = indices + .iter() + .map(|&index| index as i64) + .collect::>(); + + let result = PyDict::new(py); + result.set_item( + "indices", + PyArray1::from_owned_array(py, Array1::from_vec(selected_indices)), + )?; + result.set_item("boxes", boxes_xyxy_to_numpy(py, selected_boxes)?)?; + + if let Some(scores) = scores { + let selected_scores = indices + .iter() + .map(|&index| scores[index]) + .collect::>(); + result.set_item( + "scores", + PyArray1::from_owned_array(py, Array1::from_vec(selected_scores)), + )?; + } + + Ok(result) +} + +fn box_filter_result_to_pydict<'py>( + py: Python<'py>, + result_data: bbox::BoxFilterResult, +) -> PyResult> { + let result = PyDict::new(py); + let indices = result_data + .indices + .iter() + .map(|&index| index as i64) + .collect::>(); + result.set_item( + "indices", + PyArray1::from_owned_array(py, Array1::from_vec(indices)), + )?; + result.set_item("boxes", boxes_xyxy_to_numpy(py, result_data.boxes)?)?; + Ok(result) +} + #[pyfunction(name = "compute_letterbox")] fn compute_letterbox_py<'py>( py: Python<'py>, @@ -484,6 +538,61 @@ fn clip_boxes_numpy<'py>( boxes_xyxy_to_numpy(py, clipped) } +#[pyfunction] +fn filter_boxes_by_score_numpy<'py>( + py: Python<'py>, + boxes: PyReadonlyArray2<'_, f32>, + scores: PyReadonlyArray1<'_, f32>, + threshold: f32, +) -> PyResult> { + let boxes = boxes_from_numpy(boxes)?; + let scores = scores_from_numpy(scores); + let kept = bbox::filter_boxes_by_score(&boxes, &scores, threshold).map_err(map_bbox_error)?; + filtered_indices_to_pydict(py, &boxes, kept, Some(&scores)) +} + +#[pyfunction] +#[pyo3(signature = (boxes, min_area=None, max_area=None))] +fn filter_boxes_by_area_numpy<'py>( + py: Python<'py>, + boxes: PyReadonlyArray2<'_, f32>, + min_area: Option, + max_area: Option, +) -> PyResult> { + let boxes = boxes_from_numpy(boxes)?; + let kept = bbox::filter_boxes_by_area(&boxes, min_area, max_area).map_err(map_bbox_error)?; + filtered_indices_to_pydict(py, &boxes, kept, None) +} + +#[pyfunction] +fn filter_boxes_by_min_size_numpy<'py>( + py: Python<'py>, + boxes: PyReadonlyArray2<'_, f32>, + min_width: f32, + min_height: f32, +) -> PyResult> { + let boxes = boxes_from_numpy(boxes)?; + let kept = + bbox::filter_boxes_by_min_size(&boxes, min_width, min_height).map_err(map_bbox_error)?; + filtered_indices_to_pydict(py, &boxes, kept, None) +} + +#[pyfunction] +#[pyo3(signature = (boxes, width, height, min_width=0.0, min_height=0.0))] +fn clip_and_filter_boxes_numpy<'py>( + py: Python<'py>, + boxes: PyReadonlyArray2<'_, f32>, + width: u32, + height: u32, + min_width: f32, + min_height: f32, +) -> PyResult> { + let boxes = boxes_from_numpy(boxes)?; + let result_data = bbox::clip_and_filter_boxes(&boxes, width, height, min_width, min_height) + .map_err(map_bbox_error)?; + box_filter_result_to_pydict(py, result_data) +} + #[pyfunction] fn resize_boxes_numpy<'py>( py: Python<'py>, @@ -1004,6 +1113,10 @@ fn rusty_cv(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(xyxy_to_xywh_numpy, m)?)?; m.add_function(wrap_pyfunction!(xywh_to_xyxy_numpy, m)?)?; m.add_function(wrap_pyfunction!(clip_boxes_numpy, m)?)?; + m.add_function(wrap_pyfunction!(filter_boxes_by_score_numpy, m)?)?; + m.add_function(wrap_pyfunction!(filter_boxes_by_area_numpy, m)?)?; + m.add_function(wrap_pyfunction!(filter_boxes_by_min_size_numpy, m)?)?; + m.add_function(wrap_pyfunction!(clip_and_filter_boxes_numpy, m)?)?; m.add_function(wrap_pyfunction!(resize_boxes_numpy, m)?)?; m.add_function(wrap_pyfunction!(letterbox_boxes_numpy, m)?)?; m.add_function(wrap_pyfunction!(unletterbox_boxes_numpy, m)?)?; From ad463be05f0fa008e7cac34e60f3053ac2171474 Mon Sep 17 00:00:00 2001 From: Jules Giraud Date: Tue, 26 May 2026 19:19:02 +0200 Subject: [PATCH 2/3] Add fused detection postprocessing --- scripts/python_smoke_test.py | 32 ++++ src/bbox.rs | 320 +++++++++++++++++++++++++++++++++++ src/lib.rs | 7 +- src/python.rs | 181 +++++++++++++++++++- 4 files changed, 536 insertions(+), 4 deletions(-) diff --git a/scripts/python_smoke_test.py b/scripts/python_smoke_test.py index 0f25044..ce12858 100644 --- a/scripts/python_smoke_test.py +++ b/scripts/python_smoke_test.py @@ -298,6 +298,38 @@ def main() -> None: assert multiclass_soft["indices"].tolist() == [0, 1, 2, 1] assert multiclass_soft["class_ids"].tolist() == [0, 1, 1, 0] assert np.isclose(float(multiclass_soft["scores"][3]), 0.25546217, atol=1e-6) + + postprocessed = rusty_cv.postprocess_detections( + np.array( + [ + [160.0, 240.0, 480.0, 400.0], + [170.0, 250.0, 490.0, 410.0], + [100.0, 240.0, 140.0, 280.0], + ], + dtype=np.float32, + ), + np.array([0.95, 0.90, 0.70], dtype=np.float32), + class_ids=np.array([0, 0, 1], dtype=np.int64), + geometry_mode="letterbox", + processed_width=640, + processed_height=640, + original_width=400, + original_height=200, + clip=True, + ) + assert postprocessed["indices"].tolist() == [0, 2] + assert postprocessed["class_ids"].tolist() == [0, 1] + assert np.allclose( + postprocessed["boxes"], + np.array( + [ + [100.0, 50.0, 300.0, 150.0], + [62.5, 50.0, 87.5, 75.0], + ], + dtype=np.float32, + ), + ) + assert np.isclose(rusty_cv.iou((0.0, 0.0, 10.0, 10.0), (5.0, 5.0, 15.0, 15.0)), 25.0 / 175.0) print("python smoke test: ok") diff --git a/src/bbox.rs b/src/bbox.rs index a3c945c..5ddfb3b 100644 --- a/src/bbox.rs +++ b/src/bbox.rs @@ -172,6 +172,61 @@ pub struct BoxFilterResult { pub indices: Vec, } +/// Geometry remapping mode for fused detection postprocessing. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum BoxRemap { + None, + Current { + width: u32, + height: u32, + }, + Resize { + processed_width: u32, + processed_height: u32, + original_width: u32, + original_height: u32, + }, + Letterbox { + processed_width: u32, + processed_height: u32, + original_width: u32, + original_height: u32, + }, +} + +/// Options for fused detection postprocessing. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct PostprocessOptions { + pub iou_threshold: f32, + pub score_threshold: f32, + pub pre_nms_top_k: Option, + pub max_detections: Option, + pub min_width: f32, + pub min_height: f32, + pub clip: bool, +} + +impl Default for PostprocessOptions { + fn default() -> Self { + Self { + iou_threshold: 0.5, + score_threshold: f32::NEG_INFINITY, + pre_nms_top_k: None, + max_detections: None, + min_width: 0.0, + min_height: 0.0, + clip: false, + } + } +} + +/// Result returned by fused detection postprocessing. +#[derive(Debug, Clone, PartialEq)] +pub struct PostprocessResult { + pub boxes: Vec, + pub detections: Vec, +} + /// Errors for box postprocessing operations. #[derive(Debug, Clone, PartialEq)] pub enum BBoxError { @@ -502,6 +557,79 @@ pub fn unletterbox_boxes( .collect()) } +/// Run fused detection postprocessing over remapped boxes and class-aware NMS. +pub fn postprocess_detections( + boxes: &[BBoxXYXY], + scores: &[f32], + class_ids: &[usize], + remap: BoxRemap, + options: &PostprocessOptions, + soft_nms_options: Option<&SoftNmsOptions>, +) -> Result { + validate_postprocess_inputs(boxes, scores, class_ids, options)?; + + let (mut remapped_boxes, clip_bounds) = remap_boxes_for_postprocess(boxes, remap)?; + if options.clip { + if let Some((width, height)) = clip_bounds { + remapped_boxes = clip_boxes(&remapped_boxes, width, height)?; + } + } + + let candidate_indices = + filter_boxes_by_min_size(&remapped_boxes, options.min_width, options.min_height)?; + let candidate_boxes = candidate_indices + .iter() + .map(|&index| remapped_boxes[index]) + .collect::>(); + let candidate_scores = candidate_indices + .iter() + .map(|&index| scores[index]) + .collect::>(); + let candidate_class_ids = candidate_indices + .iter() + .map(|&index| class_ids[index]) + .collect::>(); + + let detections = if let Some(soft_options) = soft_nms_options { + batched_soft_nms( + &candidate_boxes, + &candidate_scores, + &candidate_class_ids, + soft_options, + )? + } else { + let nms_options = NmsOptions { + iou_threshold: options.iou_threshold, + score_threshold: options.score_threshold, + pre_nms_top_k: options.pre_nms_top_k, + max_detections: options.max_detections, + }; + batched_nms( + &candidate_boxes, + &candidate_scores, + &candidate_class_ids, + &nms_options, + )? + }; + + let mut ordered_boxes = Vec::with_capacity(detections.len()); + let mut mapped_detections = Vec::with_capacity(detections.len()); + for detection in detections { + let original_index = candidate_indices[detection.box_index]; + ordered_boxes.push(remapped_boxes[original_index]); + mapped_detections.push(Detection { + box_index: original_index, + class_id: detection.class_id, + score: detection.score, + }); + } + + Ok(PostprocessResult { + boxes: ordered_boxes, + detections: mapped_detections, + }) +} + /// Run single-class non-maximum suppression with custom options. /// /// Returns the kept indices in descending score order. @@ -828,6 +956,82 @@ fn validate_nms_inputs( Ok(()) } +fn validate_postprocess_inputs( + boxes: &[BBoxXYXY], + scores: &[f32], + class_ids: &[usize], + options: &PostprocessOptions, +) -> Result<(), BBoxError> { + if boxes.len() != scores.len() { + return Err(BBoxError::LengthMismatch { + boxes: boxes.len(), + scores: scores.len(), + }); + } + + if boxes.len() != class_ids.len() { + return Err(BBoxError::ClassLengthMismatch { + boxes: boxes.len(), + class_ids: class_ids.len(), + }); + } + + let nms_options = NmsOptions { + iou_threshold: options.iou_threshold, + score_threshold: options.score_threshold, + pre_nms_top_k: options.pre_nms_top_k, + max_detections: options.max_detections, + }; + validate_thresholds(&nms_options)?; + validate_boxes(boxes)?; + validate_scores(scores)?; + validate_min_size(options.min_width, options.min_height)?; + Ok(()) +} + +fn remap_boxes_for_postprocess( + boxes: &[BBoxXYXY], + remap: BoxRemap, +) -> Result<(Vec, Option<(u32, u32)>), BBoxError> { + match remap { + BoxRemap::None => Ok((boxes.to_vec(), None)), + BoxRemap::Current { width, height } => { + validate_image_size(width, height)?; + Ok((boxes.to_vec(), Some((width, height)))) + } + BoxRemap::Resize { + processed_width, + processed_height, + original_width, + original_height, + } => Ok(( + resize_boxes( + boxes, + processed_width, + processed_height, + original_width, + original_height, + )?, + Some((original_width, original_height)), + )), + BoxRemap::Letterbox { + processed_width, + processed_height, + original_width, + original_height, + } => Ok(( + unletterbox_boxes( + boxes, + original_width, + original_height, + processed_width, + processed_height, + )?, + Some((original_width, original_height)), + )), + } +} + fn validate_thresholds(options: &NmsOptions) -> Result<(), BBoxError> { if !(0.0..=1.0).contains(&options.iou_threshold) { return Err(BBoxError::InvalidIouThreshold(options.iou_threshold)); @@ -1191,6 +1395,122 @@ mod tests { ); } + #[test] + fn postprocesses_current_space_boxes() { + let boxes = vec![ + BBoxXYXY::from_xywh(-5.0, 0.0, 8.0, 8.0), + BBoxXYXY::from_xywh(0.0, 0.0, 1.0, 1.0), + BBoxXYXY::from_xywh(2.0, 2.0, 4.0, 4.0), + ]; + let scores = vec![0.9, 0.8, 0.7]; + let class_ids = vec![0usize, 0usize, 0usize]; + let options = PostprocessOptions { + min_width: 2.0, + min_height: 2.0, + clip: true, + ..PostprocessOptions::default() + }; + + let result = postprocess_detections( + &boxes, + &scores, + &class_ids, + BoxRemap::Current { + width: 4, + height: 4, + }, + &options, + None, + ) + .unwrap(); + + assert_eq!(result.detections.len(), 2); + assert_eq!(result.detections[0].box_index, 0); + assert_eq!(result.detections[1].box_index, 2); + assert_eq!( + result.boxes, + vec![ + BBoxXYXY { + x1: 0.0, + y1: 0.0, + x2: 3.0, + y2: 4.0, + }, + BBoxXYXY { + x1: 2.0, + y1: 2.0, + x2: 4.0, + y2: 4.0, + }, + ] + ); + } + + #[test] + fn postprocesses_letterboxed_boxes_back_to_original_space() { + let boxes = vec![ + BBoxXYXY { + x1: 160.0, + y1: 240.0, + x2: 480.0, + y2: 400.0, + }, + BBoxXYXY { + x1: 170.0, + y1: 250.0, + x2: 490.0, + y2: 410.0, + }, + BBoxXYXY { + x1: 100.0, + y1: 240.0, + x2: 140.0, + y2: 280.0, + }, + ]; + let scores = vec![0.95, 0.90, 0.70]; + let class_ids = vec![0usize, 0usize, 1usize]; + + let result = postprocess_detections( + &boxes, + &scores, + &class_ids, + BoxRemap::Letterbox { + processed_width: 640, + processed_height: 640, + original_width: 400, + original_height: 200, + }, + &PostprocessOptions { + clip: true, + ..PostprocessOptions::default() + }, + None, + ) + .unwrap(); + + assert_eq!(result.detections.len(), 2); + assert_eq!(result.detections[0].box_index, 0); + assert_eq!(result.detections[1].box_index, 2); + assert_eq!( + result.boxes, + vec![ + BBoxXYXY { + x1: 100.0, + y1: 50.0, + x2: 300.0, + y2: 150.0, + }, + BBoxXYXY { + x1: 62.5, + y1: 50.0, + x2: 87.5, + y2: 75.0, + }, + ] + ); + } + #[test] fn keeps_highest_scoring_boxes() { let boxes = vec![ diff --git a/src/lib.rs b/src/lib.rs index ce61180..0cdf469 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,9 +30,10 @@ mod python; pub use bbox::{ batched_nms, batched_soft_nms, clip_and_filter_boxes, clip_boxes, filter_boxes_by_area, filter_boxes_by_min_size, filter_boxes_by_score, iou, letterbox_boxes, multiclass_nms, - multiclass_soft_nms, nms, nms_with_options, resize_boxes, soft_nms, unletterbox_boxes, - xywh_to_xyxy, xyxy_to_xywh, BBoxError, BBoxXYWH, BBoxXYXY, BoxFilterResult, Detection, - NmsOptions, SoftNmsMethod, SoftNmsOptions, + multiclass_soft_nms, nms, nms_with_options, postprocess_detections, resize_boxes, soft_nms, + unletterbox_boxes, xywh_to_xyxy, xyxy_to_xywh, BBoxError, BBoxXYWH, BBoxXYXY, BoxFilterResult, + BoxRemap, Detection, NmsOptions, PostprocessOptions, PostprocessResult, SoftNmsMethod, + SoftNmsOptions, }; /// Error returned by crop operations. pub use crop::{center_crop_image, crop_image, CropError, CropInfo, CropResult}; diff --git a/src/python.rs b/src/python.rs index 45c8737..02faa16 100644 --- a/src/python.rs +++ b/src/python.rs @@ -9,7 +9,8 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyModule}; use crate::bbox::{ - self, BBoxError, BBoxXYWH, BBoxXYXY, Detection, NmsOptions, SoftNmsMethod, SoftNmsOptions, + self, BBoxError, BBoxXYWH, BBoxXYXY, BoxRemap, Detection, NmsOptions, PostprocessOptions, + SoftNmsMethod, SoftNmsOptions, }; use crate::crop::{self, CropError}; use crate::letterbox::{self, LetterboxError}; @@ -434,6 +435,89 @@ fn box_filter_result_to_pydict<'py>( Ok(result) } +fn postprocess_result_to_pydict<'py>( + py: Python<'py>, + result_data: bbox::PostprocessResult, +) -> PyResult> { + let mut indices = Vec::with_capacity(result_data.detections.len()); + let mut class_ids = Vec::with_capacity(result_data.detections.len()); + let mut scores = Vec::with_capacity(result_data.detections.len()); + + for detection in result_data.detections { + indices.push(detection.box_index as i64); + class_ids.push(detection.class_id as i64); + scores.push(detection.score); + } + + let result = PyDict::new(py); + result.set_item("boxes", boxes_xyxy_to_numpy(py, result_data.boxes)?)?; + result.set_item( + "indices", + PyArray1::from_owned_array(py, Array1::from_vec(indices)), + )?; + result.set_item( + "class_ids", + PyArray1::from_owned_array(py, Array1::from_vec(class_ids)), + )?; + result.set_item( + "scores", + PyArray1::from_owned_array(py, Array1::from_vec(scores)), + )?; + Ok(result) +} + +#[allow(clippy::too_many_arguments)] +fn parse_postprocess_remap( + geometry_mode: Option<&str>, + processed_width: Option, + processed_height: Option, + original_width: Option, + original_height: Option, +) -> PyResult { + match geometry_mode.map(|value| value.to_ascii_lowercase()) { + None => Ok(BoxRemap::None), + Some(mode) if mode == "current" => Ok(BoxRemap::Current { + width: processed_width.ok_or_else(|| { + PyValueError::new_err("processed_width is required for geometry_mode='current'") + })?, + height: processed_height.ok_or_else(|| { + PyValueError::new_err("processed_height is required for geometry_mode='current'") + })?, + }), + Some(mode) if mode == "resize" => Ok(BoxRemap::Resize { + processed_width: processed_width.ok_or_else(|| { + PyValueError::new_err("processed_width is required for geometry_mode='resize'") + })?, + processed_height: processed_height.ok_or_else(|| { + PyValueError::new_err("processed_height is required for geometry_mode='resize'") + })?, + original_width: original_width.ok_or_else(|| { + PyValueError::new_err("original_width is required for geometry_mode='resize'") + })?, + original_height: original_height.ok_or_else(|| { + PyValueError::new_err("original_height is required for geometry_mode='resize'") + })?, + }), + Some(mode) if mode == "letterbox" => Ok(BoxRemap::Letterbox { + processed_width: processed_width.ok_or_else(|| { + PyValueError::new_err("processed_width is required for geometry_mode='letterbox'") + })?, + processed_height: processed_height.ok_or_else(|| { + PyValueError::new_err("processed_height is required for geometry_mode='letterbox'") + })?, + original_width: original_width.ok_or_else(|| { + PyValueError::new_err("original_width is required for geometry_mode='letterbox'") + })?, + original_height: original_height.ok_or_else(|| { + PyValueError::new_err("original_height is required for geometry_mode='letterbox'") + })?, + }), + Some(mode) => Err(PyValueError::new_err(format!( + "unsupported geometry_mode '{mode}'. Use current, resize, or letterbox" + ))), + } +} + #[pyfunction(name = "compute_letterbox")] fn compute_letterbox_py<'py>( py: Python<'py>, @@ -886,6 +970,100 @@ fn multiclass_soft_nms_py<'py>( detections_to_pydict(py, detections) } +#[pyfunction(name = "postprocess_detections")] +#[pyo3(signature = ( + boxes, + scores, + class_ids=None, + geometry_mode=None, + processed_width=None, + processed_height=None, + original_width=None, + original_height=None, + clip=false, + iou_threshold=0.5, + score_threshold=None, + pre_nms_top_k=None, + max_detections=None, + min_width=0.0, + min_height=0.0, + soft=false, + soft_method=None, + sigma=0.5 +))] +#[allow(clippy::too_many_arguments)] +fn postprocess_detections_py<'py>( + py: Python<'py>, + boxes: PyReadonlyArray2<'_, f32>, + scores: PyReadonlyArray1<'_, f32>, + class_ids: Option>, + geometry_mode: Option<&str>, + processed_width: Option, + processed_height: Option, + original_width: Option, + original_height: Option, + clip: bool, + iou_threshold: f32, + score_threshold: Option, + pre_nms_top_k: Option, + max_detections: Option, + min_width: f32, + min_height: f32, + soft: bool, + soft_method: Option<&str>, + sigma: f32, +) -> PyResult> { + let boxes = boxes_from_numpy(boxes)?; + let scores = scores_from_numpy(scores); + let class_ids = if let Some(class_ids) = class_ids { + class_ids_from_numpy(class_ids)? + } else { + vec![0usize; boxes.len()] + }; + let remap = parse_postprocess_remap( + geometry_mode, + processed_width, + processed_height, + original_width, + original_height, + )?; + let options = PostprocessOptions { + iou_threshold, + score_threshold: score_threshold.unwrap_or(f32::NEG_INFINITY), + pre_nms_top_k, + max_detections, + min_width, + min_height, + clip, + }; + let soft_options = if soft { + Some(soft_nms_options( + parse_soft_nms_method(soft_method)?, + iou_threshold, + score_threshold, + sigma, + pre_nms_top_k, + max_detections, + )) + } else { + None + }; + + let result_data = py + .detach(move || { + bbox::postprocess_detections( + &boxes, + &scores, + &class_ids, + remap, + &options, + soft_options.as_ref(), + ) + }) + .map_err(map_bbox_error)?; + postprocess_result_to_pydict(py, result_data) +} + #[pyfunction] #[pyo3(signature = (input_bytes, x, y, width, height, output_format=None))] fn crop_image<'py>( @@ -1126,6 +1304,7 @@ fn rusty_cv(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(soft_nms_py, m)?)?; m.add_function(wrap_pyfunction!(batched_soft_nms_py, m)?)?; m.add_function(wrap_pyfunction!(multiclass_soft_nms_py, m)?)?; + m.add_function(wrap_pyfunction!(postprocess_detections_py, m)?)?; m.add_function(wrap_pyfunction!(crop_image, m)?)?; m.add_function(wrap_pyfunction!(crop_image_numpy, m)?)?; m.add_function(wrap_pyfunction!(center_crop_image, m)?)?; From cab1e053ec27ca4bae497e60d7e9cbab693d769d Mon Sep 17 00:00:00 2001 From: Jules Giraud Date: Wed, 27 May 2026 00:47:02 +0200 Subject: [PATCH 3/3] Fix clippy type complexity in bbox remap --- src/bbox.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/bbox.rs b/src/bbox.rs index b58bf67..2d0f2b6 100644 --- a/src/bbox.rs +++ b/src/bbox.rs @@ -226,6 +226,9 @@ pub struct PostprocessResult { pub boxes: Vec, pub detections: Vec, } + +type ClipBounds = Option<(u32, u32)>; +type RemappedBoxes = (Vec, ClipBounds); /// Errors for box postprocessing operations. #[derive(Debug, Clone, PartialEq)] pub enum BBoxError { @@ -991,7 +994,7 @@ fn validate_postprocess_inputs( fn remap_boxes_for_postprocess( boxes: &[BBoxXYXY], remap: BoxRemap, -) -> Result<(Vec, Option<(u32, u32)>), BBoxError> { +) -> Result { match remap { BoxRemap::None => Ok((boxes.to_vec(), None)), BoxRemap::Current { width, height } => {