forked from treble-maker123/deep-face-hashing
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcalc_pre_rec.py
More file actions
31 lines (28 loc) · 1.01 KB
/
calc_pre_rec.py
File metadata and controls
31 lines (28 loc) · 1.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from sklearn.metrics import precision_recall_curve
from pdb import set_trace
def calc_pre_rec(hamm_dist, gt, radius):
'''
Calculates the precision-recall curve values.
'''
# distance within radius counts as 0
dist = hamm_dist * (hamm_dist > radius)
# normalize the distance values, so the smaller distance, the closer to 1
max_val = dist.max()
scores = ((max_val - dist) / max_val) ** 2
scores[scores != scores] = 1
# calculate the "micro average" of the curves
pre_curve, rec_curve, _ = precision_recall_curve(gt.ravel(), scores.ravel())
# pred == 1 is what the model believes to be the person
pred = (dist == 0).astype("int8")
# true positives
tp = (pred * gt).sum(axis=0)
# recall
rec = tp / gt.sum(axis=0)
rec[rec != rec] = 0
# precision
pre = tp / pred.sum(axis=0)
pre[pre != pre] = 0
# harmonic mean
hmean = 2 * (pre * rec) / (pre + rec)
hmean[hmean != hmean] = 0
return pre.mean(), rec.mean(), hmean.mean(), pre_curve, rec_curve