Skip to content

Commit ec08bb0

Browse files
feat: Add Question/Answer mode to OSL Annotation Tool
- Introduced a new annotation mode for handling question and answer pairs. - Updated OSL JSON schema to include `questions` and `answers` fields. - Enhanced GUI to support the new Q/A mode, including a dedicated editor for managing questions and answers. - Implemented tests for Q/A workflows, ensuring proper functionality for adding, renaming, and deleting questions, as well as saving and reloading answers. - Updated existing tests to accommodate the new mode and ensure compatibility with the overall system.
1 parent 1fe440c commit ec08bb0

25 files changed

Lines changed: 1538 additions & 137 deletions

annotation_tool/controllers/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Owns runtime business logic: dataset lifecycle, mutation history, playback contr
1717
- `media_controller.py`: media playback and mute routing.
1818
- `welcome_controller.py`: welcome-page routing.
1919
- `hf_transfer_controller.py`: threaded Hugging Face download/upload orchestration for GUI menu actions.
20-
- `classification/`, `localization/`, `description/`, `dense_description/`: mode controllers.
20+
- `classification/`, `localization/`, `description/`, `dense_description/`, `question_answer/`: mode controllers.
2121

2222
## Key Functions and Responsibilities
2323
### `DatasetExplorerController`

annotation_tool/controllers/dataset_explorer_controller.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ class DatasetExplorerController(QObject):
227227
dataSelected = pyqtSignal(str)
228228
sampleSelectionChanged = pyqtSignal(object)
229229
schemaContextChanged = pyqtSignal(dict)
230+
questionBankChanged = pyqtSignal(list)
230231
classificationActionListChanged = pyqtSignal(list)
231232
mediaRouteRequested = pyqtSignal(str, bool)
232233
mediaStopRequested = pyqtSignal()
@@ -239,6 +240,7 @@ class DatasetExplorerController(QObject):
239240
resetEditorsRequested = pyqtSignal()
240241
editorTabRequested = pyqtSignal(int)
241242
descSaveRequested = pyqtSignal()
243+
qaSaveRequested = pyqtSignal()
242244
clearMarkersRequested = pyqtSignal()
243245
annotationPanelsEnabledRequested = pyqtSignal(bool)
244246
headerDraftMutationRequested = pyqtSignal(dict)
@@ -260,7 +262,7 @@ class DatasetExplorerController(QObject):
260262
"description",
261263
"metadata",
262264
)
263-
HEADER_EXCLUDED_KEYS = {"data", "labels"}
265+
HEADER_EXCLUDED_KEYS = {"data", "labels", "questions"}
264266

