diff --git a/ctlearn/tools/predict_model.py b/ctlearn/tools/predict_model.py index 32819a8a..6a32f496 100644 --- a/ctlearn/tools/predict_model.py +++ b/ctlearn/tools/predict_model.py @@ -973,9 +973,6 @@ def _create_feature_vectors_table( """ # Create the feature vector table feature_vector_table = example_identifiers.copy() - feature_vector_table.remove_columns( - ["pointing_azimuth", "pointing_altitude", "time"] - ) columns_list, shapes_list = [], [] if classification_feature_vectors is not None: is_valid_col = ~np.isnan( @@ -1007,6 +1004,9 @@ def _create_feature_vectors_table( ) ) if direction_feature_vectors is not None: + feature_vector_table.remove_columns( + ["pointing_azimuth", "pointing_altitude", "time"] + ) is_valid_col = ~np.isnan( np.min(direction_feature_vectors, axis=1), dtype=bool )