diff --git a/ctlearn/tools/predict_model.py b/ctlearn/tools/predict_model.py index 4560ecf8..562bf2ff 100644 --- a/ctlearn/tools/predict_model.py +++ b/ctlearn/tools/predict_model.py @@ -471,7 +471,7 @@ def _predict_with_model(self, model_path): # Load the model from the specified path model = keras.saving.load_model(model_path) prediction_colname = ( - model.layers[-1].name if model.layers[-1].name != "softmax" else "type" + "type" if isinstance(model.layers[-1], keras.layers.Softmax) else model.layers[-1].name ) backbone_model, feature_vectors = None, None if self.dl1_features: