From 3205b9877407196bd0037142f5da7b27f1fa1489 Mon Sep 17 00:00:00 2001 From: Jules Giraud Date: Wed, 27 May 2026 01:05:10 +0200 Subject: [PATCH] Add segmentation mask geometry helpers --- scripts/python_smoke_test.py | 40 +++++ src/lib.rs | 7 + src/mask.rs | 291 +++++++++++++++++++++++++++++++++++ src/python.rs | 151 ++++++++++++++++++ 4 files changed, 489 insertions(+) create mode 100644 src/mask.rs diff --git a/scripts/python_smoke_test.py b/scripts/python_smoke_test.py index fa322b1..4ae98b0 100644 --- a/scripts/python_smoke_test.py +++ b/scripts/python_smoke_test.py @@ -20,6 +20,46 @@ def main() -> None: assert info["padding"]["top"] == 140 assert info["padding"]["bottom"] == 140 + mask = np.array( + [ + [0, 255], + [255, 0], + ], + dtype=np.uint8, + ) + resized_mask = rusty_cv.resize_mask_numpy(mask, 4, 2) + assert resized_mask.dtype == np.uint8 + assert resized_mask.shape == (2, 4) + + letterboxed_mask, letterboxed_mask_info = rusty_cv.letterbox_mask_numpy(mask, 4, 4, fill=7) + assert letterboxed_mask.shape == (4, 4) + assert letterboxed_mask_info["padding"]["top"] == 0 + restored_mask = rusty_cv.unletterbox_mask_numpy(letterboxed_mask, 2, 2, 4, 4) + assert restored_mask.shape == (2, 2) + assert np.array_equal(restored_mask, mask) + + thresholded_mask = rusty_cv.threshold_mask_numpy( + np.array([[0.1, 0.8], [0.6, 0.2]], dtype=np.float32), + threshold=0.5, + ) + assert np.array_equal( + thresholded_mask, + np.array([[0, 255], [255, 0]], dtype=np.uint8), + ) + + box = rusty_cv.mask_to_box_numpy( + np.array( + [ + [0, 0, 0, 0], + [0, 255, 255, 0], + [0, 255, 255, 0], + [0, 0, 0, 0], + ], + dtype=np.uint8, + ) + ) + assert box == (1.0, 1.0, 3.0, 3.0) + resized = rusty_cv.resize_image(PNG_1X1_RED, 4, 2, filter="nearest", output_format="png") letterboxed = rusty_cv.letterbox_image( PNG_1X1_RED, diff --git a/src/lib.rs b/src/lib.rs index 2dc5c88..8e4b633 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,8 @@ pub mod crop; pub mod layout; /// Letterbox geometry and image operations. pub mod letterbox; +/// Segmentation mask geometry helpers. +pub mod mask; /// Image normalization operations. pub mod normalize; /// Fused inference preprocessing operations. @@ -45,6 +47,11 @@ pub use layout::{chw_to_hwc, hwc_to_chw, nchw_to_nhwc, nhwc_to_nchw, rgb_to_bgr, pub use letterbox::{ compute_letterbox, letterbox_image, LetterboxError, LetterboxInfo, LetterboxResult, Padding, }; +/// Error returned by segmentation mask operations. +pub use mask::{ + letterbox_mask, mask_to_box, resize_mask, threshold_mask, unletterbox_mask, + LetterboxMaskResult, MaskError, ResizeMaskResult, +}; /// Error returned by normalization operations. pub use normalize::{normalize_image, NormalizeError, NormalizeInfo, NormalizeResult}; /// Error returned by fused preprocessing operations. diff --git a/src/mask.rs b/src/mask.rs new file mode 100644 index 0000000..fdbd687 --- /dev/null +++ b/src/mask.rs @@ -0,0 +1,291 @@ +use crate::bbox::BBoxXYXY; +use crate::letterbox::{self, LetterboxError, LetterboxInfo}; + +/// Result of resizing a mask to exact dimensions. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ResizeMaskResult { + pub data: Vec, + pub width: u32, + pub height: u32, +} + +/// Result of letterboxing a mask. +#[derive(Debug, Clone, PartialEq)] +pub struct LetterboxMaskResult { + pub data: Vec, + pub info: LetterboxInfo, +} + +/// Errors for segmentation mask operations. +#[derive(Debug, Clone, PartialEq)] +pub enum MaskError { + InvalidDimensions { width: u32, height: u32 }, + InvalidDataLength { expected: usize, actual: usize }, + InvalidThreshold(f32), + Letterbox(LetterboxError), +} + +impl std::fmt::Display for MaskError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidDimensions { width, height } => write!( + f, + "mask width and height must be greater than zero, got {}x{}", + width, height + ), + Self::InvalidDataLength { expected, actual } => write!( + f, + "mask data length does not match shape, expected {} values but got {}", + expected, actual + ), + Self::InvalidThreshold(value) => { + write!(f, "threshold must be finite, got {}", value) + } + Self::Letterbox(err) => err.fmt(f), + } + } +} + +impl std::error::Error for MaskError {} + +/// Resize a mask using nearest-neighbor sampling. +pub fn resize_mask( + mask: &[u8], + width: u32, + height: u32, + target_width: u32, + target_height: u32, +) -> Result { + validate_mask(mask, width, height)?; + validate_dimensions(target_width, target_height)?; + + let width_usize = width as usize; + let height_usize = height as usize; + let target_width_usize = target_width as usize; + let target_height_usize = target_height as usize; + let mut output = vec![0u8; target_width_usize * target_height_usize]; + + for y in 0..target_height_usize { + let source_y = y * height_usize / target_height_usize; + for x in 0..target_width_usize { + let source_x = x * width_usize / target_width_usize; + output[y * target_width_usize + x] = mask[source_y * width_usize + source_x]; + } + } + + Ok(ResizeMaskResult { + data: output, + width: target_width, + height: target_height, + }) +} + +/// Letterbox a mask into the target dimensions. +pub fn letterbox_mask( + mask: &[u8], + width: u32, + height: u32, + target_width: u32, + target_height: u32, + fill: u8, +) -> Result { + validate_mask(mask, width, height)?; + let info = letterbox::compute_letterbox(width, height, target_width, target_height) + .map_err(MaskError::Letterbox)?; + let resized = resize_mask(mask, width, height, info.resized_width, info.resized_height)?; + let mut canvas = vec![fill; target_width as usize * target_height as usize]; + + for y in 0..info.resized_height as usize { + let dest_y = y + info.padding.top as usize; + let dest_row_start = dest_y * target_width as usize + info.padding.left as usize; + let src_row_start = y * info.resized_width as usize; + let src_row_end = src_row_start + info.resized_width as usize; + canvas[dest_row_start..dest_row_start + info.resized_width as usize] + .copy_from_slice(&resized.data[src_row_start..src_row_end]); + } + + Ok(LetterboxMaskResult { data: canvas, info }) +} + +/// Remove letterbox padding and resize the mask back to the original dimensions. +pub fn unletterbox_mask( + mask: &[u8], + target_width: u32, + target_height: u32, + original_width: u32, + original_height: u32, +) -> Result { + validate_mask(mask, target_width, target_height)?; + let info = + letterbox::compute_letterbox(original_width, original_height, target_width, target_height) + .map_err(MaskError::Letterbox)?; + + let cropped_width = info.resized_width as usize; + let cropped_height = info.resized_height as usize; + let mut cropped = vec![0u8; cropped_width * cropped_height]; + + for y in 0..cropped_height { + let src_y = y + info.padding.top as usize; + let src_row_start = src_y * target_width as usize + info.padding.left as usize; + let src_row_end = src_row_start + cropped_width; + let dst_row_start = y * cropped_width; + cropped[dst_row_start..dst_row_start + cropped_width] + .copy_from_slice(&mask[src_row_start..src_row_end]); + } + + resize_mask( + &cropped, + info.resized_width, + info.resized_height, + original_width, + original_height, + ) +} + +/// Threshold a float mask into a binary `u8` mask with values `{0, 255}`. +pub fn threshold_mask( + mask: &[f32], + width: u32, + height: u32, + threshold: f32, +) -> Result, MaskError> { + validate_mask_length(mask.len(), width, height)?; + if !threshold.is_finite() { + return Err(MaskError::InvalidThreshold(threshold)); + } + + Ok(mask + .iter() + .map(|value| if *value >= threshold { 255 } else { 0 }) + .collect()) +} + +/// Compute one bounding box covering all non-zero pixels in a mask. +pub fn mask_to_box(mask: &[u8], width: u32, height: u32) -> Result, MaskError> { + validate_mask(mask, width, height)?; + + let width_usize = width as usize; + let mut min_x = width; + let mut min_y = height; + let mut max_x = 0u32; + let mut max_y = 0u32; + let mut found = false; + + for y in 0..height as usize { + for x in 0..width_usize { + if mask[y * width_usize + x] != 0 { + found = true; + min_x = min_x.min(x as u32); + min_y = min_y.min(y as u32); + max_x = max_x.max(x as u32 + 1); + max_y = max_y.max(y as u32 + 1); + } + } + } + + if !found { + return Ok(None); + } + + Ok(Some(BBoxXYXY { + x1: min_x as f32, + y1: min_y as f32, + x2: max_x as f32, + y2: max_y as f32, + })) +} + +fn validate_mask(mask: &[u8], width: u32, height: u32) -> Result<(), MaskError> { + validate_dimensions(width, height)?; + validate_mask_length(mask.len(), width, height) +} + +fn validate_mask_length(actual_len: usize, width: u32, height: u32) -> Result<(), MaskError> { + let expected = width as usize * height as usize; + if actual_len != expected { + return Err(MaskError::InvalidDataLength { + expected, + actual: actual_len, + }); + } + Ok(()) +} + +fn validate_dimensions(width: u32, height: u32) -> Result<(), MaskError> { + if width == 0 || height == 0 { + return Err(MaskError::InvalidDimensions { width, height }); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn resizes_mask_with_nearest_neighbor() { + let mask = vec![0u8, 255, 255, 0]; + let result = resize_mask(&mask, 2, 2, 4, 2).unwrap(); + + assert_eq!(result.width, 4); + assert_eq!(result.height, 2); + assert_eq!(result.data, vec![0, 0, 255, 255, 255, 255, 0, 0]); + } + + #[test] + fn letterboxes_and_unletterboxes_mask() { + let mask = vec![0u8, 255]; + let letterboxed = letterbox_mask(&mask, 2, 1, 4, 4, 7).unwrap(); + + assert_eq!(letterboxed.info.resized_width, 4); + assert_eq!(letterboxed.info.padding.top, 1); + assert_eq!(letterboxed.data[0], 7); + + let restored = unletterbox_mask(&letterboxed.data, 4, 4, 2, 1).unwrap(); + assert_eq!(restored.data, mask); + } + + #[test] + fn thresholds_mask() { + let mask = vec![0.1f32, 0.6, 0.5, 0.2]; + let thresholded = threshold_mask(&mask, 2, 2, 0.5).unwrap(); + assert_eq!(thresholded, vec![0, 255, 255, 0]); + } + + #[test] + fn extracts_mask_box() { + let mask = vec![0u8, 0, 0, 0, 0, 255, 255, 0, 0, 255, 255, 0, 0, 0, 0, 0]; + let bbox = mask_to_box(&mask, 4, 4).unwrap().unwrap(); + assert_eq!( + bbox, + BBoxXYXY { + x1: 1.0, + y1: 1.0, + x2: 3.0, + y2: 3.0, + } + ); + } + + #[test] + fn rejects_invalid_masks() { + assert_eq!( + resize_mask(&[0u8], 0, 1, 1, 1).unwrap_err(), + MaskError::InvalidDimensions { + width: 0, + height: 1 + } + ); + assert!(matches!( + threshold_mask(&[0.0, 0.0, 0.0, 0.0], 2, 2, f32::NAN).unwrap_err(), + MaskError::InvalidThreshold(value) if value.is_nan() + )); + assert_eq!( + mask_to_box(&[0u8], 2, 2).unwrap_err(), + MaskError::InvalidDataLength { + expected: 4, + actual: 1, + } + ); + } +} diff --git a/src/python.rs b/src/python.rs index 88bdc41..f4a5b9b 100644 --- a/src/python.rs +++ b/src/python.rs @@ -18,6 +18,7 @@ use crate::bbox::{ use crate::crop::{self, CropError}; use crate::layout::{self, LayoutError}; use crate::letterbox::{self, LetterboxError}; +use crate::mask::{self, MaskError}; use crate::normalize::{self, NormalizeError}; use crate::preprocess::{ self, PreprocessError, PreprocessGeometry, PreprocessLayout, PreprocessMode, @@ -76,6 +77,10 @@ fn map_layout_error(err: LayoutError) -> PyErr { PyValueError::new_err(err.to_string()) } +fn map_mask_error(err: MaskError) -> PyErr { + PyValueError::new_err(err.to_string()) +} + fn map_normalize_error(err: NormalizeError) -> PyErr { PyValueError::new_err(err.to_string()) } @@ -160,6 +165,25 @@ fn preprocess_info_to_pydict<'py>( Ok(result) } +fn grayscale_mask_from_numpy(input: PyReadonlyArray2<'_, u8>) -> Vec { + input.as_array().iter().copied().collect() +} + +fn grayscale_mask_f32_from_numpy(input: PyReadonlyArray2<'_, f32>) -> Vec { + input.as_array().iter().copied().collect() +} + +fn mask_u8_to_numpy<'py>( + py: Python<'py>, + data: Vec, + height: u32, + width: u32, +) -> PyResult>> { + let array = Array2::from_shape_vec((height as usize, width as usize), data) + .map_err(|err| PyValueError::new_err(format!("failed to build NumPy array: {err}")))?; + Ok(PyArray2::from_owned_array(py, array)) +} + fn rgb_image_from_numpy(input: PyReadonlyArray3<'_, u8>) -> PyResult { let array = input.as_array(); let (height, width, channels) = array.dim(); @@ -559,6 +583,128 @@ fn compute_letterbox_py<'py>( letterbox_info_to_pydict(py, info) } +#[pyfunction] +#[pyo3(signature = (mask_array, target_width, target_height))] +fn resize_mask_numpy<'py>( + py: Python<'py>, + mask_array: PyReadonlyArray2<'_, u8>, + target_width: u32, + target_height: u32, +) -> PyResult>> { + let array = mask_array.as_array(); + let (height, width) = array.dim(); + let mask = grayscale_mask_from_numpy(mask_array); + let result = mask::resize_mask( + &mask, + width as u32, + height as u32, + target_width, + target_height, + ) + .map_err(map_mask_error)?; + mask_u8_to_numpy(py, result.data, result.height, result.width) +} + +#[pyfunction] +#[pyo3(signature = (mask_array, target_width, target_height, fill=0))] +fn letterbox_mask_numpy<'py>( + py: Python<'py>, + mask_array: PyReadonlyArray2<'_, u8>, + target_width: u32, + target_height: u32, + fill: u8, +) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyDict>)> { + let array = mask_array.as_array(); + let (height, width) = array.dim(); + let mask = grayscale_mask_from_numpy(mask_array); + let result = mask::letterbox_mask( + &mask, + width as u32, + height as u32, + target_width, + target_height, + fill, + ) + .map_err(map_mask_error)?; + let info = letterbox_info_to_pydict(py, result.info)?; + let array = mask_u8_to_numpy(py, result.data, target_height, target_width)?; + Ok((array, info)) +} + +#[pyfunction] +fn unletterbox_mask_numpy<'py>( + py: Python<'py>, + mask_array: PyReadonlyArray2<'_, u8>, + original_width: u32, + original_height: u32, + target_width: u32, + target_height: u32, +) -> PyResult>> { + let array = mask_array.as_array(); + let (height, width) = array.dim(); + if width as u32 != target_width || height as u32 != target_height { + return Err(PyValueError::new_err(format!( + "mask shape does not match target dimensions, got {}x{} and expected {}x{}", + width, height, target_width, target_height + ))); + } + let mask = grayscale_mask_from_numpy(mask_array); + let result = mask::unletterbox_mask( + &mask, + target_width, + target_height, + original_width, + original_height, + ) + .map_err(map_mask_error)?; + mask_u8_to_numpy(py, result.data, result.height, result.width) +} + +#[pyfunction] +#[pyo3(signature = (mask_array, threshold=0.5))] +fn threshold_mask_numpy<'py>( + py: Python<'py>, + mask_array: &Bound<'py, PyAny>, + threshold: f32, +) -> PyResult>> { + if let Ok(mask_array) = mask_array.extract::>() { + let array = mask_array.as_array(); + let (height, width) = array.dim(); + let mask = grayscale_mask_f32_from_numpy(mask_array); + let thresholded = mask::threshold_mask(&mask, width as u32, height as u32, threshold) + .map_err(map_mask_error)?; + return mask_u8_to_numpy(py, thresholded, height as u32, width as u32); + } + + if let Ok(mask_array) = mask_array.extract::>() { + let array = mask_array.as_array(); + let (height, width) = array.dim(); + let mask = mask_array + .as_array() + .iter() + .map(|value| *value as f32) + .collect::>(); + let thresholded = mask::threshold_mask(&mask, width as u32, height as u32, threshold) + .map_err(map_mask_error)?; + return mask_u8_to_numpy(py, thresholded, height as u32, width as u32); + } + + Err(PyValueError::new_err( + "expected a HxW NumPy array with dtype uint8 or float32", + )) +} + +#[pyfunction] +fn mask_to_box_numpy( + mask_array: PyReadonlyArray2<'_, u8>, +) -> PyResult> { + let array = mask_array.as_array(); + let (height, width) = array.dim(); + let mask = grayscale_mask_from_numpy(mask_array); + let bbox = mask::mask_to_box(&mask, width as u32, height as u32).map_err(map_mask_error)?; + Ok(bbox.map(|bbox| (bbox.x1, bbox.y1, bbox.x2, bbox.y2))) +} + #[pyfunction] #[pyo3(signature = (input_bytes, target_width, target_height, filter=None, output_format=None))] fn resize_image<'py>( @@ -1481,6 +1627,11 @@ fn preprocess_image_numpy<'py>( #[pymodule] fn rusty_cv(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(compute_letterbox_py, m)?)?; + m.add_function(wrap_pyfunction!(resize_mask_numpy, m)?)?; + m.add_function(wrap_pyfunction!(letterbox_mask_numpy, m)?)?; + m.add_function(wrap_pyfunction!(unletterbox_mask_numpy, m)?)?; + m.add_function(wrap_pyfunction!(threshold_mask_numpy, m)?)?; + m.add_function(wrap_pyfunction!(mask_to_box_numpy, m)?)?; m.add_function(wrap_pyfunction!(iou_py, m)?)?; m.add_function(wrap_pyfunction!(xyxy_to_xywh_numpy, m)?)?; m.add_function(wrap_pyfunction!(xywh_to_xyxy_numpy, m)?)?;