From 3f881ca2b3763ede778239a87eff1cfc98d95cd8 Mon Sep 17 00:00:00 2001 From: Weijia Liu Date: Sat, 30 Oct 2021 18:43:14 +0800 Subject: [PATCH] * FIX: fix hard negative mining --- losses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/losses.py b/losses.py index fe7aeeb..864f9dd 100755 --- a/losses.py +++ b/losses.py @@ -18,7 +18,8 @@ def hard_negative_mining(loss, gt_confs, neg_ratio): pos_idx = gt_confs > 0 num_pos = tf.reduce_sum(tf.dtypes.cast(pos_idx, tf.int32), axis=1) num_neg = num_pos * neg_ratio - + + loss = tf.where(pos_idx, 0.0, loss) rank = tf.argsort(loss, axis=1, direction='DESCENDING') rank = tf.argsort(rank, axis=1) neg_idx = rank < tf.expand_dims(num_neg, 1)