diff --git a/scripts/python_smoke_test.py b/scripts/python_smoke_test.py index 22ba474..e03ef57 100644 --- a/scripts/python_smoke_test.py +++ b/scripts/python_smoke_test.py @@ -125,6 +125,25 @@ def main() -> None: assert resized_preprocess_info["layout"] == "hwc" assert np.isclose(resized_preprocessed[0, 0, 0], 1.0) + hwc_float = np.arange(24, dtype=np.float32).reshape(2, 4, 3) + chw_float = rusty_cv.hwc_to_chw_numpy(hwc_float) + assert chw_float.dtype == np.float32 + assert chw_float.shape == (3, 2, 4) + assert np.allclose(chw_float, np.transpose(hwc_float, (2, 0, 1))) + assert np.allclose(rusty_cv.chw_to_hwc_numpy(chw_float), hwc_float) + + bgr_array = rusty_cv.rgb_to_bgr_numpy(array) + assert bgr_array.dtype == np.uint8 + assert bgr_array.shape == array.shape + assert bgr_array[0, 0].tolist() == [0, 0, 255] + + nhwc_batch = np.stack([hwc_float, hwc_float + 100.0], axis=0) + nchw_batch = rusty_cv.nhwc_to_nchw_numpy(nhwc_batch) + assert nchw_batch.dtype == np.float32 + assert nchw_batch.shape == (2, 3, 2, 4) + assert np.allclose(nchw_batch, np.transpose(nhwc_batch, (0, 3, 1, 2))) + assert np.allclose(rusty_cv.nchw_to_nhwc_numpy(nchw_batch), nhwc_batch) + boxes = np.array( [ [0.0, 0.0, 10.0, 10.0], diff --git a/src/layout.rs b/src/layout.rs new file mode 100644 index 0000000..efeb952 --- /dev/null +++ b/src/layout.rs @@ -0,0 +1,256 @@ +/// Errors returned by tensor layout conversion helpers. +#[derive(Debug, Clone, PartialEq)] +pub enum LayoutError { + InvalidDimension { name: &'static str, value: u32 }, + InvalidChannels(u32), + InvalidDataLength { expected: usize, actual: usize }, +} + +impl std::fmt::Display for LayoutError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidDimension { name, value } => { + write!(f, "{name} must be greater than zero, got {value}") + } + Self::InvalidChannels(value) => { + write!(f, "channels must be greater than zero, got {}", value) + } + Self::InvalidDataLength { expected, actual } => write!( + f, + "tensor data length does not match shape, expected {} values but got {}", + expected, actual + ), + } + } +} + +impl std::error::Error for LayoutError {} + +/// Convert a HWC tensor into CHW layout. +pub fn hwc_to_chw( + data: &[T], + height: u32, + width: u32, + channels: u32, +) -> Result, LayoutError> { + validate_3d_shape(data.len(), height, width, channels)?; + let height = height as usize; + let width = width as usize; + let channels = channels as usize; + + let mut output = Vec::with_capacity(data.len()); + for channel in 0..channels { + for y in 0..height { + for x in 0..width { + let input_index = ((y * width) + x) * channels + channel; + output.push(data[input_index]); + } + } + } + Ok(output) +} + +/// Convert a CHW tensor into HWC layout. +pub fn chw_to_hwc( + data: &[T], + channels: u32, + height: u32, + width: u32, +) -> Result, LayoutError> { + validate_3d_shape(data.len(), height, width, channels)?; + let height = height as usize; + let width = width as usize; + let channels = channels as usize; + + let mut output = Vec::with_capacity(data.len()); + for y in 0..height { + for x in 0..width { + for channel in 0..channels { + let input_index = ((channel * height) + y) * width + x; + output.push(data[input_index]); + } + } + } + Ok(output) +} + +/// Convert a NHWC tensor into NCHW layout. +pub fn nhwc_to_nchw( + data: &[T], + batch: u32, + height: u32, + width: u32, + channels: u32, +) -> Result, LayoutError> { + validate_4d_shape(data.len(), batch, height, width, channels)?; + let batch = batch as usize; + let height = height as usize; + let width = width as usize; + let channels = channels as usize; + + let mut output = Vec::with_capacity(data.len()); + for n in 0..batch { + for channel in 0..channels { + for y in 0..height { + for x in 0..width { + let input_index = (((n * height) + y) * width + x) * channels + channel; + output.push(data[input_index]); + } + } + } + } + Ok(output) +} + +/// Convert a NCHW tensor into NHWC layout. +pub fn nchw_to_nhwc( + data: &[T], + batch: u32, + channels: u32, + height: u32, + width: u32, +) -> Result, LayoutError> { + validate_4d_shape(data.len(), batch, height, width, channels)?; + let batch = batch as usize; + let height = height as usize; + let width = width as usize; + let channels = channels as usize; + + let mut output = Vec::with_capacity(data.len()); + for n in 0..batch { + for y in 0..height { + for x in 0..width { + for channel in 0..channels { + let input_index = (((n * channels) + channel) * height + y) * width + x; + output.push(data[input_index]); + } + } + } + } + Ok(output) +} + +/// Swap the channel order of an RGB HWC tensor into BGR. +pub fn rgb_to_bgr(data: &[T], height: u32, width: u32) -> Result, LayoutError> { + validate_3d_shape(data.len(), height, width, 3)?; + let height = height as usize; + let width = width as usize; + + let mut output = Vec::with_capacity(data.len()); + for y in 0..height { + for x in 0..width { + let offset = ((y * width) + x) * 3; + output.push(data[offset + 2]); + output.push(data[offset + 1]); + output.push(data[offset]); + } + } + Ok(output) +} + +fn validate_3d_shape( + actual_len: usize, + height: u32, + width: u32, + channels: u32, +) -> Result<(), LayoutError> { + validate_positive_dimension("height", height)?; + validate_positive_dimension("width", width)?; + if channels == 0 { + return Err(LayoutError::InvalidChannels(channels)); + } + + let expected = height as usize * width as usize * channels as usize; + if actual_len != expected { + return Err(LayoutError::InvalidDataLength { + expected, + actual: actual_len, + }); + } + + Ok(()) +} + +fn validate_4d_shape( + actual_len: usize, + batch: u32, + height: u32, + width: u32, + channels: u32, +) -> Result<(), LayoutError> { + validate_positive_dimension("batch", batch)?; + validate_3d_shape(actual_len / batch as usize, height, width, channels)?; + + let expected = batch as usize * height as usize * width as usize * channels as usize; + if actual_len != expected { + return Err(LayoutError::InvalidDataLength { + expected, + actual: actual_len, + }); + } + + Ok(()) +} + +fn validate_positive_dimension(name: &'static str, value: u32) -> Result<(), LayoutError> { + if value == 0 { + return Err(LayoutError::InvalidDimension { name, value }); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn converts_hwc_to_chw_and_back() { + let input = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + + let chw = hwc_to_chw(&input, 2, 2, 3).unwrap(); + assert_eq!(chw, vec![1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12]); + + let roundtrip = chw_to_hwc(&chw, 3, 2, 2).unwrap(); + assert_eq!(roundtrip, input); + } + + #[test] + fn converts_nhwc_to_nchw_and_back() { + let input = vec![1f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + + let nchw = nhwc_to_nchw(&input, 1, 2, 2, 2).unwrap(); + assert_eq!(nchw, vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0]); + + let roundtrip = nchw_to_nhwc(&nchw, 1, 2, 2, 2).unwrap(); + assert_eq!(roundtrip, input); + } + + #[test] + fn swaps_rgb_to_bgr() { + let input = vec![255u8, 0, 10, 0, 128, 255]; + let output = rgb_to_bgr(&input, 1, 2).unwrap(); + assert_eq!(output, vec![10, 0, 255, 255, 128, 0]); + } + + #[test] + fn rejects_invalid_shapes() { + assert_eq!( + hwc_to_chw::(&[1, 2, 3], 0, 1, 3).unwrap_err(), + LayoutError::InvalidDimension { + name: "height", + value: 0, + } + ); + assert_eq!( + hwc_to_chw::(&[1, 2, 3], 1, 1, 0).unwrap_err(), + LayoutError::InvalidChannels(0) + ); + assert_eq!( + chw_to_hwc::(&[1, 2, 3], 3, 1, 2).unwrap_err(), + LayoutError::InvalidDataLength { + expected: 6, + actual: 3, + } + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 839e5fe..67f0582 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,8 @@ pub mod bbox; /// Crop geometry and image operations. pub mod crop; +/// Tensor layout and channel-order helpers. +pub mod layout; /// Letterbox geometry and image operations. pub mod letterbox; /// Image normalization operations. @@ -35,6 +37,8 @@ pub use bbox::{ }; /// Error returned by crop operations. pub use crop::{center_crop_image, crop_image, CropError, CropInfo, CropResult}; +/// Error returned by tensor layout operations. +pub use layout::{chw_to_hwc, hwc_to_chw, nchw_to_nhwc, nhwc_to_nchw, rgb_to_bgr, LayoutError}; /// Error returned by letterbox operations. pub use letterbox::{ compute_letterbox, letterbox_image, LetterboxError, LetterboxInfo, LetterboxResult, Padding, diff --git a/src/python.rs b/src/python.rs index 78f34f1..21f57c2 100644 --- a/src/python.rs +++ b/src/python.rs @@ -2,8 +2,11 @@ use std::io::Cursor; use image::imageops::FilterType; use image::{DynamicImage, ImageFormat, RgbImage}; -use numpy::ndarray::{Array1, Array2, Array3}; -use numpy::{PyArray1, PyArray2, PyArray3, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3}; +use numpy::ndarray::{Array1, Array2, Array3, Array4}; +use numpy::{ + PyArray1, PyArray2, PyArray3, PyArray4, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, + PyReadonlyArray4, +}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyModule}; @@ -12,6 +15,7 @@ use crate::bbox::{ self, BBoxError, BBoxXYWH, BBoxXYXY, Detection, NmsOptions, SoftNmsMethod, SoftNmsOptions, }; use crate::crop::{self, CropError}; +use crate::layout::{self, LayoutError}; use crate::letterbox::{self, LetterboxError}; use crate::normalize::{self, NormalizeError}; use crate::preprocess::{ @@ -67,6 +71,10 @@ fn map_crop_error(err: CropError) -> PyErr { PyValueError::new_err(err.to_string()) } +fn map_layout_error(err: LayoutError) -> PyErr { + PyValueError::new_err(err.to_string()) +} + fn map_normalize_error(err: NormalizeError) -> PyErr { PyValueError::new_err(err.to_string()) } @@ -268,6 +276,26 @@ fn boxes_xywh_to_numpy<'py>( Ok(PyArray2::from_owned_array(py, array)) } +fn array3_to_pyobject( + py: Python<'_>, + data: Vec, + shape: (usize, usize, usize), +) -> PyResult> { + let array = Array3::from_shape_vec(shape, data) + .map_err(|err| PyValueError::new_err(format!("failed to build NumPy array: {err}")))?; + Ok(PyArray3::from_owned_array(py, array).into_any().unbind()) +} + +fn array4_to_pyobject( + py: Python<'_>, + data: Vec, + shape: (usize, usize, usize, usize), +) -> PyResult> { + let array = Array4::from_shape_vec(shape, data) + .map_err(|err| PyValueError::new_err(format!("failed to build NumPy array: {err}")))?; + Ok(PyArray4::from_owned_array(py, array).into_any().unbind()) +} + fn scores_from_numpy(input: PyReadonlyArray1<'_, f32>) -> Vec { input.as_array().iter().copied().collect() } @@ -434,6 +462,173 @@ fn resize_image_numpy<'py>( rgb_image_to_numpy(py, result.image) } +#[pyfunction] +fn hwc_to_chw_numpy<'py>(py: Python<'py>, input_array: &Bound<'py, PyAny>) -> PyResult> { + if let Ok(array) = input_array.extract::>() { + let array_view = array.as_array(); + let (height, width, channels) = array_view.dim(); + let data = array_view.iter().copied().collect::>(); + let converted = layout::hwc_to_chw(&data, height as u32, width as u32, channels as u32) + .map_err(map_layout_error)?; + return array3_to_pyobject(py, converted, (channels, height, width)); + } + + if let Ok(array) = input_array.extract::>() { + let array_view = array.as_array(); + let (height, width, channels) = array_view.dim(); + let data = array_view.iter().copied().collect::>(); + let converted = layout::hwc_to_chw(&data, height as u32, width as u32, channels as u32) + .map_err(map_layout_error)?; + return array3_to_pyobject(py, converted, (channels, height, width)); + } + + Err(PyValueError::new_err( + "expected a HxWxC NumPy array with dtype uint8 or float32", + )) +} + +#[pyfunction] +fn chw_to_hwc_numpy<'py>(py: Python<'py>, input_array: &Bound<'py, PyAny>) -> PyResult> { + if let Ok(array) = input_array.extract::>() { + let array_view = array.as_array(); + let (channels, height, width) = array_view.dim(); + let data = array_view.iter().copied().collect::>(); + let converted = layout::chw_to_hwc(&data, channels as u32, height as u32, width as u32) + .map_err(map_layout_error)?; + return array3_to_pyobject(py, converted, (height, width, channels)); + } + + if let Ok(array) = input_array.extract::>() { + let array_view = array.as_array(); + let (channels, height, width) = array_view.dim(); + let data = array_view.iter().copied().collect::>(); + let converted = layout::chw_to_hwc(&data, channels as u32, height as u32, width as u32) + .map_err(map_layout_error)?; + return array3_to_pyobject(py, converted, (height, width, channels)); + } + + Err(PyValueError::new_err( + "expected a CxHxW NumPy array with dtype uint8 or float32", + )) +} + +#[pyfunction] +fn rgb_to_bgr_numpy<'py>(py: Python<'py>, input_array: &Bound<'py, PyAny>) -> PyResult> { + if let Ok(array) = input_array.extract::>() { + let array_view = array.as_array(); + let (height, width, channels) = array_view.dim(); + if channels != 3 { + return Err(PyValueError::new_err(format!( + "expected a HxWx3 NumPy array, got last dimension {}", + channels + ))); + } + let data = array_view.iter().copied().collect::>(); + let converted = + layout::rgb_to_bgr(&data, height as u32, width as u32).map_err(map_layout_error)?; + return array3_to_pyobject(py, converted, (height, width, channels)); + } + + if let Ok(array) = input_array.extract::>() { + let array_view = array.as_array(); + let (height, width, channels) = array_view.dim(); + if channels != 3 { + return Err(PyValueError::new_err(format!( + "expected a HxWx3 NumPy array, got last dimension {}", + channels + ))); + } + let data = array_view.iter().copied().collect::>(); + let converted = + layout::rgb_to_bgr(&data, height as u32, width as u32).map_err(map_layout_error)?; + return array3_to_pyobject(py, converted, (height, width, channels)); + } + + Err(PyValueError::new_err( + "expected a HxWx3 NumPy array with dtype uint8 or float32", + )) +} + +#[pyfunction] +fn nhwc_to_nchw_numpy<'py>( + py: Python<'py>, + input_array: &Bound<'py, PyAny>, +) -> PyResult> { + if let Ok(array) = input_array.extract::>() { + let array_view = array.as_array(); + let (batch, height, width, channels) = array_view.dim(); + let data = array_view.iter().copied().collect::>(); + let converted = layout::nhwc_to_nchw( + &data, + batch as u32, + height as u32, + width as u32, + channels as u32, + ) + .map_err(map_layout_error)?; + return array4_to_pyobject(py, converted, (batch, channels, height, width)); + } + + if let Ok(array) = input_array.extract::>() { + let array_view = array.as_array(); + let (batch, height, width, channels) = array_view.dim(); + let data = array_view.iter().copied().collect::>(); + let converted = layout::nhwc_to_nchw( + &data, + batch as u32, + height as u32, + width as u32, + channels as u32, + ) + .map_err(map_layout_error)?; + return array4_to_pyobject(py, converted, (batch, channels, height, width)); + } + + Err(PyValueError::new_err( + "expected a NxHxWxC NumPy array with dtype uint8 or float32", + )) +} + +#[pyfunction] +fn nchw_to_nhwc_numpy<'py>( + py: Python<'py>, + input_array: &Bound<'py, PyAny>, +) -> PyResult> { + if let Ok(array) = input_array.extract::>() { + let array_view = array.as_array(); + let (batch, channels, height, width) = array_view.dim(); + let data = array_view.iter().copied().collect::>(); + let converted = layout::nchw_to_nhwc( + &data, + batch as u32, + channels as u32, + height as u32, + width as u32, + ) + .map_err(map_layout_error)?; + return array4_to_pyobject(py, converted, (batch, height, width, channels)); + } + + if let Ok(array) = input_array.extract::>() { + let array_view = array.as_array(); + let (batch, channels, height, width) = array_view.dim(); + let data = array_view.iter().copied().collect::>(); + let converted = layout::nchw_to_nhwc( + &data, + batch as u32, + channels as u32, + height as u32, + width as u32, + ) + .map_err(map_layout_error)?; + return array4_to_pyobject(py, converted, (batch, height, width, channels)); + } + + Err(PyValueError::new_err( + "expected a NxCxHxW NumPy array with dtype uint8 or float32", + )) +} + #[pyfunction(name = "iou")] fn iou_py(a: (f32, f32, f32, f32), b: (f32, f32, f32, f32)) -> f32 { bbox::iou( @@ -1019,6 +1214,11 @@ fn rusty_cv(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(center_crop_image_numpy, m)?)?; m.add_function(wrap_pyfunction!(resize_image, m)?)?; m.add_function(wrap_pyfunction!(resize_image_numpy, m)?)?; + m.add_function(wrap_pyfunction!(hwc_to_chw_numpy, m)?)?; + m.add_function(wrap_pyfunction!(chw_to_hwc_numpy, m)?)?; + m.add_function(wrap_pyfunction!(rgb_to_bgr_numpy, m)?)?; + m.add_function(wrap_pyfunction!(nhwc_to_nchw_numpy, m)?)?; + m.add_function(wrap_pyfunction!(nchw_to_nhwc_numpy, m)?)?; m.add_function(wrap_pyfunction!(letterbox_image, m)?)?; m.add_function(wrap_pyfunction!(letterbox_image_numpy, m)?)?; m.add_function(wrap_pyfunction!(normalize_image_numpy, m)?)?;