Skip to content

Commit ec2813b

Browse files
Jintao MaJintao Ma
authored andcommitted
Localization Module Enhancements: AI-Human Collaborative Workflow & OSL Standard Compliance
1 parent cfddd41 commit ec2813b

10 files changed

Lines changed: 1212 additions & 152 deletions

File tree

.DS_Store

6 KB
Binary file not shown.

annotation_tool/controllers/history_manager.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _refresh_active_view(self):
5353
# Refresh Events (Table & Timeline)
5454
self.main.loc_manager._refresh_current_clip_events()
5555
# Refresh left side
56-
self.main.loc_manager.populate_tree()
56+
self.main.loc_manager.refresh_tree_icons()
5757

5858
# 2. Description Mode
5959
elif current_widget == self.main.ui.description_ui:
@@ -201,6 +201,68 @@ def _apply_state_change(self, cmd, is_undo):
201201

202202
self._refresh_active_view()
203203

204+
205+
# =========================================================
206+
# Redo/undo for localization smart annotation
207+
# =========================================================
208+
elif ctype == CmdType.LOC_SMART_CONFIRM:
209+
path = cmd['video_path']
210+
events = cmd['confirmed_events']
211+
212+
smart_events = self.model.smart_localization_events.get(path, [])
213+
temp_events = self.model.temp_smart_events.get(path, [])
214+
215+
if is_undo:
216+
for evt in events:
217+
if evt in smart_events:
218+
smart_events.remove(evt)
219+
temp_events.extend(events)
220+
temp_events.sort(key=lambda x: x.get('position_ms', 0))
221+
else:
222+
for evt in events:
223+
if evt in temp_events:
224+
temp_events.remove(evt)
225+
smart_events.extend(events)
226+
smart_events.sort(key=lambda x: x.get('position_ms', 0))
227+
228+
self.model.smart_localization_events[path] = smart_events
229+
self.model.temp_smart_events[path] = temp_events
230+
231+
self.main.loc_manager.refresh_tree_icons()
232+
self.main.loc_manager._display_smart_events(path)
233+
234+
elif ctype == CmdType.LOC_SMART_EVENT_DEL:
235+
path = cmd['video_path']
236+
evt = cmd['event']
237+
is_confirmed = cmd['is_confirmed']
238+
239+
if is_confirmed:
240+
events_list = self.model.smart_localization_events.setdefault(path, [])
241+
else:
242+
events_list = self.model.temp_smart_events.setdefault(path, [])
243+
244+
if is_undo:
245+
events_list.append(evt)
246+
events_list.sort(key=lambda x: x.get('position_ms', 0))
247+
else:
248+
if evt in events_list:
249+
events_list.remove(evt)
250+
251+
self.main.loc_manager.refresh_tree_icons()
252+
self.main.loc_manager._display_smart_events(path)
253+
254+
elif ctype == CmdType.LOC_SMART_RUN:
255+
path = cmd['video_path']
256+
old_events = cmd['old_events']
257+
new_events = cmd['new_events']
258+
259+
if is_undo:
260+
self.model.temp_smart_events[path] = copy.deepcopy(old_events)
261+
else:
262+
self.model.temp_smart_events[path] = copy.deepcopy(new_events)
263+
264+
self.main.loc_manager._display_smart_events(path)
265+
204266
# =========================================================
205267
# 3. Description Specific
206268
# =========================================================
@@ -430,4 +492,4 @@ def _apply_state_change(self, cmd, is_undo):
430492
if evt.get('head') == head and evt.get('label') == src:
431493
evt['label'] = dst
432494

433-
self._refresh_active_view()
495+
self._refresh_active_view()

annotation_tool/controllers/localization/loc_file_manager.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def load_project(self, data, file_path):
167167

168168
# Process events
169169
raw_events = item.get("events", [])
170-
processed_events = []
170+
processed_events = []
171+
processed_smart_events = []
171172

172173
if isinstance(raw_events, list):
173174
for evt in raw_events:
@@ -178,16 +179,25 @@ def load_project(self, data, file_path):
178179
except ValueError:
179180
pos_ms = 0
180181

