diff --git a/detection/validator/reward.py b/detection/validator/reward.py index 6f5e870..4f5053b 100644 --- a/detection/validator/reward.py +++ b/detection/validator/reward.py @@ -36,11 +36,14 @@ def reward(y_pred: np.array, y_true: np.array) -> float: """ preds = np.round(y_pred) - # accuracy = accuracy_score(y_true, preds) - cm = confusion_matrix(y_true, preds) + if len(y_true) == 0: + return 0, {'fp_score': 0, 'f1_score': 0, 'ap_score': 0} + + # Handle single-class case where confusion_matrix returns 1x1 + cm = confusion_matrix(y_true, preds, labels=[0, 1]) tn, fp, fn, tp = cm.ravel() - f1 = f1_score(y_true, preds) - ap_score = average_precision_score(y_true, y_pred) + f1 = f1_score(y_true, preds, zero_division=0) + ap_score = average_precision_score(y_true, y_pred) if len(np.unique(y_true)) > 1 else 0.0 res = {'fp_score': 1 - fp / len(y_pred), 'f1_score': f1,