265267
def __init__(
266268
self,
@@ -342,6 +344,18 @@ def modalities(self):
342344
def modalities(self, value):
343345
self.dataset_json["modalities"] = list(value) if isinstance(value, list) else ["video"]
344346

347+
@property
348+
def question_definitions(self) -> list:
349+
questions = self.dataset_json.get("questions")
350+
if not isinstance(questions, list):
351+
questions = []
352+
self.dataset_json["questions"] = questions
353+
return questions
354+
355+
@question_definitions.setter
356+
def question_definitions(self, value):
357+
self.dataset_json["questions"] = list(value) if isinstance(value, list) else []
358+
345359
@property
346360
def project_header_known(self) -> dict:
347361
return {
@@ -552,6 +566,9 @@ def _emit_selected_sample(self, sample_id: str):
552566
def _emit_schema_context(self):
553567
self.schemaContextChanged.emit(copy.deepcopy(self.label_definitions))
554568

569+
def _emit_question_bank_context(self):
570+
self.questionBankChanged.emit(copy.deepcopy(self.question_definitions))
571+
555572
def _emit_classification_action_list(self):
556573
self.classificationActionListChanged.emit(copy.deepcopy(self.action_item_data))
557574

@@ -721,6 +738,7 @@ def clear_annotations_for_path(self, path: str):
721738
"events",
722739
"captions",
723740
"dense_captions",
741+
"answers",
724742
):
725743
sample.pop(field, None)
726744

@@ -881,6 +899,8 @@ def _prompt_unsaved_close_action(self) -> str:
881899
def save_project(self):
882900
if self._active_mode_idx() == 2:
883901
self.descSaveRequested.emit()
902+
if self._active_mode_idx() == 4:
903+
self.qaSaveRequested.emit()
884904

885905
if not self.current_json_path:
886906
return self.export_project()
@@ -889,6 +909,8 @@ def save_project(self):
889909
def export_project(self):
890910
if self._active_mode_idx() == 2:
891911
self.descSaveRequested.emit()
912+
if self._active_mode_idx() == 4:
913+
self.qaSaveRequested.emit()
892914

893915
path, _ = QFileDialog.getSaveFileName(
894916
self.panel,
@@ -1001,9 +1023,66 @@ def _default_dataset_json(self):
10011023
"modalities": ["video"],
10021024
"metadata": {},
10031025
"labels": {},
1026+
"questions": [],
10041027
"data": [],
10051028
}
10061029

1030+
@staticmethod
1031+
def _normalize_question_id(question_id: str) -> str:
1032+
return str(question_id or "").strip()
1033+
1034+
def _normalize_questions_payload(self, questions) -> list:
1035+
normalized = []
1036+
seen_ids = set()
1037+
for raw_question in list(questions or []):
1038+
if not isinstance(raw_question, dict):
1039+
continue
1040+
1041+
question_id = self._normalize_question_id(raw_question.get("id"))
1042+
question_text = str(raw_question.get("question") or "").strip()
1043+
if not question_id or not question_text:
1044+
continue
1045+
if question_id in seen_ids:
1046+
continue
1047+
1048+
seen_ids.add(question_id)
1049+
normalized.append({"id": question_id, "question": question_text})
1050+
return normalized
1051+
1052+
@staticmethod
1053+
def _normalize_sample_answers_payload(answers, valid_question_ids: set) -> list:
1054+
normalized = []
1055+
seen_question_ids = set()
1056+
for raw_answer in list(answers or []):
1057+
if not isinstance(raw_answer, dict):
1058+
continue
1059+
question_id = str(raw_answer.get("question_id") or "").strip()
1060+
if (
1061+
not question_id
1062+
or question_id not in valid_question_ids
1063+
or question_id in seen_question_ids
1064+
):
1065+
continue
1066+
answer_text = str(raw_answer.get("answer") or "").strip()
1067+
if not answer_text:
1068+
continue
1069+
normalized.append({"question_id": question_id, "answer": answer_text})
1070+
seen_question_ids.add(question_id)
1071+
return normalized
1072+
1073+
def next_question_id(self) -> str:
1074+
max_suffix = 0
1075+
for question in self.question_definitions:
1076+
if not isinstance(question, dict):
1077+
continue
1078+
question_id = self._normalize_question_id(question.get("id"))
1079+
if not question_id.startswith("q"):
1080+
continue
1081+
suffix = question_id[1:]
1082+
if suffix.isdigit():
1083+
max_suffix = max(max_suffix, int(suffix))
1084+
return f"q{max_suffix + 1}"
1085+
10071086
def _normalize_dataset_json(self, data):
10081087
if not isinstance(data, dict):
10091088
return None, "Root JSON must be an object."
@@ -1016,6 +1095,8 @@ def _normalize_dataset_json(self, data):
10161095

10171096
if not isinstance(normalized.get("labels"), dict):
10181097
normalized["labels"] = {}
1098+
normalized["questions"] = self._normalize_questions_payload(normalized.get("questions"))
1099+
valid_question_ids = {question["id"] for question in normalized["questions"]}
10191100
if not isinstance(normalized.get("metadata"), dict):
10201101
normalized["metadata"] = {}
10211102
if not isinstance(normalized.get("modalities"), list):
@@ -1061,6 +1142,15 @@ def _normalize_dataset_json(self, data):
10611142
if isinstance(event, dict):
10621143
event["position_ms"] = _safe_int(event.get("position_ms", 0))
10631144

1145+
normalized_answers = self._normalize_sample_answers_payload(
1146+
sample.get("answers"),
1147+
valid_question_ids,
1148+
)
1149+
if normalized_answers:
1150+
sample["answers"] = normalized_answers
1151+
else:
1152+
sample.pop("answers", None)
1153+
10641154
cleaned_data.append(sample)
10651155

10661156
normalized["data"] = cleaned_data
@@ -1430,10 +1520,12 @@ def _sample_supports_mode(self, sample: dict, mode_idx: int) -> bool:
14301520
return any(isinstance(cap, dict) and str(cap.get("text", "")).strip() for cap in captions)
14311521
if mode_idx == 3:
14321522
return bool(sample.get("dense_captions"))
1523+
if mode_idx == 4:
1524+
return self._has_non_empty_answers(sample)
14331525
return False
14341526

14351527
def _available_mode_indices_for_sample(self, sample: dict):
1436-
return [mode_idx for mode_idx in (0, 1, 2, 3) if self._sample_supports_mode(sample, mode_idx)]
1528+
return [mode_idx for mode_idx in (0, 1, 2, 3, 4) if self._sample_supports_mode(sample, mode_idx)]
14371529

14381530
def _reconcile_annotation_tab_for_sample(self, sample: dict) -> bool:
14391531
available_modes = self._available_mode_indices_for_sample(sample)
@@ -1558,10 +1650,28 @@ def _label_state_for_sample(self, sample):
15581650
for cap in captions
15591651
)
15601652

