Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions scripts/python_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,38 @@ def main() -> None:
assert multiclass_soft["indices"].tolist() == [0, 1, 2, 1]
assert multiclass_soft["class_ids"].tolist() == [0, 1, 1, 0]
assert np.isclose(float(multiclass_soft["scores"][3]), 0.25546217, atol=1e-6)

postprocessed = rusty_cv.postprocess_detections(
np.array(
[
[160.0, 240.0, 480.0, 400.0],
[170.0, 250.0, 490.0, 410.0],
[100.0, 240.0, 140.0, 280.0],
],
dtype=np.float32,
),
np.array([0.95, 0.90, 0.70], dtype=np.float32),
class_ids=np.array([0, 0, 1], dtype=np.int64),
geometry_mode="letterbox",
processed_width=640,
processed_height=640,
original_width=400,
original_height=200,
clip=True,
)
assert postprocessed["indices"].tolist() == [0, 2]
assert postprocessed["class_ids"].tolist() == [0, 1]
assert np.allclose(
postprocessed["boxes"],
np.array(
[
[100.0, 50.0, 300.0, 150.0],
[62.5, 50.0, 87.5, 75.0],
],
dtype=np.float32,
),
)

assert np.isclose(rusty_cv.iou((0.0, 0.0, 10.0, 10.0), (5.0, 5.0, 15.0, 15.0)), 25.0 / 175.0)

print("python smoke test: ok")
Expand Down
322 changes: 322 additions & 0 deletions src/bbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,63 @@ pub struct BoxFilterResult {
pub indices: Vec<usize>,
}

/// Geometry remapping mode for fused detection postprocessing.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BoxRemap {
None,
Current {
width: u32,
height: u32,
},
Resize {
processed_width: u32,
processed_height: u32,
original_width: u32,
original_height: u32,
},
Letterbox {
processed_width: u32,
processed_height: u32,
original_width: u32,
original_height: u32,
},
}

/// Options for fused detection postprocessing.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct PostprocessOptions {
pub iou_threshold: f32,
pub score_threshold: f32,
pub pre_nms_top_k: Option<usize>,
pub max_detections: Option<usize>,
pub min_width: f32,
pub min_height: f32,
pub clip: bool,
}

impl Default for PostprocessOptions {
fn default() -> Self {
Self {
iou_threshold: 0.5,
score_threshold: f32::NEG_INFINITY,
pre_nms_top_k: None,
max_detections: None,
min_width: 0.0,
min_height: 0.0,
clip: false,
}
}
}

/// Result returned by fused detection postprocessing.
#[derive(Debug, Clone, PartialEq)]
pub struct PostprocessResult {
pub boxes: Vec<BBoxXYXY>,
pub detections: Vec<Detection>,
}

type ClipBounds = Option<(u32, u32)>;
type RemappedBoxes = (Vec<BBoxXYXY>, ClipBounds);
/// Errors for box postprocessing operations.
#[derive(Debug, Clone, PartialEq)]
pub enum BBoxError {
Expand Down Expand Up @@ -502,6 +559,79 @@ pub fn unletterbox_boxes(
.collect())
}

/// Run fused detection postprocessing over remapped boxes and class-aware NMS.
pub fn postprocess_detections(
boxes: &[BBoxXYXY],
scores: &[f32],
class_ids: &[usize],
remap: BoxRemap,
options: &PostprocessOptions,
soft_nms_options: Option<&SoftNmsOptions>,
) -> Result<PostprocessResult, BBoxError> {
validate_postprocess_inputs(boxes, scores, class_ids, options)?;

let (mut remapped_boxes, clip_bounds) = remap_boxes_for_postprocess(boxes, remap)?;
if options.clip {
if let Some((width, height)) = clip_bounds {
remapped_boxes = clip_boxes(&remapped_boxes, width, height)?;
}
}

let candidate_indices =
filter_boxes_by_min_size(&remapped_boxes, options.min_width, options.min_height)?;
let candidate_boxes = candidate_indices
.iter()
.map(|&index| remapped_boxes[index])
.collect::<Vec<_>>();
let candidate_scores = candidate_indices
.iter()
.map(|&index| scores[index])
.collect::<Vec<_>>();
let candidate_class_ids = candidate_indices
.iter()
.map(|&index| class_ids[index])
.collect::<Vec<_>>();

let detections = if let Some(soft_options) = soft_nms_options {
batched_soft_nms(
&candidate_boxes,
&candidate_scores,
&candidate_class_ids,
soft_options,
)?
} else {
let nms_options = NmsOptions {
iou_threshold: options.iou_threshold,
score_threshold: options.score_threshold,
pre_nms_top_k: options.pre_nms_top_k,
max_detections: options.max_detections,
};
batched_nms(
&candidate_boxes,
&candidate_scores,
&candidate_class_ids,
&nms_options,
)?
};

let mut ordered_boxes = Vec::with_capacity(detections.len());
let mut mapped_detections = Vec::with_capacity(detections.len());
for detection in detections {
let original_index = candidate_indices[detection.box_index];
ordered_boxes.push(remapped_boxes[original_index]);
mapped_detections.push(Detection {
box_index: original_index,
class_id: detection.class_id,
score: detection.score,
});
}

Ok(PostprocessResult {
boxes: ordered_boxes,
detections: mapped_detections,
})
}

