Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 5 additions & 26 deletions ctlearn/tools/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,6 @@ class TrainCTLearnModel(Tool):
).tag(config=True)


overwrite = Bool(help="Overwrite output dir if it exists").tag(config=True)

aliases = {
"signal": "TrainCTLearnModel.input_dir_signal",
"background": "TrainCTLearnModel.input_dir_background",
Expand All @@ -239,33 +237,14 @@ class TrainCTLearnModel(Tool):
("o", "output"): "TrainCTLearnModel.output_dir",
}

flags = {
"overwrite": (
{"TrainCTLearnModel": {"overwrite": True}},
"Overwrite existing files",
),
}

classes = (
[
CTLearnModel,
DLDataReader,
]
+ classes_with_traits(CTLearnModel)
+ classes_with_traits(DLDataReader)
)
classes = classes_with_traits(CTLearnModel) + classes_with_traits(DLDataReader)

def setup(self):
# Check if the output directory exists and if it should be overwritten
# Check if the output directory exists
if self.output_dir.exists():
if not self.overwrite:
raise ToolConfigurationError(
f"Output directory {self.output_dir} already exists. Use --overwrite to overwrite."
)
else:
# Remove the output directory if it exists
self.log.info("Removing existing output directory %s", self.output_dir)
shutil.rmtree(self.output_dir)
raise ToolConfigurationError(
f"Output directory {self.output_dir} already exists."
)
# Create a MirroredStrategy.
self.strategy = tf.distribute.MirroredStrategy()
atexit.register(self.strategy._extended._collective_ops._lock.locked) # type: ignore
Expand Down