From 7df449f69cb6f740e6267fe66df24f7bffdfbd7c Mon Sep 17 00:00:00 2001 From: Moshe Plotkin Date: Wed, 16 Jan 2019 11:11:53 -0500 Subject: [PATCH] Fixed issue with confusion_matrix When test data does not contain all of the classes, the column ad row for that class was being skipped. Passing a list of classes to sklearn, fixes the problem. --- visual_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/visual_callbacks.py b/visual_callbacks.py index 974ed59..85869c2 100644 --- a/visual_callbacks.py +++ b/visual_callbacks.py @@ -114,7 +114,7 @@ def on_epoch_end(self, epoch, logs={}): pred = self.model.predict(self.X_val) max_pred = np.argmax(pred, axis=1) max_y = np.argmax(self.Y_val, axis=1) - cnf_mat = confusion_matrix(max_y, max_pred) + cnf_mat = confusion_matrix(max_y, max_pred, labels=range(len(self.classes))) if self.normalize: cnf_mat = cnf_mat.astype('float') / cnf_mat.sum(axis=1)[:, np.newaxis]