/// Run single-class non-maximum suppression with custom options.
///
/// Returns the kept indices in descending score order.
Expand Down Expand Up @@ -828,6 +958,82 @@ fn validate_nms_inputs(
Ok(())
}

fn validate_postprocess_inputs(
boxes: &[BBoxXYXY],
scores: &[f32],
class_ids: &[usize],
options: &PostprocessOptions,
) -> Result<(), BBoxError> {
if boxes.len() != scores.len() {
return Err(BBoxError::LengthMismatch {
boxes: boxes.len(),
scores: scores.len(),
});
}

if boxes.len() != class_ids.len() {
return Err(BBoxError::ClassLengthMismatch {
boxes: boxes.len(),
class_ids: class_ids.len(),
});
}

let nms_options = NmsOptions {
iou_threshold: options.iou_threshold,
score_threshold: options.score_threshold,
pre_nms_top_k: options.pre_nms_top_k,
max_detections: options.max_detections,
};
validate_thresholds(&nms_options)?;
validate_boxes(boxes)?;
validate_scores(scores)?;
validate_min_size(options.min_width, options.min_height)?;
Ok(())
}

fn remap_boxes_for_postprocess(
boxes: &[BBoxXYXY],
remap: BoxRemap,
) -> Result<RemappedBoxes, BBoxError> {
match remap {
BoxRemap::None => Ok((boxes.to_vec(), None)),
BoxRemap::Current { width, height } => {
validate_image_size(width, height)?;
Ok((boxes.to_vec(), Some((width, height))))
}
BoxRemap::Resize {
processed_width,
processed_height,
original_width,
original_height,
} => Ok((
resize_boxes(
boxes,
processed_width,
processed_height,
original_width,
original_height,
)?,
Some((original_width, original_height)),
)),
BoxRemap::Letterbox {
processed_width,
processed_height,
original_width,
original_height,
} => Ok((
unletterbox_boxes(
boxes,
original_width,
original_height,
processed_width,
processed_height,
)?,
Some((original_width, original_height)),
)),
}
}

fn validate_thresholds(options: &NmsOptions) -> Result<(), BBoxError> {
if !(0.0..=1.0).contains(&options.iou_threshold) {
return Err(BBoxError::InvalidIouThreshold(options.iou_threshold));
Expand Down Expand Up @@ -1191,6 +1397,122 @@ mod tests {
);
}

#[test]
fn postprocesses_current_space_boxes() {
let boxes = vec![
BBoxXYXY::from_xywh(-5.0, 0.0, 8.0, 8.0),
BBoxXYXY::from_xywh(0.0, 0.0, 1.0, 1.0),
BBoxXYXY::from_xywh(2.0, 2.0, 4.0, 4.0),
];
let scores = vec![0.9, 0.8, 0.7];
let class_ids = vec![0usize, 0usize, 0usize];
let options = PostprocessOptions {
min_width: 2.0,
min_height: 2.0,
clip: true,
..PostprocessOptions::default()
};

let result = postprocess_detections(
&boxes,
&scores,
&class_ids,
BoxRemap::Current {
width: 4,
height: 4,
},
&options,
None,
)
.unwrap();

assert_eq!(result.detections.len(), 2);
assert_eq!(result.detections[0].box_index, 0);
assert_eq!(result.detections[1].box_index, 2);
assert_eq!(
result.boxes,
vec![
BBoxXYXY {
x1: 0.0,
y1: 0.0,
x2: 3.0,
y2: 4.0,
},
BBoxXYXY {
x1: 2.0,
y1: 2.0,
x2: 4.0,
y2: 4.0,
},
]
);
}

#[test]
fn postprocesses_letterboxed_boxes_back_to_original_space() {
let boxes = vec![
BBoxXYXY {
x1: 160.0,
y1: 240.0,
x2: 480.0,
y2: 400.0,
},
BBoxXYXY {
x1: 170.0,
y1: 250.0,
x2: 490.0,
y2: 410.0,
},
BBoxXYXY {
x1: 100.0,
y1: 240.0,
x2: 140.0,
y2: 280.0,
},
];
let scores = vec![0.95, 0.90, 0.70];
let class_ids = vec![0usize, 0usize, 1usize];

let result = postprocess_detections(
&boxes,
&scores,
&class_ids,
BoxRemap::Letterbox {
processed_width: 640,
processed_height: 640,
original_width: 400,
original_height: 200,
},
&PostprocessOptions {
clip: true,
..PostprocessOptions::default()
},
None,
)
.unwrap();

assert_eq!(result.detections.len(), 2);
assert_eq!(result.detections[0].box_index, 0);
assert_eq!(result.detections[1].box_index, 2);
assert_eq!(
result.boxes,
vec![
BBoxXYXY {
x1: 100.0,
y1: 50.0,
x2: 300.0,
y2: 150.0,
},
BBoxXYXY {
x1: 62.5,
y1: 50.0,
x2: 87.5,
y2: 75.0,
},
]
);
}

#[test]
fn keeps_highest_scoring_boxes() {
let boxes = vec![
Expand Down
Loading