-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathroc.c
More file actions
79 lines (71 loc) · 2.34 KB
/
roc.c
File metadata and controls
79 lines (71 loc) · 2.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include "roc.h"
#include <stdlib.h>
#include <stdint.h>
// formula described here: https://dyakonov.org/2017/07/28/auc-roc-площадь-под-кривой-ошибок/
typedef struct {
int label;
int actual;
double pred;
} RocItem;
int items_compare(const void* va, const void* vb) {
const RocItem* a = (const RocItem*) va;
const RocItem* b = (const RocItem*) vb;
if (a->label != b->label) {
return a->label < b->label ? -1 : 1;
}
if (!EPS_EQUAL(a->pred, b->pred)) {
return a->pred < b->pred ? -1 : 1;
}
return a->actual < b->actual ? -1 : 1;
}
int _roc_auc(const RocItem* arr, size_t n, double* result) {
size_t i, next_num_it;
size_t prev_num_zeroes_count = 0;
size_t cur_num_zeroes_count = 0;
uint64_t num_wholes = 0, num_halfs = 0;
for (i = 0; i < n; i++) {
double cur = arr[i].pred;
if (i == 0 || !EPS_EQUAL(cur, arr[i - 1].pred)) {
prev_num_zeroes_count += cur_num_zeroes_count;
cur_num_zeroes_count = 0;
for (next_num_it = i; next_num_it < n && EPS_EQUAL(arr[next_num_it].pred, cur); next_num_it++) {
cur_num_zeroes_count += arr[next_num_it].actual == 0;
}
}
if (arr[i].actual) {
num_wholes += prev_num_zeroes_count;
num_halfs += cur_num_zeroes_count;
}
}
prev_num_zeroes_count += cur_num_zeroes_count;
if (prev_num_zeroes_count == 0 || prev_num_zeroes_count == n) {
return 0;
}
*result = (num_wholes + num_halfs * 0.5) / ((uint64_t)prev_num_zeroes_count * (n - prev_num_zeroes_count));
return 1;
}
double mean_roc_auc(const int* labels, const int* actual, const double* pred, size_t n) {
RocItem* arr = malloc(sizeof(RocItem) * n);
size_t i;
for (i = 0; i < n; i++) {
arr[i].label = labels[i];
arr[i].actual = actual[i];
arr[i].pred = pred[i];
}
qsort(arr, n, sizeof(RocItem), items_compare);
double sum = 0;
size_t count = 0;
size_t prev_start = 0;
for (i = 1; i <= n; i++) {
if (i == n || arr[i].label != arr[i - 1].label) {
double val;
if (_roc_auc(arr + prev_start, i - prev_start, &val)) {
sum += val;
count++;
}
prev_start = i;
}
}
free(arr);
return sum / count;
}