-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathPrecisionRecall.cpp
More file actions
95 lines (75 loc) · 2.93 KB
/
PrecisionRecall.cpp
File metadata and controls
95 lines (75 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include "PrecisionRecall.h"
using namespace std;
static void computePrecisionRecallForThreshold(const vector<float> >, const vector<float>& preds,
float threshold, int nGroundTruthDetections, float& precision, float& recall)
{
// Compute tp, fp, fn, tn
int truePos = 0, trueNeg = 0, falsePos = 0, falseNeg = 0;
for(int i = 0; i < preds.size(); i++) {
if(preds[i] > threshold) {
if(gt[i] > 0) truePos++;
else falsePos++;
} else if(preds[i] <= threshold) {
if(gt[i] < 0) trueNeg++;
else falseNeg++;
}
}
if(truePos + falsePos == 0) precision = 1.0;
else precision = float(truePos) / (truePos + falsePos);
int nGt = (nGroundTruthDetections >= 0)? nGroundTruthDetections:(truePos + falseNeg);
if(truePos + falseNeg == 0) recall = 1.0;
else recall = float(truePos) / nGt;
}
bool sortByRecall(const PrecisionRecallPoint& a, const PrecisionRecallPoint& b)
{
return a.recall < b.recall;
}
PrecisionRecall::PrecisionRecall(const std::vector<float> >, const std::vector<float>& preds, int nGroundTruthDetections)
{
std::set<float> thresholds;
for (int i = 0; i < preds.size(); ++i) {
thresholds.insert(preds[i]);
}
_data.resize(0);
for(std::set<float>::iterator th = thresholds.begin(); th != thresholds.end(); th++) {
PrecisionRecallPoint pr;
computePrecisionRecallForThreshold(gt, preds, *th, nGroundTruthDetections, pr.precision, pr.recall);
pr.threshold = *th;
_data.push_back(pr);
}
std::sort(_data.begin(), _data.end(), sortByRecall);
// Remove jags in precision recall curve
float maxPrecision = -1;
for(std::vector<PrecisionRecallPoint>::reverse_iterator pr = _data.rbegin(); pr != _data.rend(); pr++) {
pr->precision = max(maxPrecision, pr->precision);
maxPrecision = max(pr->precision, maxPrecision);
}
// Compute average precision as area under the curve
_averagePrecision = 0.0;
for(std::vector<PrecisionRecallPoint>::iterator pr = _data.begin() + 1, prPrev = _data.begin(); pr != _data.end(); pr++, prPrev++) {
float xdiff = pr->recall - prPrev->recall;
float ydiff = pr->precision - prPrev->precision;
_averagePrecision += xdiff * prPrev->precision + xdiff * ydiff / 2.0;
}
}
void PrecisionRecall::save(const char* filename) const
{
std::ofstream f(filename);
if(f.bad()) throw std::runtime_error("ERROR: Could not open file for writing");
f << "# precision recall threshold\n";
for(std::vector<PrecisionRecallPoint>::const_iterator pr = _data.begin(); pr != _data.end(); pr++) {
f << pr->precision << " " << pr->recall << " " << pr->threshold << "\n";
}
}
double PrecisionRecall::getBestThreshold() const
{
double bestFMeasure = -1, bestThreshold = -1;
for(std::vector<PrecisionRecallPoint>::const_iterator pr = _data.begin(); pr != _data.end(); pr++) {
double fMeasure = (pr->precision * pr->recall) / (pr->precision + pr->recall);
if(fMeasure > bestFMeasure) {
bestFMeasure = fMeasure;
bestThreshold = pr->threshold;
}
}
return bestThreshold;
}