From a6091b7a29ec0482ca366f824827d4f097f1b2c7 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Wed, 28 May 2025 16:19:52 -0400 Subject: [PATCH 1/9] Update training schema --- element_deeplabcut/train.py | 333 ++++++++++++------------------------ 1 file changed, 106 insertions(+), 227 deletions(-) diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index 1c78045..c1bc9fe 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -1,7 +1,10 @@ """ -Code adapted from the Mathis Lab -MIT License Copyright (c) 2022 Mackenzie Mathis -DataJoint Schema for DeepLabCut 2.x, Supports 2D and 3D DLC via triangulation. +Module for training DeepLabCut models within DataJoint. +---------------------------- +This module leverages DeepLabCut's built-in training functionality and captures the +results in DataJoint tables. +The user is expected to make changes to the DeepLabCut-generated config.yaml file and +the pytorch-config.yaml file to specify the training parameters prior to running this module. """ import datajoint as dj @@ -105,259 +108,135 @@ def get_dlc_processed_data_dir() -> str: @schema -class VideoSet(dj.Manual): - """Collection of videos included in a given training set. +class DLCTrainingTask(dj.Manual): + """Table for creating a DLC training task. Attributes: - video_set_id (int): Unique ID for collection of videos.""" - - definition = """ # Set of vids in training set - video_set_id: int + task_id (int): Primary key. Unique identifier for the training task. + project_path (str): Path to the DeepLabCut project directory. + dlc_config (longblob): DeepLabCut-generated config.yaml file. + pytorch_config (longblob): DeepLabCut-generated pytorch-config.yaml file. + shuffle (int): Shuffle for the training task. + trainingsetindex (int): Index of the training fraction in config.yaml. + snapshot_file (filepath): Optional. Path to latest snapshot file if available. """ - class File(dj.Part): - """File IDs and paths in a given VideoSet - - Attributes: - VideoSet (foreign key): VideoSet key. - file_path ( varchar(255) ): Path to file on disk relative to root.""" - - definition = """ # Paths of training files (e.g., labeled pngs, CSV or video) - -> master - file_id: int - --- - file_path: varchar(255) - """ - - -@schema -class TrainingParamSet(dj.Lookup): - """Parameters used to train a model - - Attributes: - paramset_idx (smallint): Index uniqely identifying paramset. - paramset_desc ( varchar(128) ): Description of paramset. - param_set_hash (uuid): Hash identifying this paramset. - params (longblob): Dictionary of all applicable parameters. - Note: param_set_hash must be unique.""" - definition = """ - # Parameters to specify a DLC model training instance - # For DLC ≤ v2.0, include scorer_legacy = True in params - paramset_idx : smallint + task_id: int --- - paramset_desc: varchar(128) - param_set_hash : uuid # hash identifying this parameterset - unique index (param_set_hash) - params : longblob # dictionary of all applicable parameters + project_path: varchar(255) # Path to the DeepLabCut project directory + dlc_config: longblob + pytorch_config: longblob + shuffle: int + trainingsetindex: int + snapshot_file=null: filepath@dlc_training """ - required_parameters = ("shuffle", "trainingsetindex") - skipped_parameters = ("project_path", "video_sets") - - @classmethod - def insert_new_params( - cls, paramset_desc: str, params: dict, paramset_idx: int = None - ): - """ - Insert a new set of training parameters into dlc.TrainingParamSet. - - Args: - paramset_desc (str): Description of parameter set to be inserted - params (dict): Dictionary including all settings to specify model training. - Must include shuffle & trainingsetindex b/c not in config.yaml. - project_path and video_sets will be overwritten by config.yaml. - Note that trainingsetindex is 0-indexed - paramset_idx (int): optional, integer to represent parameters. - """ - - for required_param in cls.required_parameters: - assert required_param in params, ( - "Missing required parameter: " + required_param - ) - for skipped_param in cls.skipped_parameters: - if skipped_param in params: - params.pop(skipped_param) - - if paramset_idx is None: - paramset_idx = ( - dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0 - ) + 1 - - param_dict = { - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid(params), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - # If the specified param-set already exists - if param_query: - existing_paramset_idx = param_query.fetch1("paramset_idx") - if existing_paramset_idx == int(paramset_idx): # If existing_idx same: - return # job done - else: - cls.insert1(param_dict) # if duplicate, will raise duplicate error - @schema -class TrainingTask(dj.Manual): - """Staging table for pairing videosets and training parameter sets +class DLCModelTraining(dj.Computed): + """Table for training DeepLabCut models. Attributes: - VideoSet (foreign key): VideoSet Key. - TrainingParamSet (foreign key): TrainingParamSet key. - training_id (int): Unique ID for training task. - model_prefix ( varchar(32) ): Optional. Prefix for model files. - project_path ( varchar(255) ): Optional. DLC's project_path in config relative - to get_dlc_root_data_dir - """ - - definition = """ # Specification for a DLC model training instance - -> VideoSet # labeled video(s) for training - -> TrainingParamSet - training_id : int - --- - model_prefix='' : varchar(32) - project_path='' : varchar(255) # DLC's project_path in config relative to root + task_id (int): Foreign key to task_id in TrainingTask table. + trained_config (longblob): DeepLabCut-generated config.yaml file after training. + trained_pytorch_config (longblob): DeepLabCut-generated pytorch-config.yaml file + after training. + training_log_file (filepath): Path to the train.txt file. + training_snapshot_file (filepath): Path to the latest snapshot file after training. """ - -@schema -class ModelTraining(dj.Computed): - """Automated Model training information. - - Attributes: - TrainingTask (foreign key): TrainingTask key. - latest_snapshot (int unsigned): Latest exact snapshot index (i.e., never -1). - config_template (longblob): Stored full config file.""" - definition = """ -> TrainingTask --- - latest_snapshot: int unsigned # latest exact snapshot index (i.e., never -1) - config_template: longblob # stored full config file + trained_pose_cfg: longblob + trained_pytorch_config: longblob + training_log_file: filepath@dlc_training + training_snapshot_file: filepath@dlc_training """ - # To continue from previous training snapshot, devs suggest editing pose_cfg.yml - # https://github.com/DeepLabCut/DeepLabCut/issues/70 - def make(self, key): - import deeplabcut - - try: - from deeplabcut.utils.auxiliaryfunctions import ( - get_model_folder, - edit_config, - ) # isort:skip - except ImportError: - from deeplabcut.utils.auxiliaryfunctions import ( - GetModelFolder as get_model_folder, - ) # isort:skip - - """Launch training for each train.TrainingTask training_id via `.populate()`.""" - project_path, model_prefix = (TrainingTask & key).fetch1( - "project_path", "model_prefix" + """Run model training after verifying that the config files match what was + ingested in the TrainingTask table.""" + from deeplabcut.compat import train_network + import yaml + import pathlib + + # Fetch the task entry from TrainingTask + project_dir, dlc_config_db, pytorch_config_db = (TrainingTask & key).fetch1( + "project_path", "dlc_config", "pytorch_config" ) + dlc_config_db = yaml.safe_load(dlc_config_db) + pytorch_config_db = yaml.safe_load(pytorch_config_db) - project_path = find_full_path(get_dlc_root_data_dir(), project_path) - - # ---- Build and save DLC configuration (yaml) file ---- - _, dlc_config = dlc_reader.read_yaml(project_path) # load existing - dlc_config.update((TrainingParamSet & key).fetch1("params")) - dlc_config.update( - { - "project_path": project_path.as_posix(), - "modelprefix": model_prefix, - "train_fraction": dlc_config["TrainingFraction"][ - int(dlc_config["trainingsetindex"]) - ], - "training_filelist_datajoint": [ # don't overwrite origin video_sets - find_full_path(get_dlc_root_data_dir(), fp).as_posix() - for fp in (VideoSet.File & key).fetch("file_path") - ], - } + # Locate the model folder config files + dlc_config_path = get_dlc_root_data_dir() / (project_dir + "config.yaml") + pytorch_config_path = get_dlc_root_data_dir() / ( + project_dir + "pytorch-config.yaml" ) - # Write dlc config file to base project folder - dlc_cfg_filepath = dlc_reader.save_yaml(project_path, dlc_config) - - # ---- Update the project path in the DLC pose configuration (yaml) files ---- - model_folder = get_model_folder( - trainFraction=dlc_config["train_fraction"], - shuffle=dlc_config["shuffle"], - cfg=dlc_config, - modelprefix=dlc_config["modelprefix"], - ) - model_train_folder = project_path / model_folder / "train" - - # update path of the init_weight - with open(model_train_folder / "pose_cfg.yaml", "r") as f: - pose_cfg = yaml.safe_load(f) - init_weights_path = Path(pose_cfg["init_weights"]) - - if ( - "pose_estimation_tensorflow/models/pretrained" - in init_weights_path.as_posix() - ): - # this is the res_net models, construct new path here - init_weights_path = ( - Path(deeplabcut.__path__[0]) - / "pose_estimation_tensorflow/models/pretrained" - / init_weights_path.name + + # Load the model folder config files + with open(dlc_config_path, "r") as f: + dlc_config_file = yaml.safe_load(f) + with open(pytorch_config_path, "r") as f: + pytorch_config_file = yaml.safe_load(f) + + # Compare the contents + if dlc_config_db != dlc_config_file: + raise ValueError( + f"Contents of DLC config file: {dlc_config_path} do not match the database config file." + ) + if pytorch_config_db != pytorch_config_file: + raise ValueError( + f"Contents of PyTorch config file: {pytorch_config_path} do not match the database config file." ) - else: - # this is existing snapshot weights, update path here - init_weights_path = model_train_folder / init_weights_path.name - edit_config( - model_train_folder / "pose_cfg.yaml", - { - "project_path": project_path.as_posix(), - "init_weights": init_weights_path.as_posix(), - "dataset": Path(pose_cfg["dataset"]).as_posix(), - "metadataset": Path(pose_cfg["metadataset"]).as_posix(), - }, + # Proceed with training if files match + trainingsetindex, shuffle = (TrainingTask & key).fetch1( + "trainingsetindex", "shuffle" ) - # ---- Trigger DLC model training job ---- - train_network_input_args = list( - inspect.signature(deeplabcut.train_network).parameters + train_network( + config=dlc_config_path.as_posix(), + shuffle=shuffle, + trainingsetindex=trainingsetindex, ) - train_network_kwargs = { - k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v - for k, v in dlc_config.items() - if k in train_network_input_args - } - for k in ["shuffle", "trainingsetindex", "maxiters"]: - train_network_kwargs[k] = int(train_network_kwargs[k]) - - try: - deeplabcut.train_network(dlc_cfg_filepath, **train_network_kwargs) - except KeyboardInterrupt: # Instructions indicate to train until interrupt - print("DLC training stopped via Keyboard Interrupt") - - # DLC goes by snapshot magnitude when judging 'latest' for evaluation - # Here, we mean most recently generated - snapshots = sorted(model_train_folder.glob("snapshot*.index")) - max_modified_time = 0 - for snapshot in snapshots: - modified_time = snapshot.stat().st_mtime - if modified_time > max_modified_time: - latest_snapshot_file = snapshot - latest_snapshot = int( - re.search(r"(\d+)\.index", latest_snapshot_file.name).group(1) - ) - max_modified_time = modified_time - - # update snapshotindex in the config - snapshotindex = snapshots.index(latest_snapshot_file) - - dlc_config["snapshotindex"] = snapshotindex - edit_config( - dlc_cfg_filepath, - {"snapshotindex": snapshotindex}, + + # Fetch the trained pose config and pytorch config + iteration = dlc_config_file["iteration"] + training_dir_path = pathlib.Path( + project_dir + f"/dlc-pytorch-models/iteration-{iteration}/" + ) + trained_config_path = next( + (get_dlc_processed_data_dir() / training_dir_path).rglob( + "*/train/pose_cfg.yaml" + ) + ) + trained_pytorch_config_path = next( + (get_dlc_processed_data_dir() / training_dir_path).rglob( + "*/train/pytorch-config.yaml" + ) ) + training_log_filepath = next( + (get_dlc_processed_data_dir() / training_dir_path).glob("*/train/train.txt") + ) + training_snapshot_file = sorted( + (get_dlc_processed_data_dir() / training_dir_path).glob( + "*/train/snapshot_*.pth" + ) + )[-1] + # Insert the results into ModelTraining table self.insert1( - {**key, "latest_snapshot": latest_snapshot, "config_template": dlc_config} + { + **key, + "trained_pose_cfg": yaml.safe_load(trained_config_path), + "trained_pytorch_config": yaml.safe_load(trained_pytorch_config_path), + "training_log_file": training_log_filepath.relative_to( + get_dlc_processed_data_dir() + ), + "training_snapshot_file": training_snapshot_file.relative_to( + get_dlc_processed_data_dir() + ), + } ) From e6904e45f25627131ab0e4e3764e402f1b626001 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Fri, 30 May 2025 10:48:58 -0400 Subject: [PATCH 2/9] fix '_' -> '-' for external store --- element_deeplabcut/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index c1bc9fe..6060567 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -129,7 +129,7 @@ class DLCTrainingTask(dj.Manual): pytorch_config: longblob shuffle: int trainingsetindex: int - snapshot_file=null: filepath@dlc_training + snapshot_file=null: filepath@dlc-training """ @@ -151,8 +151,8 @@ class DLCModelTraining(dj.Computed): --- trained_pose_cfg: longblob trained_pytorch_config: longblob - training_log_file: filepath@dlc_training - training_snapshot_file: filepath@dlc_training + training_log_file: filepath@dlc-training + training_snapshot_file: filepath@dlc-training """ def make(self, key): From 9ffa4dc78aa9a0124bc800b156afd6110a840f59 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Fri, 30 May 2025 11:41:53 -0400 Subject: [PATCH 3/9] Update foreign_key references --- element_deeplabcut/model.py | 6 +++--- element_deeplabcut/train.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/element_deeplabcut/model.py b/element_deeplabcut/model.py index dd458aa..6379606 100644 --- a/element_deeplabcut/model.py +++ b/element_deeplabcut/model.py @@ -339,7 +339,7 @@ class Model(dj.Manual): project_path : varchar(255) # DLC's project_path in config relative to root model_prefix='' : varchar(32) model_description='' : varchar(300) - -> [nullable] train.TrainingParamSet + -> [nullable] train.DLCTrainingTask """ # project_path is the only item required downstream in the pose schema @@ -365,7 +365,7 @@ def insert_new_model( trainingsetindex, model_description="", model_prefix="", - paramset_idx: int = None, + task_id: int = None, prompt=True, params=None, ): @@ -457,7 +457,7 @@ def insert_new_model( "trainingsetindex": int(trainingsetindex), "engine": engine, "project_path": project_path.relative_to(root_dir).as_posix(), - "paramset_idx": paramset_idx, + "paramset_idx": task_id, "config_template": dlc_config, } diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index 6060567..1da16d8 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -147,7 +147,7 @@ class DLCModelTraining(dj.Computed): """ definition = """ - -> TrainingTask + -> DLCTrainingTask --- trained_pose_cfg: longblob trained_pytorch_config: longblob From 869f5ac920e446dda04a1cdd77569aa1896d2d29 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Fri, 30 May 2025 13:12:58 -0400 Subject: [PATCH 4/9] Minor fixes to ModelTraining make function --- element_deeplabcut/train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index 1da16d8..901cef2 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -163,12 +163,10 @@ def make(self, key): import pathlib # Fetch the task entry from TrainingTask - project_dir, dlc_config_db, pytorch_config_db = (TrainingTask & key).fetch1( + project_dir, dlc_config_db, pytorch_config_db = (DLCTrainingTask & key).fetch1( "project_path", "dlc_config", "pytorch_config" ) - dlc_config_db = yaml.safe_load(dlc_config_db) - pytorch_config_db = yaml.safe_load(pytorch_config_db) - + # Locate the model folder config files dlc_config_path = get_dlc_root_data_dir() / (project_dir + "config.yaml") pytorch_config_path = get_dlc_root_data_dir() / ( From 4234346db7d9f9e8156c02982a386a8c5deede60 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Fri, 30 May 2025 13:57:34 -0400 Subject: [PATCH 5/9] Debug errors in make function --- element_deeplabcut/train.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index 901cef2..061abcd 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -166,38 +166,40 @@ def make(self, key): project_dir, dlc_config_db, pytorch_config_db = (DLCTrainingTask & key).fetch1( "project_path", "dlc_config", "pytorch_config" ) + root_dir = get_dlc_root_data_dir().parent / "outbox" + project_dir = pathlib.Path(project_dir.replace("\\","/")) - # Locate the model folder config files - dlc_config_path = get_dlc_root_data_dir() / (project_dir + "config.yaml") - pytorch_config_path = get_dlc_root_data_dir() / ( - project_dir + "pytorch-config.yaml" - ) - - # Load the model folder config files + # Locate and open the config file + dlc_config_path = root_dir / project_dir / "config.yaml" with open(dlc_config_path, "r") as f: dlc_config_file = yaml.safe_load(f) - with open(pytorch_config_path, "r") as f: - pytorch_config_file = yaml.safe_load(f) - + # Compare the contents if dlc_config_db != dlc_config_file: raise ValueError( f"Contents of DLC config file: {dlc_config_path} do not match the database config file." ) + # Locate and open the pytorch config file + iteration = dlc_config_file["iteration"] + pytorch_config_path = next((root_dir / project_dir / f"dlc-models-pytorch").glob(f"iteration-{iteration}/*/train/pytorch_config.yaml")) + with open(pytorch_config_path, "r") as f: + pytorch_config_file = yaml.safe_load(f) + + # Compare the contents if pytorch_config_db != pytorch_config_file: raise ValueError( f"Contents of PyTorch config file: {pytorch_config_path} do not match the database config file." ) # Proceed with training if files match - trainingsetindex, shuffle = (TrainingTask & key).fetch1( + training_set_index, train_shuffle = (DLCTrainingTask & key).fetch1( "trainingsetindex", "shuffle" ) train_network( config=dlc_config_path.as_posix(), - shuffle=shuffle, - trainingsetindex=trainingsetindex, + shuffle=train_shuffle, + trainingsetindex=training_set_index, ) # Fetch the trained pose config and pytorch config From f7b54b364938ec68ef6e88577ecaa9a32078ecd2 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Fri, 30 May 2025 14:01:26 -0400 Subject: [PATCH 6/9] Minor fix --- element_deeplabcut/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index 061abcd..3eda383 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -166,7 +166,7 @@ def make(self, key): project_dir, dlc_config_db, pytorch_config_db = (DLCTrainingTask & key).fetch1( "project_path", "dlc_config", "pytorch_config" ) - root_dir = get_dlc_root_data_dir().parent / "outbox" + root_dir = get_dlc_root_data_dir()[0].parent / "outbox" project_dir = pathlib.Path(project_dir.replace("\\","/")) # Locate and open the config file From 24cb3a6b404df9f1788dea5b96228b82ad99f1ad Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Fri, 30 May 2025 14:28:46 -0400 Subject: [PATCH 7/9] Debug network training errors --- element_deeplabcut/train.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index 3eda383..61aebd1 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -190,16 +190,15 @@ def make(self, key): raise ValueError( f"Contents of PyTorch config file: {pytorch_config_path} do not match the database config file." ) - - # Proceed with training if files match + training_set_index, train_shuffle = (DLCTrainingTask & key).fetch1( "trainingsetindex", "shuffle" ) train_network( config=dlc_config_path.as_posix(), - shuffle=train_shuffle, - trainingsetindex=training_set_index, + shuffle=int(train_shuffle), + trainingsetindex=int(training_set_index), ) # Fetch the trained pose config and pytorch config From 8d09b25a4f51aa4e571747cb739e48ef4d2e4520 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Fri, 30 May 2025 14:34:14 -0400 Subject: [PATCH 8/9] Black formatting --- element_deeplabcut/train.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index 61aebd1..446015d 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -167,13 +167,13 @@ def make(self, key): "project_path", "dlc_config", "pytorch_config" ) root_dir = get_dlc_root_data_dir()[0].parent / "outbox" - project_dir = pathlib.Path(project_dir.replace("\\","/")) - + project_dir = pathlib.Path(project_dir.replace("\\", "/")) + # Locate and open the config file dlc_config_path = root_dir / project_dir / "config.yaml" with open(dlc_config_path, "r") as f: dlc_config_file = yaml.safe_load(f) - + # Compare the contents if dlc_config_db != dlc_config_file: raise ValueError( @@ -181,7 +181,11 @@ def make(self, key): ) # Locate and open the pytorch config file iteration = dlc_config_file["iteration"] - pytorch_config_path = next((root_dir / project_dir / f"dlc-models-pytorch").glob(f"iteration-{iteration}/*/train/pytorch_config.yaml")) + pytorch_config_path = next( + (root_dir / project_dir / f"dlc-models-pytorch").glob( + f"iteration-{iteration}/*/train/pytorch_config.yaml" + ) + ) with open(pytorch_config_path, "r") as f: pytorch_config_file = yaml.safe_load(f) @@ -190,7 +194,7 @@ def make(self, key): raise ValueError( f"Contents of PyTorch config file: {pytorch_config_path} do not match the database config file." ) - + training_set_index, train_shuffle = (DLCTrainingTask & key).fetch1( "trainingsetindex", "shuffle" ) From 1bf93af7d890a0a2a2c1171060318d6036c0f695 Mon Sep 17 00:00:00 2001 From: kushalbakshi Date: Fri, 30 May 2025 14:37:19 -0400 Subject: [PATCH 9/9] Changes to post-training --- element_deeplabcut/train.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index 446015d..4fa834d 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -220,7 +220,13 @@ def make(self, key): "*/train/pytorch-config.yaml" ) ) - training_log_filepath = next( + with open(trained_config_path, "r") as f: + trained_config_file = yaml.safe_load(f) + + with open(trained_pytorch_config_path, "r") as f: + trained_pytorch_config_file = yaml.safe_load(f) + + training_log_file = next( (get_dlc_processed_data_dir() / training_dir_path).glob("*/train/train.txt") ) training_snapshot_file = sorted( @@ -233,13 +239,9 @@ def make(self, key): self.insert1( { **key, - "trained_pose_cfg": yaml.safe_load(trained_config_path), - "trained_pytorch_config": yaml.safe_load(trained_pytorch_config_path), - "training_log_file": training_log_filepath.relative_to( - get_dlc_processed_data_dir() - ), - "training_snapshot_file": training_snapshot_file.relative_to( - get_dlc_processed_data_dir() - ), + "trained_pose_cfg": trained_config_file, + "trained_pytorch_config": trained_pytorch_config_file, + "training_log_file": training_log_file, + "training_snapshot_file": training_snapshot_file, } )