diff --git a/scripts/python_smoke_test.py b/scripts/python_smoke_test.py index c7b6b8c..fa322b1 100644 --- a/scripts/python_smoke_test.py +++ b/scripts/python_smoke_test.py @@ -317,6 +317,38 @@ def main() -> None: assert multiclass_soft["indices"].tolist() == [0, 1, 2, 1] assert multiclass_soft["class_ids"].tolist() == [0, 1, 1, 0] assert np.isclose(float(multiclass_soft["scores"][3]), 0.25546217, atol=1e-6) + + postprocessed = rusty_cv.postprocess_detections( + np.array( + [ + [160.0, 240.0, 480.0, 400.0], + [170.0, 250.0, 490.0, 410.0], + [100.0, 240.0, 140.0, 280.0], + ], + dtype=np.float32, + ), + np.array([0.95, 0.90, 0.70], dtype=np.float32), + class_ids=np.array([0, 0, 1], dtype=np.int64), + geometry_mode="letterbox", + processed_width=640, + processed_height=640, + original_width=400, + original_height=200, + clip=True, + ) + assert postprocessed["indices"].tolist() == [0, 2] + assert postprocessed["class_ids"].tolist() == [0, 1] + assert np.allclose( + postprocessed["boxes"], + np.array( + [ + [100.0, 50.0, 300.0, 150.0], + [62.5, 50.0, 87.5, 75.0], + ], + dtype=np.float32, + ), + ) + assert np.isclose(rusty_cv.iou((0.0, 0.0, 10.0, 10.0), (5.0, 5.0, 15.0, 15.0)), 25.0 / 175.0) print("python smoke test: ok") diff --git a/src/bbox.rs b/src/bbox.rs index a3c945c..2d0f2b6 100644 --- a/src/bbox.rs +++ b/src/bbox.rs @@ -172,6 +172,63 @@ pub struct BoxFilterResult { pub indices: Vec, } +/// Geometry remapping mode for fused detection postprocessing. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum BoxRemap { + None, + Current { + width: u32, + height: u32, + }, + Resize { + processed_width: u32, + processed_height: u32, + original_width: u32, + original_height: u32, + }, + Letterbox { + processed_width: u32, + processed_height: u32, + original_width: u32, + original_height: u32, + }, +} + +/// Options for fused detection postprocessing. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct PostprocessOptions { + pub iou_threshold: f32, + pub score_threshold: f32, + pub pre_nms_top_k: Option, + pub max_detections: Option, + pub min_width: f32, + pub min_height: f32, + pub clip: bool, +} + +impl Default for PostprocessOptions { + fn default() -> Self { + Self { + iou_threshold: 0.5, + score_threshold: f32::NEG_INFINITY, + pre_nms_top_k: None, + max_detections: None, + min_width: 0.0, + min_height: 0.0, + clip: false, + } + } +} + +/// Result returned by fused detection postprocessing. +#[derive(Debug, Clone, PartialEq)] +pub struct PostprocessResult { + pub boxes: Vec, + pub detections: Vec, +} + +type ClipBounds = Option<(u32, u32)>; +type RemappedBoxes = (Vec, ClipBounds); /// Errors for box postprocessing operations. #[derive(Debug, Clone, PartialEq)] pub enum BBoxError { @@ -502,6 +559,79 @@ pub fn unletterbox_boxes( .collect()) } +/// Run fused detection postprocessing over remapped boxes and class-aware NMS. +pub fn postprocess_detections( + boxes: &[BBoxXYXY], + scores: &[f32], + class_ids: &[usize], + remap: BoxRemap, + options: &PostprocessOptions, + soft_nms_options: Option<&SoftNmsOptions>, +) -> Result { + validate_postprocess_inputs(boxes, scores, class_ids, options)?; + + let (mut remapped_boxes, clip_bounds) = remap_boxes_for_postprocess(boxes, remap)?; + if options.clip { + if let Some((width, height)) = clip_bounds { + remapped_boxes = clip_boxes(&remapped_boxes, width, height)?; + } + } + + let candidate_indices = + filter_boxes_by_min_size(&remapped_boxes, options.min_width, options.min_height)?; + let candidate_boxes = candidate_indices + .iter() + .map(|&index| remapped_boxes[index]) + .collect::>(); + let candidate_scores = candidate_indices + .iter() + .map(|&index| scores[index]) + .collect::>(); + let candidate_class_ids = candidate_indices + .iter() + .map(|&index| class_ids[index]) + .collect::>(); + + let detections = if let Some(soft_options) = soft_nms_options { + batched_soft_nms( + &candidate_boxes, + &candidate_scores, + &candidate_class_ids, + soft_options, + )? + } else { + let nms_options = NmsOptions { + iou_threshold: options.iou_threshold, + score_threshold: options.score_threshold, + pre_nms_top_k: options.pre_nms_top_k, + max_detections: options.max_detections, + }; + batched_nms( + &candidate_boxes, + &candidate_scores, + &candidate_class_ids, + &nms_options, + )? + }; + + let mut ordered_boxes = Vec::with_capacity(detections.len()); + let mut mapped_detections = Vec::with_capacity(detections.len()); + for detection in detections { + let original_index = candidate_indices[detection.box_index]; + ordered_boxes.push(remapped_boxes[original_index]); + mapped_detections.push(Detection { + box_index: original_index, + class_id: detection.class_id, + score: detection.score, + }); + } + + Ok(PostprocessResult { + boxes: ordered_boxes, + detections: mapped_detections, + }) +} + /// Run single-class non-maximum suppression with custom options. /// /// Returns the kept indices in descending score order. @@ -828,6 +958,82 @@ fn validate_nms_inputs( Ok(()) } +fn validate_postprocess_inputs( + boxes: &[BBoxXYXY], + scores: &[f32], + class_ids: &[usize], + options: &PostprocessOptions, +) -> Result<(), BBoxError> { + if boxes.len() != scores.len() { + return Err(BBoxError::LengthMismatch { + boxes: boxes.len(), + scores: scores.len(), + }); + } + + if boxes.len() != class_ids.len() { + return Err(BBoxError::ClassLengthMismatch { + boxes: boxes.len(), + class_ids: class_ids.len(), + }); + } + + let nms_options = NmsOptions { + iou_threshold: options.iou_threshold, + score_threshold: options.score_threshold, + pre_nms_top_k: options.pre_nms_top_k, + max_detections: options.max_detections, + }; + validate_thresholds(&nms_options)?; + validate_boxes(boxes)?; + validate_scores(scores)?; + validate_min_size(options.min_width, options.min_height)?; + Ok(()) +} + +fn remap_boxes_for_postprocess( + boxes: &[BBoxXYXY], + remap: BoxRemap, +) -> Result { + match remap { + BoxRemap::None => Ok((boxes.to_vec(), None)), + BoxRemap::Current { width, height } => { + validate_image_size(width, height)?; + Ok((boxes.to_vec(), Some((width, height)))) + } + BoxRemap::Resize { + processed_width, + processed_height, + original_width, + original_height, + } => Ok(( + resize_boxes( + boxes, + processed_width, + processed_height, + original_width, + original_height, + )?, + Some((original_width, original_height)), + )), + BoxRemap::Letterbox { + processed_width, + processed_height, + original_width, + original_height, + } => Ok(( + unletterbox_boxes( + boxes, + original_width, + original_height, + processed_width, + processed_height, + )?, + Some((original_width, original_height)), + )), + } +} + fn validate_thresholds(options: &NmsOptions) -> Result<(), BBoxError> { if !(0.0..=1.0).contains(&options.iou_threshold) { return Err(BBoxError::InvalidIouThreshold(options.iou_threshold)); @@ -1191,6 +1397,122 @@ mod tests { ); } + #[test] + fn postprocesses_current_space_boxes() { + let boxes = vec![ + BBoxXYXY::from_xywh(-5.0, 0.0, 8.0, 8.0), + BBoxXYXY::from_xywh(0.0, 0.0, 1.0, 1.0), + BBoxXYXY::from_xywh(2.0, 2.0, 4.0, 4.0), + ]; + let scores = vec![0.9, 0.8, 0.7]; + let class_ids = vec![0usize, 0usize, 0usize]; + let options = PostprocessOptions { + min_width: 2.0, + min_height: 2.0, + clip: true, + ..PostprocessOptions::default() + }; + + let result = postprocess_detections( + &boxes, + &scores, + &class_ids, + BoxRemap::Current { + width: 4, + height: 4, + }, + &options, + None, + ) + .unwrap(); + + assert_eq!(result.detections.len(), 2); + assert_eq!(result.detections[0].box_index, 0); + assert_eq!(result.detections[1].box_index, 2); + assert_eq!( + result.boxes, + vec![ + BBoxXYXY { + x1: 0.0, + y1: 0.0, + x2: 3.0, + y2: 4.0, + }, + BBoxXYXY { + x1: 2.0, + y1: 2.0, + x2: 4.0, + y2: 4.0, + }, + ] + ); + } + + #[test] + fn postprocesses_letterboxed_boxes_back_to_original_space() { + let boxes = vec![ + BBoxXYXY { + x1: 160.0, + y1: 240.0, + x2: 480.0, + y2: 400.0, + }, + BBoxXYXY { + x1: 170.0, + y1: 250.0, + x2: 490.0, + y2: 410.0, + }, + BBoxXYXY { + x1: 100.0, + y1: 240.0, + x2: 140.0, + y2: 280.0, + }, + ]; + let scores = vec![0.95, 0.90, 0.70]; + let class_ids = vec![0usize, 0usize, 1usize]; + + let result = postprocess_detections( + &boxes, + &scores, + &class_ids, + BoxRemap::Letterbox { + processed_width: 640, + processed_height: 640, + original_width: 400, + original_height: 200, + }, + &PostprocessOptions { + clip: true, + ..PostprocessOptions::default() + }, + None, + ) + .unwrap(); + + assert_eq!(result.detections.len(), 2); + assert_eq!(result.detections[0].box_index, 0); + assert_eq!(result.detections[1].box_index, 2); + assert_eq!( + result.boxes, + vec![ + BBoxXYXY { + x1: 100.0, + y1: 50.0, + x2: 300.0, + y2: 150.0, + }, + BBoxXYXY { + x1: 62.5, + y1: 50.0, + x2: 87.5, + y2: 75.0, + }, + ] + ); + } + #[test] fn keeps_highest_scoring_boxes() { let boxes = vec![ diff --git a/src/lib.rs b/src/lib.rs index 5c1f4e0..2dc5c88 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,9 +32,10 @@ mod python; pub use bbox::{ batched_nms, batched_soft_nms, clip_and_filter_boxes, clip_boxes, filter_boxes_by_area, filter_boxes_by_min_size, filter_boxes_by_score, iou, letterbox_boxes, multiclass_nms, - multiclass_soft_nms, nms, nms_with_options, resize_boxes, soft_nms, unletterbox_boxes, - xywh_to_xyxy, xyxy_to_xywh, BBoxError, BBoxXYWH, BBoxXYXY, BoxFilterResult, Detection, - NmsOptions, SoftNmsMethod, SoftNmsOptions, + multiclass_soft_nms, nms, nms_with_options, postprocess_detections, resize_boxes, soft_nms, + unletterbox_boxes, xywh_to_xyxy, xyxy_to_xywh, BBoxError, BBoxXYWH, BBoxXYXY, BoxFilterResult, + BoxRemap, Detection, NmsOptions, PostprocessOptions, PostprocessResult, SoftNmsMethod, + SoftNmsOptions, }; /// Error returned by crop operations. pub use crop::{center_crop_image, crop_image, CropError, CropInfo, CropResult}; diff --git a/src/python.rs b/src/python.rs index 1046a01..88bdc41 100644 --- a/src/python.rs +++ b/src/python.rs @@ -12,7 +12,8 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyModule}; use crate::bbox::{ - self, BBoxError, BBoxXYWH, BBoxXYXY, Detection, NmsOptions, SoftNmsMethod, SoftNmsOptions, + self, BBoxError, BBoxXYWH, BBoxXYXY, BoxRemap, Detection, NmsOptions, PostprocessOptions, + SoftNmsMethod, SoftNmsOptions, }; use crate::crop::{self, CropError}; use crate::layout::{self, LayoutError}; @@ -462,6 +463,88 @@ fn box_filter_result_to_pydict<'py>( Ok(result) } +fn postprocess_result_to_pydict<'py>( + py: Python<'py>, + result_data: bbox::PostprocessResult, +) -> PyResult> { + let mut indices = Vec::with_capacity(result_data.detections.len()); + let mut class_ids = Vec::with_capacity(result_data.detections.len()); + let mut scores = Vec::with_capacity(result_data.detections.len()); + + for detection in result_data.detections { + indices.push(detection.box_index as i64); + class_ids.push(detection.class_id as i64); + scores.push(detection.score); + } + + let result = PyDict::new(py); + result.set_item("boxes", boxes_xyxy_to_numpy(py, result_data.boxes)?)?; + result.set_item( + "indices", + PyArray1::from_owned_array(py, Array1::from_vec(indices)), + )?; + result.set_item( + "class_ids", + PyArray1::from_owned_array(py, Array1::from_vec(class_ids)), + )?; + result.set_item( + "scores", + PyArray1::from_owned_array(py, Array1::from_vec(scores)), + )?; + Ok(result) +} + +#[allow(clippy::too_many_arguments)] +fn parse_postprocess_remap( + geometry_mode: Option<&str>, + processed_width: Option, + processed_height: Option, + original_width: Option, + original_height: Option, +) -> PyResult { + match geometry_mode.map(|value| value.to_ascii_lowercase()) { + None => Ok(BoxRemap::None), + Some(mode) if mode == "current" => Ok(BoxRemap::Current { + width: processed_width.ok_or_else(|| { + PyValueError::new_err("processed_width is required for geometry_mode='current'") + })?, + height: processed_height.ok_or_else(|| { + PyValueError::new_err("processed_height is required for geometry_mode='current'") + })?, + }), + Some(mode) if mode == "resize" => Ok(BoxRemap::Resize { + processed_width: processed_width.ok_or_else(|| { + PyValueError::new_err("processed_width is required for geometry_mode='resize'") + })?, + processed_height: processed_height.ok_or_else(|| { + PyValueError::new_err("processed_height is required for geometry_mode='resize'") + })?, + original_width: original_width.ok_or_else(|| { + PyValueError::new_err("original_width is required for geometry_mode='resize'") + })?, + original_height: original_height.ok_or_else(|| { + PyValueError::new_err("original_height is required for geometry_mode='resize'") + })?, + }), + Some(mode) if mode == "letterbox" => Ok(BoxRemap::Letterbox { + processed_width: processed_width.ok_or_else(|| { + PyValueError::new_err("processed_width is required for geometry_mode='letterbox'") + })?, + processed_height: processed_height.ok_or_else(|| { + PyValueError::new_err("processed_height is required for geometry_mode='letterbox'") + })?, + original_width: original_width.ok_or_else(|| { + PyValueError::new_err("original_width is required for geometry_mode='letterbox'") + })?, + original_height: original_height.ok_or_else(|| { + PyValueError::new_err("original_height is required for geometry_mode='letterbox'") + })?, + }), + Some(mode) => Err(PyValueError::new_err(format!( + "unsupported geometry_mode '{mode}'. Use current, resize, or letterbox" + ))), + } +} #[pyfunction(name = "compute_letterbox")] fn compute_letterbox_py<'py>( py: Python<'py>, @@ -1081,6 +1164,100 @@ fn multiclass_soft_nms_py<'py>( detections_to_pydict(py, detections) } +#[pyfunction(name = "postprocess_detections")] +#[pyo3(signature = ( + boxes, + scores, + class_ids=None, + geometry_mode=None, + processed_width=None, + processed_height=None, + original_width=None, + original_height=None, + clip=false, + iou_threshold=0.5, + score_threshold=None, + pre_nms_top_k=None, + max_detections=None, + min_width=0.0, + min_height=0.0, + soft=false, + soft_method=None, + sigma=0.5 +))] +#[allow(clippy::too_many_arguments)] +fn postprocess_detections_py<'py>( + py: Python<'py>, + boxes: PyReadonlyArray2<'_, f32>, + scores: PyReadonlyArray1<'_, f32>, + class_ids: Option>, + geometry_mode: Option<&str>, + processed_width: Option, + processed_height: Option, + original_width: Option, + original_height: Option, + clip: bool, + iou_threshold: f32, + score_threshold: Option, + pre_nms_top_k: Option, + max_detections: Option, + min_width: f32, + min_height: f32, + soft: bool, + soft_method: Option<&str>, + sigma: f32, +) -> PyResult> { + let boxes = boxes_from_numpy(boxes)?; + let scores = scores_from_numpy(scores); + let class_ids = if let Some(class_ids) = class_ids { + class_ids_from_numpy(class_ids)? + } else { + vec![0usize; boxes.len()] + }; + let remap = parse_postprocess_remap( + geometry_mode, + processed_width, + processed_height, + original_width, + original_height, + )?; + let options = PostprocessOptions { + iou_threshold, + score_threshold: score_threshold.unwrap_or(f32::NEG_INFINITY), + pre_nms_top_k, + max_detections, + min_width, + min_height, + clip, + }; + let soft_options = if soft { + Some(soft_nms_options( + parse_soft_nms_method(soft_method)?, + iou_threshold, + score_threshold, + sigma, + pre_nms_top_k, + max_detections, + )) + } else { + None + }; + + let result_data = py + .detach(move || { + bbox::postprocess_detections( + &boxes, + &scores, + &class_ids, + remap, + &options, + soft_options.as_ref(), + ) + }) + .map_err(map_bbox_error)?; + postprocess_result_to_pydict(py, result_data) +} + #[pyfunction] #[pyo3(signature = (input_bytes, x, y, width, height, output_format=None))] fn crop_image<'py>( @@ -1321,6 +1498,7 @@ fn rusty_cv(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(soft_nms_py, m)?)?; m.add_function(wrap_pyfunction!(batched_soft_nms_py, m)?)?; m.add_function(wrap_pyfunction!(multiclass_soft_nms_py, m)?)?; + m.add_function(wrap_pyfunction!(postprocess_detections_py, m)?)?; m.add_function(wrap_pyfunction!(crop_image, m)?)?; m.add_function(wrap_pyfunction!(crop_image_numpy, m)?)?; m.add_function(wrap_pyfunction!(center_crop_image, m)?)?;