181-
processed_events.append(
182-
{
182+
if "confidence" in evt or "score" in evt:
183+
conf = evt.get("confidence", evt.get("score", 1.0))
184+
processed_smart_events.append({
183185
"head": evt.get("head", "action"),
184186
"label": evt.get("label", "?"),
185187
"position_ms": pos_ms,
186-
}
187-
)
188+
"confidence": conf
189+
})
190+
else:
191+
processed_events.append({
192+
"head": evt.get("head", "action"),
193+
"label": evt.get("label", "?"),
194+
"position_ms": pos_ms,
195+
})
188196

189197
if processed_events:
190198
self.model.localization_events[final_path] = processed_events
199+
if processed_smart_events:
200+
self.model.smart_localization_events[final_path] = processed_smart_events
191201

192202
loaded_count += 1
193203

@@ -258,35 +268,46 @@ def _write_json(self, path):
258268

259269
for data in sorted_items:
260270
abs_path = data["path"]
271+
261272
events = self.model.localization_events.get(abs_path, [])
273+
smart_events = self.model.smart_localization_events.get(abs_path, [])
262274

263-
# Store path as relative if possible
264275
try:
265276
rel_path = os.path.relpath(abs_path, base_dir).replace(os.sep, "/")
266277
except Exception:
267278
rel_path = abs_path
268279

269-
# Convert events to export format
270280
export_events = []
281+
271282
for e in events:
272-
export_events.append(
273-
{
274-
"head": e.get("head"),
275-
"label": e.get("label"),
276-
"position_ms": str(e.get("position_ms")),
277-
}
278-
)
283+
export_events.append({
284+
"head": e.get("head"),
285+
"label": e.get("label"),
286+
"position_ms": int(e.get("position_ms", 0)),
287+
})
288+
289+
for e in smart_events:
290+
export_events.append({
291+
"head": e.get("head"),
292+
"label": e.get("label"),
293+
"position_ms": int(e.get("position_ms", 0)),
294+
"confidence": float(e.get("confidence", 0.99))
295+
})
296+
297+
export_events.sort(key=lambda x: x["position_ms"])
279298

280299
entry = {
300+
"id": data.get("name", ""),
281301
"inputs": [
282302
{
283303
"type": "video",
284304
"path": rel_path,
285305
"fps": 25.0,
286306
}
287307
],
288-
"events": export_events,
308+
"events": export_events
289309
}
310+
290311
output["data"].append(entry)
291312

292313
try:
@@ -297,32 +318,10 @@ def _write_json(self, path):
297318
self.main.statusBar().showMessage(f"Saved — {os.path.basename(path)}", 1500)
298319
return True
299320
except Exception as e:
321+
from PyQt6.QtWidgets import QMessageBox
300322
QMessageBox.critical(self.main, "Error", f"Save failed: {e}")
301323
return False
302-
303-
for video_path in sorted(self.model.localization_events.keys()):
304-
# 获取该视频所属的原始 item 定义(包含 inputs 视频源信息)
305-
base_item = next((item for item in self.model.action_item_data if item["path"] == video_path), None)
306-
if not base_item: continue
307-
308-
# 1. 获取手工(或已确认的)标注
309-
manual_events = self.model.localization_events.get(video_path, [])
310-
311-
# 2. 获取未确认的智能标注
312-
smart_events = self.model.smart_localization_events.get(video_path, [])
313-
314-
# 构建符合 OSL 标准规范的单条数据结构
315-
out_item = {
316-
"id": base_item.get("id", ""),
317-
"inputs": [{"path": f, "type": "video"} for f in base_item.get("source_files", [video_path])],
318-
"events": manual_events
319-
}
320-
321-
# 遵循原始结构添加 smart_events 字段(如果有的话)
322-
if smart_events:
323-
out_item["smart_events"] = smart_events
324-
325-
items.append(out_item)
324+
326325

