diff --git a/ctlearn/tools/predict_model.py b/ctlearn/tools/predict_model.py index afca2380..0eaa8b23 100644 --- a/ctlearn/tools/predict_model.py +++ b/ctlearn/tools/predict_model.py @@ -18,6 +18,7 @@ vstack, join, setdiff, + unique, ) from ctapipe.containers import ( @@ -112,7 +113,7 @@ class PredictCTLearnModel(Tool): load_cameradirection_model_from : pathlib.Path Path to a Keras model file (Keras3) or directory (Keras2) for the regression of the primary particle arrival direction based on camera coordinate offsets. - load_cameradirection_model_from : pathlib.Path + load_skydirection_model_from : pathlib.Path Path to a Keras model file (Keras3) or directory (Keras2) for the regression of the primary particle arrival direction based on spherical coordinate offsets. output_path : pathlib.Path @@ -854,6 +855,28 @@ def _create_nan_table(self, nonexample_identifiers, columns, shapes): ) return nan_table + def deduplicate_first_valid( + self, + table: Table, + keys=('obs_id', 'event_id'), + valid_col='CTLearn_is_valid', + ): + """ + Return a deduplicated Astropy Table. + + For each group defined by `keys`, keep the first row where + `valid_col` is True. If none are valid, keep the first row. + """ + + t = table.copy() + + t.sort( + list(keys) + [valid_col], + reverse=[False] * len(keys) + [True] + ) + + return unique(t, keys=list(keys), keep='first') + def _store_pointing(self, all_identifiers): """ Store the telescope pointing table from to the output file. @@ -1253,6 +1276,12 @@ def start(self): classification_subarray_table[f"{self.prefix}_telescopes"] = ( reco_telescopes ) + # Deduplicate the subarray classification table to have only one entry per event + classification_subarray_table = super().deduplicate_first_valid( + table=classification_subarray_table, + keys=SUBARRAY_EVENT_KEYS, + valid_col=f"{self.prefix}_is_valid", + ) # Sort the subarray classification table classification_subarray_table.sort(SUBARRAY_EVENT_KEYS) # Save the prediction to the output file @@ -1381,6 +1410,12 @@ def start(self): energy_subarray_table[f"{self.prefix}_telescopes"] = ( reco_telescopes ) + # Deduplicate the subarray classification table to have only one entry per event + energy_subarray_table = super().deduplicate_first_valid( + table=energy_subarray_table, + keys=SUBARRAY_EVENT_KEYS, + valid_col=f"{self.prefix}_is_valid", + ) # Sort the subarray energy table energy_subarray_table.sort(SUBARRAY_EVENT_KEYS) # Save the prediction to the output file @@ -1537,6 +1572,12 @@ def start(self): direction_subarray_table[f"{self.prefix}_telescopes"] = ( reco_telescopes ) + # Deduplicate the subarray classification table to have only one entry per event + direction_subarray_table = super().deduplicate_first_valid( + table=direction_subarray_table, + keys=SUBARRAY_EVENT_KEYS, + valid_col=f"{self.prefix}_is_valid", + ) # Sort the subarray geometry table direction_subarray_table.sort(SUBARRAY_EVENT_KEYS) # Save the prediction to the output file @@ -1717,7 +1758,7 @@ def start(self): self.log.info("Starting the prediction...") classification_feature_vectors = None if self.load_type_model_from is not None: - # Predict the energy of the primary particle + # Predict the classification of the primary particle classification_table, classification_feature_vectors = ( super()._predict_classification(example_identifiers) ) @@ -1730,7 +1771,7 @@ def start(self): shapes=[(len(nonexample_identifiers),)], ) classification_table = vstack([classification_table, nan_table]) - # Add is_valid column to the energy table + # Add is_valid column to the classification table classification_table.add_column( ~np.isnan( classification_table[f"{self.prefix}_tel_prediction"].data, @@ -1745,6 +1786,12 @@ def start(self): classification_table.rename_column( f"{self.prefix}_tel_is_valid", f"{self.prefix}_is_valid" ) + # Deduplicate the subarray classification table to have only one entry per event + classification_table = super().deduplicate_first_valid( + table=classification_table, + keys=SUBARRAY_EVENT_KEYS, + valid_col=f"{self.prefix}_is_valid", + ) classification_table.sort(SUBARRAY_EVENT_KEYS) # Add the default values and meta data to the table add_defaults_and_meta( @@ -1793,6 +1840,12 @@ def start(self): energy_table.rename_column( f"{self.prefix}_tel_is_valid", f"{self.prefix}_is_valid" ) + # Deduplicate the subarray energy table to have only one entry per event + energy_table = super().deduplicate_first_valid( + table=energy_table, + keys=SUBARRAY_EVENT_KEYS, + valid_col=f"{self.prefix}_is_valid", + ) energy_table.sort(SUBARRAY_EVENT_KEYS) # Add the default values and meta data to the table add_defaults_and_meta( @@ -1845,6 +1898,12 @@ def start(self): ~np.isnan(direction_table[f"{self.prefix}_alt"].data, dtype=bool), name=f"{self.prefix}_is_valid", ) + # Deduplicate the subarray direction table to have only one entry per event + direction_table = super().deduplicate_first_valid( + table=direction_table, + keys=SUBARRAY_EVENT_KEYS, + valid_col=f"{self.prefix}_is_valid", + ) direction_table.sort(SUBARRAY_EVENT_KEYS) # Add the default values and meta data to the table add_defaults_and_meta(