Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
.DS_Store

# Compiled object and shared library files
src/*.o
src/*.so
src/*.dll
17 changes: 13 additions & 4 deletions R/swarmbHIVE.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,29 @@ swarmbHIVE <- function(X,
if (task == "classification") {
predicted_labels <- model$assignments
actual_labels <- y


# Build the confusion matrix over the union of actual and predicted
# labels so it is always square. Without shared levels, a class that the
# model never predicts is absent from the columns, and indexing tb[cl, cl]
# or sum(tb[, cl]) for that class throws "subscript out of bounds".
lvls <- sort(unique(c(as.character(actual_labels),
as.character(predicted_labels))))
actual_f <- factor(as.character(actual_labels), levels=lvls)
predicted_f <- factor(as.character(predicted_labels), levels=lvls)

if (metric == "accuracy") {
return(mean(predicted_labels == actual_labels))
} else if (metric == "balanced_accuracy") {
# Balanced accuracy across classes
# For multi-class, we can do macro-average recall
tbl <- table(actual_labels, predicted_labels)
tbl <- table(actual_f, predicted_f)
# row = actual, col = predicted
recalls <- diag(prop.table(tbl, margin=1))
return(mean(recalls, na.rm=TRUE))
} else if (metric == "f1") {
# For multi-class, compute macro-F1
# F1_class_i = 2 * precision_i * recall_i / (precision_i + recall_i)
tb <- table(actual_labels, predicted_labels)
tb <- table(actual_f, predicted_f)
# row = actual, col = pred
f1s <- c()
for (cl in rownames(tb)) {
Expand All @@ -172,7 +181,7 @@ swarmbHIVE <- function(X,
return(mean(f1s, na.rm=TRUE))
} else if (metric == "kappa") {
# Cohen's Kappa
tb <- table(actual_labels, predicted_labels)
tb <- table(actual_f, predicted_f)
n <- sum(tb)
p0 <- sum(diag(tb)) / n
# Expected agreement under random chance
Expand Down
1 change: 1 addition & 0 deletions src/Makevars
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
CXX_STD = CXX17
PKG_LIBS = $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)
1 change: 1 addition & 0 deletions src/Makevars.win
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
CXX_STD = CXX17
PKG_LIBS = $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)
Loading