327326
def _clear_workspace(self, full_reset=False):
328327
"""
@@ -336,7 +335,7 @@ def _clear_workspace(self, full_reset=False):
336335
self.main.loc_manager.center_panel.media_preview.stop()
337336
self.main.loc_manager.center_panel.media_preview.player.setSource(QUrl())
338337

339-
# [FIX] Reset timeline UI (markers + label + slider)
338+
# [FIX] Reset timeline UI (markers + label + slider)
340339
tl = self.main.loc_manager.center_panel.timeline
341340
tl.set_markers([])
342341
tl.set_duration(0)
@@ -345,6 +344,13 @@ def _clear_workspace(self, full_reset=False):
345344
# Right panel: clear table and schema
346345
self.main.loc_manager.right_panel.table.set_data([])
347346
self.main.loc_manager.right_panel.annot_mgmt.update_schema({})
347+
if hasattr(self.main.loc_manager.right_panel, "smart_widget"):
348+
smart_ui = self.main.loc_manager.right_panel.smart_widget
349+
350+
smart_ui.reset_ui()
351+
352+
smart_ui.predicted_table.set_data([])
353+
smart_ui.confirmed_table.set_data([])
348354

349355
# Reset model data
350356
self.model.reset(full_reset)

annotation_tool/controllers/localization/loc_inference.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,38 @@ def __init__(self, video_path, start_ms, end_ms, config_path):
2222

2323
def run(self):
2424
try:
25+
import torch
26+
if not torch.cuda.is_available():
27+
torch.cuda.FloatTensor = torch.FloatTensor
28+
torch.cuda.LongTensor = torch.LongTensor
29+
torch.cuda.IntTensor = torch.IntTensor
30+
torch.cuda.DoubleTensor = torch.DoubleTensor
31+
# ==========================================
32+
2533
# Import library inside thread to avoid blocking main thread at startup
2634
from opensportslib import model
35+
import subprocess
2736

2837
with tempfile.TemporaryDirectory() as tmp_dir:
38+
# Use FFmpeg to cut clips
39+
clip_video_path = os.path.join(tmp_dir, "clipped_segment.mp4")
40+
41+
def ms_to_ffmpeg(ms):
42+
s = ms // 1000
43+
return f"{s // 3600:02}:{(s % 3600) // 60:02}:{s % 60:02}.{ms % 1000:03}"
44+
45+
start_time_str = ms_to_ffmpeg(self.start_ms)
46+
duration_ms = self.end_ms - self.start_ms if self.end_ms > 0 else 0
47+
48+
cmd = ['ffmpeg', '-y', '-ss', start_time_str, '-i', self.video_path]
49+
if duration_ms > 0:
50+
cmd += ['-t', ms_to_ffmpeg(duration_ms)]
51+
cmd += ['-c', 'copy', clip_video_path]
52+
53+
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
54+
2955
tmp_input_json = os.path.join(tmp_dir, "temp_test.json")
3056
tmp_config_yaml = os.path.join(tmp_dir, "temp_config.yaml")
31-
tmp_output_json = os.path.join(tmp_dir, "predictions.json")
3257

3358
# --- 1. Load and dynamically patch the YAML config ---
3459
with open(self.config_path, 'r', encoding='utf-8') as f:
@@ -37,34 +62,35 @@ def run(self):
3762
classes = config_dict.get("DATA", {}).get("classes", [])
3863

3964
# 🚀 [MAC CPU ADAPTATION & PATH FIXES] 🚀
40-
# Force CPU mode and disable Multi-GPU dynamically
4165
if "SYSTEM" not in config_dict: config_dict["SYSTEM"] = {}
4266
config_dict["SYSTEM"]["work_dir"] = tmp_dir
4367
config_dict["SYSTEM"]["device"] = "cpu"
44-
config_dict["SYSTEM"]["GPU"] = 0
45-
config_dict["SYSTEM"]["gpu_id"] = 0
68+
config_dict["SYSTEM"]["GPU"] = -1
69+
config_dict["SYSTEM"]["gpu_id"] = -1
4670

4771
if "MODEL" not in config_dict: config_dict["MODEL"] = {}
4872
config_dict["MODEL"]["multi_gpu"] = False
49-
50-
# Override dataloader paths for test
73+
5174
if "DATA" in config_dict and "test" in config_dict["DATA"]:
52-
config_dict["DATA"]["test"]["video_path"] = os.path.dirname(self.video_path)
75+
config_dict["DATA"]["test"]["video_path"] = tmp_dir
5376
config_dict["DATA"]["test"]["path"] = tmp_input_json
5477
config_dict["DATA"]["test"]["results"] = "predictions"
55-
78+
79+
if "dataloader" not in config_dict["DATA"]["test"]:
80+
config_dict["DATA"]["test"]["dataloader"] = {}
81+
config_dict["DATA"]["test"]["dataloader"]["pin_memory"] = False
82+
5683
with open(tmp_config_yaml, 'w', encoding='utf-8') as f:
5784
yaml.dump(config_dict, f)
5885

59-
# --- 2. Create temporary JSON for the single video ---
86+
# --- 2. Create temporary JSON for the clipped video ---
6087
test_data = {
6188
"version": "2.0",
6289
"task": "action_spotting",
6390
"labels": {"ball_action": {"type": "single_label", "labels": classes}},
6491
"data": [{
6592
"id": "inf_vid",
66-
"inputs": [{"path": self.video_path, "type": "video", "fps": 25.0}],
67-
# 必须放一个 Dummy event 骗过 DataLoader
93+
"inputs": [{"path": clip_video_path, "type": "video", "fps": 25.0}],
6894
"events": [{"head": "ball_action", "label": classes[0] if classes else "Unknown", "position_ms": 0}]
6995
}]
7096
}
@@ -75,34 +101,29 @@ def run(self):
75101
loc_model = model.localization(config=tmp_config_yaml)
76102

77103
try:
78-
# 运行推理。这里一定会抛出 FileNotFoundError,因为框架底层的评估器找不到文件
79104
loc_model.infer(
80105
test_set=tmp_input_json,
81106
pretrained="jeetv/snpro-snbas-2024"
82107
)
83-
except FileNotFoundError:
84-
# [关键修复 4]:霸气忽略!
85-
# 因为报错发生在推理完成之后的“评估阶段”,所以我们直接 catch 掉这个错误,
86-
# 假装无事发生,直接进入下一步去深层文件夹里捞生成的 JSON。
108+
109+
except Exception as eval_err:
110+
print(f"Ignored evaluation error: {eval_err}")
87111
pass
88112

89-
# --- 4. Parse result JSON ---
90-
# 递归搜索临时文件夹下的所有 .json 文件(完美穿透 checkpoints/xxx 嵌套文件夹)
113+
# --- 4. Parse result JSON and compensate timestamps ---
91114
search_pattern = os.path.join(tmp_dir, "**", "*.json")
92115
all_jsons = glob.glob(search_pattern, recursive=True)
93116

94117
valid_preds = []
95118
for f in all_jsons:
96119
filename = os.path.basename(f)
97-
# 排除掉我们自己生成的输入数据和配置文件
98120
if "temp_test" not in filename and "temp_config" not in filename:
99121
valid_preds.append(f)
100122

101123
if valid_preds:
102-
# 找到最新生成的那一个(防止有多个旧文件干扰)
103124
actual_output_json = max(valid_preds, key=os.path.getctime)
104125
else:
105-
raise FileNotFoundError(f"Could not find any generated prediction JSON in {tmp_dir}/checkpoints/")
126+
raise FileNotFoundError(f"Could not find any generated prediction JSON in {tmp_dir}")
106127

107128
predicted_events = []
108129
if os.path.exists(actual_output_json):
@@ -111,16 +132,20 @@ def run(self):
111132

112133
raw_evts = output_data.get("data", [{}])[0].get("events", [])
113134
for evt in raw_evts:
114-
p_ms = int(evt.get("position_ms", 0))
135+
p_ms_relative = int(evt.get("position_ms", 0))
115136

116-
if p_ms == 0 and evt.get("label") == (classes[0] if classes else "Unknown"):
137+
if p_ms_relative == 0 and evt.get("label") == (classes[0] if classes else "Unknown"):
117138
continue
139+
p_ms_absolute = p_ms_relative + self.start_ms
118140

119-
if p_ms >= self.start_ms and (self.end_ms == 0 or p_ms <= self.end_ms):
141+
if self.end_ms == 0 or p_ms_absolute <= self.end_ms:
142+
# Get confidence
143+
conf = evt.get("confidence", evt.get("score", 0.99))
120144
predicted_events.append({
121-
"head": evt.get("head", "ball_action"),
145+
"head": "ball_action",
122146
"label": evt.get("label", "Unknown"),
123-
"position_ms": p_ms
147+
"position_ms": p_ms_absolute,
148+
"confidence": conf
124149
})
125150

126151
self.finished_signal.emit(predicted_events)

0 commit comments

Comments
 (0)