diff --git a/visual_callbacks.py b/visual_callbacks.py index 974ed59..85869c2 100644 --- a/visual_callbacks.py +++ b/visual_callbacks.py @@ -114,7 +114,7 @@ def on_epoch_end(self, epoch, logs={}): pred = self.model.predict(self.X_val) max_pred = np.argmax(pred, axis=1) max_y = np.argmax(self.Y_val, axis=1) - cnf_mat = confusion_matrix(max_y, max_pred) + cnf_mat = confusion_matrix(max_y, max_pred, labels=range(len(self.classes))) if self.normalize: cnf_mat = cnf_mat.astype('float') / cnf_mat.sum(axis=1)[:, np.newaxis]