1561-
hand = bool(_ManualAnnotationRecord(sample)) or bool(sample.get("events")) or bool(sample.get("dense_captions")) or has_caption_text
1653+
hand = (
1654+
bool(_ManualAnnotationRecord(sample))
1655+
or bool(sample.get("events"))
1656+
or bool(sample.get("dense_captions"))
1657+
or has_caption_text
1658+
or self._has_non_empty_answers(sample)
1659+
)
15621660
smart = self._has_smart_labels(sample) or self._has_smart_events(sample)
15631661
return bool(hand), bool(smart)
15641662

1663+
@staticmethod
1664+
def _has_non_empty_answers(sample: dict) -> bool:
1665+
answers = sample.get("answers")
1666+
if not isinstance(answers, list):
1667+
return False
1668+
for entry in answers:
1669+
if not isinstance(entry, dict):
1670+
continue
1671+
if str(entry.get("answer") or "").strip():
1672+
return True
1673+
return False
1674+
15651675
@staticmethod
15661676
def _has_smart_labels(sample: dict) -> bool:
15671677
labels = sample.get("labels")
@@ -1784,6 +1894,7 @@ def _build_new_sample(self, source_group):
17841894
"events": [],
17851895
"captions": [],
17861896
"dense_captions": [],
1897+
"answers": [],
17871898
}
17881899

17891900
def handle_add_sample(self):
@@ -1859,6 +1970,8 @@ def _dataset_json_for_write(self, save_path: str):
18591970

18601971
base_dir = os.path.dirname(os.path.abspath(save_path))
18611972
written = copy.deepcopy(normalized)
1973+
written["questions"] = self._normalize_questions_payload(written.get("questions"))
1974+
valid_question_ids = {question["id"] for question in written["questions"]}
18621975
for sample in written.get("data", []):
18631976
new_inputs = []
18641977

@@ -1884,6 +1997,14 @@ def _dataset_json_for_write(self, save_path: str):
18841997
sample.pop("captions", None)
18851998
if not sample.get("dense_captions"):
18861999
sample.pop("dense_captions", None)
2000+
normalized_answers = self._normalize_sample_answers_payload(
2001+
sample.get("answers"),
2002+
valid_question_ids,
2003+
)
2004+
if normalized_answers:
2005+
sample["answers"] = normalized_answers
2006+
else:
2007+
sample.pop("answers", None)
18872008
if not sample.get("metadata"):
18882009
sample.pop("metadata", None)
18892010
# Never persist retired smart-* keys.
@@ -1896,6 +2017,7 @@ def _dataset_json_for_write(self, save_path: str):
18962017
definition.pop("label_colors", None)
18972018
written.setdefault("metadata", {})
18982019
written.setdefault("modalities", ["video"])
2020+
written.setdefault("questions", [])
18992021
if not written.get("description"):
19002022
written["description"] = ""
19012023
return written
@@ -1917,6 +2039,7 @@ def _write_dataset_json(self, save_path: str):
19172039
self._add_recent_project(self.current_json_path)
19182040
self._rebuild_runtime_index()
19192041
self._refresh_header_panel()
2042+
self._refresh_schema_panels()
19202043
self.saveStateRefreshRequested.emit()
19212044
self.statusMessageRequested.emit("Saved", f"Saved to {os.path.basename(save_path)}", 1500)
19222045
return True
@@ -1927,3 +2050,4 @@ def _write_dataset_json(self, save_path: str):
19272050
def _refresh_schema_panels(self):
19282051
self.schemaRefreshRequested.emit()
19292052
self._emit_schema_context()
2053+
self._emit_question_bank_context()

0 commit comments

Comments
 (0)