From 2fb6b61cba9390a74500997cfb560386c5179797 Mon Sep 17 00:00:00 2001 From: Rondineli Gomes de Araujo Date: Sat, 17 Jan 2026 11:19:14 +0000 Subject: [PATCH] Adding reset parameters - usefull for transfer learnings --- oslactionspotting/apis/evaluate/utils.py | 5 ++++- oslactionspotting/datasets/json.py | 3 ++- oslactionspotting/models/builder.py | 23 ++++++++++++++++++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/oslactionspotting/apis/evaluate/utils.py b/oslactionspotting/apis/evaluate/utils.py index d191279..5a4815c 100644 --- a/oslactionspotting/apis/evaluate/utils.py +++ b/oslactionspotting/apis/evaluate/utils.py @@ -166,7 +166,10 @@ def label2vector( else: frame = framerate * (seconds + 60 * minutes) - label = EVENT_DICTIONARY[event] + # in case of a event not in EVENT_DICTIONARY move with the evaluation + # as label maybe None it wont evaluate for all classes + # but only the ones configured in the config model file + label = EVENT_DICTIONARY.get(event) value = 1 if "visibility" in annotation.keys(): diff --git a/oslactionspotting/datasets/json.py b/oslactionspotting/datasets/json.py index c5bda38..a41357f 100644 --- a/oslactionspotting/datasets/json.py +++ b/oslactionspotting/datasets/json.py @@ -163,7 +163,8 @@ def annotation(self, annotation): frame = self.framerate * (seconds + 60 * minutes) cont = False - + # Initializing label var in case the cont is False dont thrown an error + label = None if event not in self.classes: cont = True else: diff --git a/oslactionspotting/models/builder.py b/oslactionspotting/models/builder.py index 15501b4..952b618 100644 --- a/oslactionspotting/models/builder.py +++ b/oslactionspotting/models/builder.py @@ -37,6 +37,27 @@ def build_model(cfg, verbose=True, default_args=None): neck=cfg.model.neck, runner=cfg.runner.type, ) + + # Adding reset_backbone, reset_neck and freeze_backbone will allow + # in case of transfer learning domains + if getattr(cfg.model, "freeze_backbone", False): + for p in model.model.backbone.parameters(): + p.requires_grad = False + logging.info(f"[INFO] Backbone is freeze_backbone: and set requires_grad=False") + + # 2. Freeze neck (optional) + if getattr(cfg.model, "freeze_neck", False): + for p in model.model.neck.parameters(): + logging.info(f"[INFO] Neck is set to freeze_neck: and set requires_grad=False") + p.requires_grad = False + + # 3. Reset head + if getattr(cfg.model, "reset_head", False): + for m in model.model.head.modules(): + if hasattr(m, "reset_parameters"): + m.reset_parameters() + logging.info("[INFO] Hed is set to reset: reset_head: reset_parameters() was called") + elif cfg.model.type == "E2E": model = E2EModel( cfg, @@ -97,4 +118,4 @@ def build_model(cfg, verbose=True, default_args=None): logging.info(f"Total trainable parameters: {total_params:,}") logging.info(f"Number of parameter groups: {len(parameters_per_layer)}") - return model \ No newline at end of file + return model