diff --git a/scripts/python_smoke_test.py b/scripts/python_smoke_test.py index fa322b1..3f080ea 100644 --- a/scripts/python_smoke_test.py +++ b/scripts/python_smoke_test.py @@ -125,6 +125,41 @@ def main() -> None: assert resized_preprocess_info["layout"] == "hwc" assert np.isclose(resized_preprocessed[0, 0, 0], 1.0) + batch_preprocessed, batch_infos = rusty_cv.preprocess_batch_numpy( + np.stack([array, array], axis=0), + 4, + 4, + mode="letterbox", + fill=(114, 114, 114), + filter="nearest", + mean=(0.0, 0.0, 0.0), + std=(1.0, 1.0, 1.0), + scale_to_unit=True, + layout="chw", + ) + assert batch_preprocessed.dtype == np.float32 + assert batch_preprocessed.shape == (2, 3, 4, 4) + assert len(batch_infos) == 2 + assert batch_infos[0]["mode"] == "letterbox" + assert batch_infos[0]["layout"] == "chw" + assert batch_infos[1]["geometry"]["resized_width"] == 4 + + list_batch_preprocessed, list_batch_infos = rusty_cv.preprocess_batch_numpy( + [array, array], + 3, + 2, + mode="resize", + filter="nearest", + mean=(0.0, 0.0, 0.0), + std=(255.0, 255.0, 255.0), + scale_to_unit=False, + layout="hwc", + ) + assert list_batch_preprocessed.shape == (2, 2, 3, 3) + assert list_batch_infos[0]["mode"] == "resize" + assert list_batch_infos[1]["layout"] == "hwc" + assert np.isclose(list_batch_preprocessed[0, 0, 0, 0], 1.0) + hwc_float = np.arange(24, dtype=np.float32).reshape(2, 4, 3) chw_float = rusty_cv.hwc_to_chw_numpy(hwc_float) assert chw_float.dtype == np.float32 diff --git a/src/lib.rs b/src/lib.rs index 2dc5c88..49c28a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,8 +49,8 @@ pub use letterbox::{ pub use normalize::{normalize_image, NormalizeError, NormalizeInfo, NormalizeResult}; /// Error returned by fused preprocessing operations. pub use preprocess::{ - preprocess_image, PreprocessError, PreprocessGeometry, PreprocessInfo, PreprocessLayout, - PreprocessMode, PreprocessResult, + preprocess_batch, preprocess_image, PreprocessBatchResult, PreprocessError, PreprocessGeometry, + PreprocessInfo, PreprocessLayout, PreprocessMode, PreprocessResult, }; /// Error returned by direct resize operations. pub use resize::{resize_image, ResizeError, ResizeInfo, ResizeResult}; diff --git a/src/preprocess.rs b/src/preprocess.rs index 41c35b4..69db92d 100644 --- a/src/preprocess.rs +++ b/src/preprocess.rs @@ -43,6 +43,13 @@ pub struct PreprocessResult { pub info: PreprocessInfo, } +/// Result of preprocessing a batch of images into one contiguous tensor. +#[derive(Debug, Clone, PartialEq)] +pub struct PreprocessBatchResult { + pub data: Vec, + pub infos: Vec, +} + /// Errors for fused preprocessing operations. #[derive(Debug, Clone, PartialEq)] pub enum PreprocessError { @@ -154,6 +161,41 @@ pub fn preprocess_image( }) } +/// Resize or letterbox a batch of images and normalize them into one contiguous tensor buffer. +#[allow(clippy::too_many_arguments)] +pub fn preprocess_batch( + images: &[DynamicImage], + target_width: u32, + target_height: u32, + mode: PreprocessMode, + filter: FilterType, + mean: [f32; 3], + std: [f32; 3], + scale_to_unit: bool, + layout: PreprocessLayout, +) -> Result { + let mut data = Vec::new(); + let mut infos = Vec::with_capacity(images.len()); + + for image in images { + let result = preprocess_image( + image, + target_width, + target_height, + mode, + filter, + mean, + std, + scale_to_unit, + layout, + )?; + data.extend(result.data); + infos.push(result.info); + } + + Ok(PreprocessBatchResult { data, infos }) +} + #[cfg(test)] mod tests { use super::*; @@ -263,4 +305,32 @@ mod tests { } ); } + + #[test] + fn preprocesses_batch_into_nchw_tensor() { + let images = vec![ + image_from_pixels(2, 1, vec![[255, 0, 0], [0, 255, 0]]), + image_from_pixels(2, 1, vec![[0, 0, 255], [255, 255, 255]]), + ]; + + let result = preprocess_batch( + &images, + 2, + 1, + PreprocessMode::Resize, + FilterType::Nearest, + [0.0, 0.0, 0.0], + [1.0, 1.0, 1.0], + true, + PreprocessLayout::Chw, + ) + .unwrap(); + + assert_eq!(result.infos.len(), 2); + assert_eq!(result.infos[0].layout, PreprocessLayout::Chw); + assert_eq!(result.infos[1].layout, PreprocessLayout::Chw); + assert_eq!(result.data.len(), 12); + assert_eq!(result.data[0..6], [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]); + assert_eq!(result.data[6..12], [0.0, 1.0, 0.0, 1.0, 1.0, 1.0]); + } } diff --git a/src/python.rs b/src/python.rs index 88bdc41..45b8acd 100644 --- a/src/python.rs +++ b/src/python.rs @@ -9,7 +9,7 @@ use numpy::{ }; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyDict, PyModule}; +use pyo3::types::{PyBytes, PyDict, PyList, PyModule}; use crate::bbox::{ self, BBoxError, BBoxXYWH, BBoxXYXY, BoxRemap, Detection, NmsOptions, PostprocessOptions, @@ -160,6 +160,17 @@ fn preprocess_info_to_pydict<'py>( Ok(result) } +fn preprocess_infos_to_pylist<'py>( + py: Python<'py>, + infos: Vec, +) -> PyResult> { + let list = PyList::empty(py); + for info in infos { + list.append(preprocess_info_to_pydict(py, info)?)?; + } + Ok(list) +} + fn rgb_image_from_numpy(input: PyReadonlyArray3<'_, u8>) -> PyResult { let array = input.as_array(); let (height, width, channels) = array.dim(); @@ -184,6 +195,63 @@ fn rgb_image_from_numpy(input: PyReadonlyArray3<'_, u8>) -> PyResult { }) } +fn rgb_images_from_python_input(input: &Bound<'_, PyAny>) -> PyResult> { + if let Ok(array) = input.extract::>() { + let array = array.as_array(); + let (batch, height, width, channels) = array.dim(); + + if channels != 3 { + return Err(PyValueError::new_err(format!( + "expected a NxHxWx3 uint8 array, got last dimension {channels}" + ))); + } + + let mut images = Vec::with_capacity(batch); + for index in 0..batch { + let mut buffer = Vec::with_capacity(height * width * channels); + for y in 0..height { + for x in 0..width { + for c in 0..channels { + buffer.push(array[(index, y, x, c)]); + } + } + } + + let image = + RgbImage::from_vec(width as u32, height as u32, buffer).ok_or_else(|| { + PyValueError::new_err("failed to convert batched NumPy array into RGB images") + })?; + images.push(image); + } + + return Ok(images); + } + + let iterator = input.try_iter().map_err(|_| { + PyValueError::new_err( + "expected either a NxHxWx3 uint8 array or a sequence of HxWx3 uint8 arrays", + ) + })?; + + let mut images = Vec::new(); + for item in iterator { + let array = item?.extract::>().map_err(|_| { + PyValueError::new_err( + "all sequence items must be HxWx3 uint8 arrays for batch preprocessing", + ) + })?; + images.push(rgb_image_from_numpy(array)?); + } + + if images.is_empty() { + return Err(PyValueError::new_err( + "batch preprocessing requires at least one input image", + )); + } + + Ok(images) +} + fn float_array_to_numpy<'py>( py: Python<'py>, data: Vec, @@ -200,6 +268,23 @@ fn float_array_to_numpy<'py>( Ok(PyArray3::from_owned_array(py, array)) } +fn batch_float_array_to_numpy<'py>( + py: Python<'py>, + data: Vec, + batch: usize, + height: u32, + width: u32, + layout: PreprocessLayout, +) -> PyResult>> { + let shape = match layout { + PreprocessLayout::Hwc => (batch, height as usize, width as usize, 3), + PreprocessLayout::Chw => (batch, 3, height as usize, width as usize), + }; + let array = Array4::from_shape_vec(shape, data) + .map_err(|err| PyValueError::new_err(format!("failed to build NumPy array: {err}")))?; + Ok(PyArray4::from_owned_array(py, array)) +} + fn rgb_image_to_numpy<'py>(py: Python<'py>, image: RgbImage) -> PyResult>> { let (width, height) = image.dimensions(); let array = Array3::from_shape_vec((height as usize, width as usize, 3), image.into_raw()) @@ -1478,6 +1563,73 @@ fn preprocess_image_numpy<'py>( Ok((array, info)) } +#[pyfunction] +#[pyo3(signature = ( + input_arrays, + target_width, + target_height, + mode=None, + fill=(114, 114, 114), + filter=None, + mean=(0.0, 0.0, 0.0), + std=(1.0, 1.0, 1.0), + scale_to_unit=true, + layout=None +))] +#[allow(clippy::too_many_arguments)] +fn preprocess_batch_numpy<'py>( + py: Python<'py>, + input_arrays: &Bound<'py, PyAny>, + target_width: u32, + target_height: u32, + mode: Option<&str>, + fill: (u8, u8, u8), + filter: Option<&str>, + mean: (f32, f32, f32), + std: (f32, f32, f32), + scale_to_unit: bool, + layout: Option<&str>, +) -> PyResult<(Bound<'py, PyArray4>, Bound<'py, PyList>)> { + let images = rgb_images_from_python_input(input_arrays)?; + let filter = parse_filter(filter)?; + let mode = parse_preprocess_mode(mode, fill)?; + let layout = parse_preprocess_layout(layout)?; + let dynamic_images = images + .into_iter() + .map(DynamicImage::ImageRgb8) + .collect::>(); + + let result = py + .detach(move || { + preprocess::preprocess_batch( + &dynamic_images, + target_width, + target_height, + mode, + filter, + [mean.0, mean.1, mean.2], + [std.0, std.1, std.2], + scale_to_unit, + layout, + ) + }) + .map_err(map_preprocess_error)?; + + let first_info = result.infos.first().copied().ok_or_else(|| { + PyValueError::new_err("batch preprocessing requires at least one input image") + })?; + let infos = preprocess_infos_to_pylist(py, result.infos)?; + let array = batch_float_array_to_numpy( + py, + result.data, + infos.len(), + first_info.height, + first_info.width, + layout, + )?; + Ok((array, infos)) +} + #[pymodule] fn rusty_cv(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(compute_letterbox_py, m)?)?; @@ -1514,5 +1666,6 @@ fn rusty_cv(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(letterbox_image_numpy, m)?)?; m.add_function(wrap_pyfunction!(normalize_image_numpy, m)?)?; m.add_function(wrap_pyfunction!(preprocess_image_numpy, m)?)?; + m.add_function(wrap_pyfunction!(preprocess_batch_numpy, m)?)?; Ok(()) }