diff --git a/ctlearn/tools/predict_LST1.py b/ctlearn/tools/predict_LST1.py index 4a385e57..e27e68c6 100644 --- a/ctlearn/tools/predict_LST1.py +++ b/ctlearn/tools/predict_LST1.py @@ -604,7 +604,7 @@ def start(self): colname, colname.replace("_tel", "") ) classification_subarray_table.add_column( - classification_is_valid[np.newaxis], name=f"{self.prefix}_telescopes" + [[val] for val in classification_is_valid], name=f"{self.prefix}_telescopes" ) # Save the prediction to the output file write_table( @@ -685,7 +685,7 @@ def start(self): colname, colname.replace("_tel", "") ) energy_subarray_table.add_column( - energy_is_valid[np.newaxis], name=f"{self.prefix}_telescopes" + [[val] for val in energy_is_valid], name=f"{self.prefix}_telescopes" ) # Save the prediction to the output file write_table( @@ -803,7 +803,7 @@ def start(self): colname, colname.replace("_tel", "") ) direction_subarray_table.add_column( - direction_is_valid[np.newaxis], name=f"{self.prefix}_telescopes" + [[val] for val in direction_is_valid], name=f"{self.prefix}_telescopes" ) # Save the prediction to the output file write_table(