diff --git a/losses.py b/losses.py index fe7aeeb..864f9dd 100755 --- a/losses.py +++ b/losses.py @@ -18,7 +18,8 @@ def hard_negative_mining(loss, gt_confs, neg_ratio): pos_idx = gt_confs > 0 num_pos = tf.reduce_sum(tf.dtypes.cast(pos_idx, tf.int32), axis=1) num_neg = num_pos * neg_ratio - + + loss = tf.where(pos_idx, 0.0, loss) rank = tf.argsort(loss, axis=1, direction='DESCENDING') rank = tf.argsort(rank, axis=1) neg_idx = rank < tf.expand_dims(num_neg, 1)