From 562e5100d472b4028b656c5314e003a71ccd3525 Mon Sep 17 00:00:00 2001 From: Isac Paulsson Date: Mon, 8 Dec 2025 12:50:29 +0100 Subject: [PATCH 1/6] rework of teamassigner to use kmeans clustering --- services/team_assigner_service/.dockerignore | 8 - services/team_assigner_service/Dockerfile | 16 +- .../processing}/__init__.py | 0 .../processing/team_assigner.py | 101 ++++++++++++ .../team_assigner_service/requirements.txt | 3 +- .../team_assigner_service.py | 7 +- team_assigner/team_assigner.py | 154 ------------------ tracking/track_players.py | 2 +- 8 files changed, 111 insertions(+), 180 deletions(-) delete mode 100644 services/team_assigner_service/.dockerignore rename {team_assigner => services/team_assigner_service/processing}/__init__.py (100%) create mode 100644 services/team_assigner_service/processing/team_assigner.py delete mode 100644 team_assigner/team_assigner.py diff --git a/services/team_assigner_service/.dockerignore b/services/team_assigner_service/.dockerignore deleted file mode 100644 index 16d1b28..0000000 --- a/services/team_assigner_service/.dockerignore +++ /dev/null @@ -1,8 +0,0 @@ -* -!*.py -!requirements.txt -!../../tracking -!../../utils -!../../shared -!../../canvas -!../../ball_acq diff --git a/services/team_assigner_service/Dockerfile b/services/team_assigner_service/Dockerfile index ca79c63..889eaa6 100644 --- a/services/team_assigner_service/Dockerfile +++ b/services/team_assigner_service/Dockerfile @@ -1,4 +1,4 @@ -FROM pytorch/pytorch:2.9.1-cuda13.0-cudnn9-runtime +FROM python:3.12-slim WORKDIR /app @@ -6,19 +6,15 @@ RUN apt-get update && apt-get install -y \ ffmpeg libsm6 libxext6 \ && rm -rf /var/lib/apt/lists/* -COPY services/team_assigner_service/requirements.txt /app/ +COPY services/team_assigner_service/requirements.txt . -# Install PyTorch with CUDA 13.0 runtime first RUN pip3 install --upgrade pip && \ pip3 install -r requirements.txt -# pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu130 && \ -COPY utils /app/utils -COPY shared /app/shared -COPY team_assigner /app/team_assigner -COPY canvas /app/canvas -COPY services/team_assigner_service/team_assigner_service.py /app/ +COPY utils/ ./utils/ +COPY shared/ ./shared/ +COPY services/team_assigner_service/ ./team_assigner_service/ EXPOSE 8000 -CMD ["uvicorn", "team_assigner_service:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file +CMD ["uvicorn", "team_assigner_service.team_assigner_service:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/team_assigner/__init__.py b/services/team_assigner_service/processing/__init__.py similarity index 100% rename from team_assigner/__init__.py rename to services/team_assigner_service/processing/__init__.py diff --git a/services/team_assigner_service/processing/team_assigner.py b/services/team_assigner_service/processing/team_assigner.py new file mode 100644 index 0000000..dfddf25 --- /dev/null +++ b/services/team_assigner_service/processing/team_assigner.py @@ -0,0 +1,101 @@ +import cv2 +import numpy as np +from PIL import Image +from collections import defaultdict +from sklearn.cluster import KMeans + +import sys +sys.path.append('../') + +class TeamAssigner: + def __init__(self, crop_factor = 0.3): + self.crop_factor = crop_factor + + def get_center_crop(self, img_array, crop_factor=0.4): + h, w, _ = img_array.shape + h_center, w_center = h // 2, w // 2 + h_crop, w_crop = int(h * crop_factor), int(w * crop_factor) + + y1 = max(0, h_center - h_crop // 2) + y2 = min(h, h_center + h_crop // 2) + x1 = max(0, w_center - w_crop // 2) + x2 = min(w, w_center + w_crop // 2) + + return img_array[y1:y2, x1:x2] + + def _extract_features(self, pil_images): + features = [] + bins = (8, 4, 4) + + for pil in pil_images: + img = np.array(pil) + img_crop = self.get_center_crop(img, crop_factor=self.crop_factor) + + hsv = cv2.cvtColor(img_crop, cv2.COLOR_RGB2HSV) + hist = cv2.calcHist([hsv], [0, 1, 2], None, bins, [0, 180, 0, 256, 0, 256]) + + cv2.normalize(hist, hist) + features.append(hist.flatten()) + + return np.array(features) + + def get_player_teams_global(self, vid_frames, player_tracks): + player_features_map = defaultdict(list) + frame_assignments = [dict() for _ in range(len(vid_frames))] + + for frame_id, player_track in enumerate(player_tracks): + frame = vid_frames[frame_id] + + frame_crops = [] + frame_pids = [] + + for pid, info in player_track.items(): + x1, y1, x2, y2 = map(int, info['bbox']) + crop = frame[y1:y2, x1:x2] + + pil_img = Image.fromarray(crop) + + frame_crops.append(pil_img) + frame_pids.append(pid) + + if not frame_crops: + continue + + feats = self._extract_features(frame_crops) + + for pid, feat in zip(frame_pids, feats): + player_features_map[pid].append(feat) + + unique_pids = list(player_features_map.keys()) + averaged_features = [] + + for pid in unique_pids: + # Average features per player id for global embedding (Remove noise) + feats = np.array(player_features_map[pid]) + avg_feat = np.mean(feats, axis=0) + averaged_features.append(avg_feat) + + if not unique_pids: + return frame_assignments + + print(f" # Player ID's: {len(unique_pids)}") + + kmeans = KMeans(n_clusters=2, random_state=42) + + if len(unique_pids) >= 2: + labels = kmeans.fit_predict(averaged_features) + else: + labels = [1] * len(unique_pids) # Fallback in case of not enough players + + pid_to_team = {pid: int(label) + 1 for pid, label in zip(unique_pids, labels)} + + for frame_id, player_track in enumerate(player_tracks): + for pid in player_track.keys(): + if pid in pid_to_team: + team_id = pid_to_team[pid] + frame_assignments[frame_id][pid] = team_id + + return frame_assignments + + def get_player_teams_over_frames(self, vid_frames, player_tracks): + return self.get_player_teams_global(vid_frames, player_tracks) \ No newline at end of file diff --git a/services/team_assigner_service/requirements.txt b/services/team_assigner_service/requirements.txt index ade428f..7cc4f9b 100644 --- a/services/team_assigner_service/requirements.txt +++ b/services/team_assigner_service/requirements.txt @@ -4,6 +4,5 @@ requests boto3 opencv-python-headless pillow -transformers python-multipart -ultralytics \ No newline at end of file +scikit-learn \ No newline at end of file diff --git a/services/team_assigner_service/team_assigner_service.py b/services/team_assigner_service/team_assigner_service.py index 8ff925a..73c596d 100644 --- a/services/team_assigner_service/team_assigner_service.py +++ b/services/team_assigner_service/team_assigner_service.py @@ -7,7 +7,7 @@ import json from utils import read_video -from team_assigner import TeamAssigner +from team_assigner_service.processing.team_assigner import TeamAssigner def serialize_team_assignments(assignments): out = [] @@ -23,10 +23,7 @@ def serialize_team_assignments(assignments): app = FastAPI() -team_assigner = TeamAssigner( - team_A="WHITE shirt", - team_B="DARK BLUE shirt" - ) +team_assigner = TeamAssigner(crop_factor=0.3) @app.post("/assign_teams") async def assign_teams( diff --git a/team_assigner/team_assigner.py b/team_assigner/team_assigner.py deleted file mode 100644 index e58f097..0000000 --- a/team_assigner/team_assigner.py +++ /dev/null @@ -1,154 +0,0 @@ -import os -import cv2 -import torch -import random -import numpy as np -from time import time -from PIL import Image -from ultralytics import SAM -from transformers import CLIPProcessor, CLIPModel -from collections import defaultdict, deque, Counter - -import sys -sys.path.append('../') - -class TeamAssigner: - def __init__(self, - team_A= "WHITE shirt", - team_B= "DARK-BLUE shirt", - history_len = 50, - crop_factor = 0.375, - save_imgs = False, - crop = False - ): - self.team_colors = {} - self.history_len = history_len - self.player_team_cache_history = defaultdict(lambda: deque(maxlen=history_len)) - - self.crop_factor = crop_factor - self.save_imgs = save_imgs - self.crop = crop - - self.team_A = team_A - self.team_B = team_B - - self.load_model() - - def load_model(self): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip").to(self.device) - self.processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip") - - def crop_img(self, pil_image): - width, height = pil_image.size - torso_height = int(height * self.crop_factor) - y_center = height // 2 - y1_new = max(y_center - torso_height // 2, 0) - y2_new = min(y_center + torso_height // 2, height) - cropped_pil_image = pil_image.crop((0, y1_new, width, y2_new)) - - return cropped_pil_image - - def get_player_color(self,frame,bbox): - image = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] - - pil_image = Image.fromarray(image) - - if self.crop: - pil_image = self.crop_img(pil_image) - - if self.save_imgs: - r = random.randint(1, 1000000) - filename = f"masked_{r}.png" - pil_image.save(os.path.join("imgs/masked", filename)) - - team_classes = [self.team_A, self.team_B] - - inputs = self.processor(text=team_classes, images=pil_image, return_tensors="pt", padding=True) - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - outputs = self.model(**inputs) - logits_per_image = outputs.logits_per_image - max_class_prob = logits_per_image.softmax(dim=1) - - player_color = team_classes[max_class_prob.argmax(dim=1)[0]] - return player_color - - def get_team_from_history(self, player_id): - history = list(self.player_team_cache_history[player_id]) - - if not history: - return None - - counter = Counter(history) - most_freq, _ = counter.most_common(1)[0] - print(counter) - print(most_freq) - return 1 if most_freq == self.team_A else 2 - - def get_player_team(self,frame,player_bbox,player_id): - player_color = self.get_player_color(frame,player_bbox) - self.player_team_cache_history[player_id].append(player_color) - team_id = self.get_team_from_history(player_id) - print(player_id) - return team_id - - def get_player_color_batch(self, pil_images): - team_classes = [self.team_A, self.team_B] - - inputs = self.processor( - text=team_classes, - images=pil_images, - return_tensors="pt", - padding=True - ) - - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - with torch.no_grad(): - outputs = self.model(**inputs) - logits_per_image = outputs.logits_per_image - probs = logits_per_image.softmax(dim=1) - - pred_indices = probs.argmax(dim=1).tolist() - return [team_classes[i] for i in pred_indices] - - def process_frame_batched(self, frame, player_track): - pil_images = [] - player_ids = [] - bboxes = [] - - for pid, info in player_track.items(): - bbox = info['bbox'] - crop = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] - pil_img = Image.fromarray(crop) - - if self.crop: - pil_img = self.crop_img(pil_img) - - if self.save_imgs: - r = random.randint(1, 1_000_000) - pil_img.save(f"imgs/masked/masked_{r}.png") - - pil_images.append(pil_img) - player_ids.append(pid) - bboxes.append(bbox) - - predicted_colors = self.get_player_color_batch(pil_images) - - assignment = {} - for pid, color in zip(player_ids, predicted_colors): - self.player_team_cache_history[pid].append(color) - assignment[pid] = self.get_team_from_history(pid) - - return assignment - - def get_player_teams_over_frames(self, vid_frames, player_tracks): - player_assignment = [] - - for frame_id, player_track in enumerate(player_tracks): - frame = vid_frames[frame_id] - assignment = self.process_frame_batched(frame, player_track) - player_assignment.append(assignment) - - return player_assignment \ No newline at end of file diff --git a/tracking/track_players.py b/tracking/track_players.py index 84e16a1..5df9e79 100644 --- a/tracking/track_players.py +++ b/tracking/track_players.py @@ -18,7 +18,7 @@ def detect_frames(self, vid_frames, batch_size=20, min_conf=0.5): def get_object_tracks(self, vid_frames): - detections = self.detect_frames(vid_frames) + detections = self.detect_frames(vid_frames, min_conf=0.3) tracks = [] for frame_id, detection in enumerate(detections): From a54bf2da33bd049e2fc46521b611815fabed2a00 Mon Sep 17 00:00:00 2001 From: Isac Paulsson Date: Mon, 8 Dec 2025 14:35:39 +0100 Subject: [PATCH 2/6] added crop skew parameter, shifts the crop up in the bounding box, added outlier check for short lived ids (crowd can massively skew feature distribution) --- .../processing/team_assigner.py | 21 ++++++++++++------- .../team_assigner_service.py | 2 +- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/services/team_assigner_service/processing/team_assigner.py b/services/team_assigner_service/processing/team_assigner.py index dfddf25..a4c4754 100644 --- a/services/team_assigner_service/processing/team_assigner.py +++ b/services/team_assigner_service/processing/team_assigner.py @@ -11,13 +11,14 @@ class TeamAssigner: def __init__(self, crop_factor = 0.3): self.crop_factor = crop_factor - def get_center_crop(self, img_array, crop_factor=0.4): + def get_center_crop(self, img_array, crop_factor=0.4, skew_factor=0.1): h, w, _ = img_array.shape h_center, w_center = h // 2, w // 2 h_crop, w_crop = int(h * crop_factor), int(w * crop_factor) - - y1 = max(0, h_center - h_crop // 2) - y2 = min(h, h_center + h_crop // 2) + h_skew = int(h * skew_factor) + + y1 = max(0, h_center - h_crop // 2 - h_skew) + y2 = min(h, h_center + h_crop // 2 - h_skew) x1 = max(0, w_center - w_crop // 2) x2 = min(w, w_center + w_crop // 2) @@ -32,9 +33,10 @@ def _extract_features(self, pil_images): img_crop = self.get_center_crop(img, crop_factor=self.crop_factor) hsv = cv2.cvtColor(img_crop, cv2.COLOR_RGB2HSV) - hist = cv2.calcHist([hsv], [0, 1, 2], None, bins, [0, 180, 0, 256, 0, 256]) + hist = cv2.calcHist([hsv], [0, 1, 2], None, bins, [0, 180, 0, 256, 0, 256]) cv2.normalize(hist, hist) + features.append(hist.flatten()) return np.array(features) @@ -71,9 +73,12 @@ def get_player_teams_global(self, vid_frames, player_tracks): for pid in unique_pids: # Average features per player id for global embedding (Remove noise) - feats = np.array(player_features_map[pid]) - avg_feat = np.mean(feats, axis=0) - averaged_features.append(avg_feat) + if len(player_features_map[pid]) > 5: # Id should exist for alteast 5 frames to be used in kmeans + feats = np.array(player_features_map[pid]) + avg_feat = np.mean(feats, axis=0) + averaged_features.append(avg_feat) + else: + averaged_features.append(averaged_features[-1]) if not unique_pids: return frame_assignments diff --git a/services/team_assigner_service/team_assigner_service.py b/services/team_assigner_service/team_assigner_service.py index 73c596d..94d7e0c 100644 --- a/services/team_assigner_service/team_assigner_service.py +++ b/services/team_assigner_service/team_assigner_service.py @@ -23,7 +23,7 @@ def serialize_team_assignments(assignments): app = FastAPI() -team_assigner = TeamAssigner(crop_factor=0.3) +team_assigner = TeamAssigner(crop_factor=0.2) @app.post("/assign_teams") async def assign_teams( From e7df32e0e9be4debc996595099c3924de758db20 Mon Sep 17 00:00:00 2001 From: Isac Paulsson Date: Mon, 8 Dec 2025 14:40:17 +0100 Subject: [PATCH 3/6] adjusted rare threshold --- services/team_assigner_service/processing/team_assigner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/team_assigner_service/processing/team_assigner.py b/services/team_assigner_service/processing/team_assigner.py index a4c4754..0ff88d8 100644 --- a/services/team_assigner_service/processing/team_assigner.py +++ b/services/team_assigner_service/processing/team_assigner.py @@ -73,7 +73,7 @@ def get_player_teams_global(self, vid_frames, player_tracks): for pid in unique_pids: # Average features per player id for global embedding (Remove noise) - if len(player_features_map[pid]) > 5: # Id should exist for alteast 5 frames to be used in kmeans + if len(player_features_map[pid]) > 3: # Id should exist for alteast 5 frames to be used in kmeans feats = np.array(player_features_map[pid]) avg_feat = np.mean(feats, axis=0) averaged_features.append(avg_feat) From 5a9e900c18bb313645241c5c8d0ca689d2ec4b92 Mon Sep 17 00:00:00 2001 From: Isac Paulsson Date: Mon, 8 Dec 2025 15:49:25 +0100 Subject: [PATCH 4/6] color extraction from kmeans clustering, color used for plots/videos --- canvas/top_down_overlay.py | 1 + services/orchestrator_service/Dockerfile | 1 - services/orchestrator_service/api_utils.py | 2 +- .../orchestrator_service.py | 15 +++++--- .../processing/team_assigner.py | 27 ++++++++++++- .../team_assigner_service.py | 5 ++- services/ui_service/plots.py | 38 +++++++++---------- services/ui_service/tabs/inference_tab.py | 14 +++++-- 8 files changed, 68 insertions(+), 35 deletions(-) diff --git a/canvas/top_down_overlay.py b/canvas/top_down_overlay.py index 26c3f8f..c054ba5 100644 --- a/canvas/top_down_overlay.py +++ b/canvas/top_down_overlay.py @@ -76,6 +76,7 @@ def draw_voronoi(self, minimap, subdiv, teams, alpha=0.2): def draw_players(self, minimap, positions, teams): for (mx, my), t in zip(positions, teams): cv2.circle(minimap, (mx, my), 10, self.color[t], -1) + cv2.circle(minimap, (mx, my), 10, (0, 0, 0), 2, lineType=cv2.LINE_AA) return minimap def draw_overlay(self, frames, td_track, x=0, y=0): diff --git a/services/orchestrator_service/Dockerfile b/services/orchestrator_service/Dockerfile index 01bb791..314fa1a 100644 --- a/services/orchestrator_service/Dockerfile +++ b/services/orchestrator_service/Dockerfile @@ -17,7 +17,6 @@ COPY utils /app/utils COPY canvas /app/canvas COPY shared /app/shared COPY ball_acq /app/ball_acq -COPY team_assigner /app/team_assigner RUN pip install --no-cache-dir -r requirements.txt diff --git a/services/orchestrator_service/api_utils.py b/services/orchestrator_service/api_utils.py index cabaf0c..56164e7 100644 --- a/services/orchestrator_service/api_utils.py +++ b/services/orchestrator_service/api_utils.py @@ -29,7 +29,7 @@ def get_team_assignments_from_service(local_video_path: str, player_tracks): r = requests.post(url, files=files) r.raise_for_status() data = r.json() - return data["team_assignments"] + return data def get_tracks_from_service(local_video_path: str): url = DETECTOR_URL diff --git a/services/orchestrator_service/orchestrator_service.py b/services/orchestrator_service/orchestrator_service.py index 476ad07..cba354f 100644 --- a/services/orchestrator_service/orchestrator_service.py +++ b/services/orchestrator_service/orchestrator_service.py @@ -40,8 +40,10 @@ async def process_video(video_name: str, reference_court: str): # 3) Get team assignments team_assignments_json = get_team_assignments_from_service(tmp_video_path, player_tracks) - team_assignments = deserialize_team_assignments(team_assignments_json) - + team_assignments = deserialize_team_assignments(team_assignments_json["team_assignments"]) + team_colors = team_assignments_json["team_colors"] + print("team colors:: ", team_colors) + print("col 1:", team_colors["1"]) # 4) Ball possession ball_sensor = BallAcquisitionSensor() ball_acquisition_list = ball_sensor.detect_ball_possession(player_tracks, ball_tracks) @@ -55,7 +57,7 @@ async def process_video(video_name: str, reference_court: str): H = get_homographies_from_service(tmp_video_path, tmp_ref_path) # 5) Draw overlays - player_draw = PlayerTrackDrawer() + player_draw = PlayerTrackDrawer(team_1_color=team_colors["1"], team_2_color=team_colors["2"]) ball_draw = BallTrackDrawer() player_vid_frames = player_draw.draw_annotations( @@ -66,9 +68,9 @@ async def process_video(video_name: str, reference_court: str): ) output_vid_frames = ball_draw.draw_annotations(player_vid_frames, ball_tracks) - top_down_overlay = TDOverlay(tmp_ref_path, base_court, xz=1280, yz=720) + top_down_overlay = TDOverlay(tmp_ref_path, base_court, t1_color=team_colors["1"], t2_color=team_colors["2"], xz=1280, yz=720) td_tracks = top_down_overlay.get_td_tracks(player_tracks, team_assignments, H) - minimap_frames = [np.zeros_like(frame) for frame in output_vid_frames] + minimap_frames = [np.zeros((720, 1280, 3), dtype=np.uint8) for _ in output_vid_frames] output_vid_minimap, control_stats = top_down_overlay.draw_overlay( minimap_frames, @@ -106,7 +108,8 @@ async def process_video(video_name: str, reference_court: str): "ball_tp": f"{ball_team_possessions}", "vid_name": f"{video_name}", "control_stats": json.dumps(control_stats), - 'pi_stats': json.dumps(passes_and_interceptions) + "pi_stats": json.dumps(passes_and_interceptions), + "team_colors": json.dumps(team_colors) }) if __name__ == "__main__": diff --git a/services/team_assigner_service/processing/team_assigner.py b/services/team_assigner_service/processing/team_assigner.py index 0ff88d8..27a3c9d 100644 --- a/services/team_assigner_service/processing/team_assigner.py +++ b/services/team_assigner_service/processing/team_assigner.py @@ -41,6 +41,24 @@ def _extract_features(self, pil_images): return np.array(features) + def get_rgb_from_histogram(self, hist_vector, bins=(8, 4, 4)): + hist_3d = hist_vector.reshape(bins) + + h_idx, s_idx, v_idx = np.unravel_index(np.argmax(hist_3d), bins) + + h_step = 180.0 / bins[0] + s_step = 256.0 / bins[1] + v_step = 256.0 / bins[2] + + h_val = int(h_idx * h_step + h_step / 2) + s_val = int(s_idx * s_step + s_step / 2) + v_val = int(v_idx * v_step + v_step / 2) + + hsv_pixel = np.array([[[h_val, s_val, v_val]]], dtype=np.uint8) + rgb_pixel = cv2.cvtColor(hsv_pixel, cv2.COLOR_HSV2RGB)[0][0] + + return tuple(map(int, rgb_pixel)) + def get_player_teams_global(self, vid_frames, player_tracks): player_features_map = defaultdict(list) frame_assignments = [dict() for _ in range(len(vid_frames))] @@ -73,7 +91,7 @@ def get_player_teams_global(self, vid_frames, player_tracks): for pid in unique_pids: # Average features per player id for global embedding (Remove noise) - if len(player_features_map[pid]) > 3: # Id should exist for alteast 5 frames to be used in kmeans + if len(player_features_map[pid]) > 3: # Id should exist for atleast this number of frames to be used in kmeans feats = np.array(player_features_map[pid]) avg_feat = np.mean(feats, axis=0) averaged_features.append(avg_feat) @@ -89,8 +107,13 @@ def get_player_teams_global(self, vid_frames, player_tracks): if len(unique_pids) >= 2: labels = kmeans.fit_predict(averaged_features) + + team1_color = self.get_rgb_from_histogram(kmeans.cluster_centers_[0], bins=(8, 4, 4)) + team2_color = self.get_rgb_from_histogram(kmeans.cluster_centers_[1], bins=(8, 4, 4)) + team_colors = {1: team1_color, 2: team2_color} else: labels = [1] * len(unique_pids) # Fallback in case of not enough players + team_colors = {1: (255, 255, 255), 2: (0, 0, 0)} pid_to_team = {pid: int(label) + 1 for pid, label in zip(unique_pids, labels)} @@ -100,7 +123,7 @@ def get_player_teams_global(self, vid_frames, player_tracks): team_id = pid_to_team[pid] frame_assignments[frame_id][pid] = team_id - return frame_assignments + return frame_assignments, team_colors def get_player_teams_over_frames(self, vid_frames, player_tracks): return self.get_player_teams_global(vid_frames, player_tracks) \ No newline at end of file diff --git a/services/team_assigner_service/team_assigner_service.py b/services/team_assigner_service/team_assigner_service.py index 94d7e0c..4a77a08 100644 --- a/services/team_assigner_service/team_assigner_service.py +++ b/services/team_assigner_service/team_assigner_service.py @@ -48,14 +48,15 @@ async def assign_teams( # Read video frames frames = read_video(str(tmp_video_path)) - team_assignments = team_assigner.get_player_teams_over_frames( + team_assignments, team_colors = team_assigner.get_player_teams_over_frames( vid_frames=frames, player_tracks=player_tracks, ) # Serialize payload = { - "team_assignments": serialize_team_assignments(team_assignments) + "team_assignments": serialize_team_assignments(team_assignments), + "team_colors": team_colors } safe = jsonable_encoder(payload) diff --git a/services/ui_service/plots.py b/services/ui_service/plots.py index 54cb20c..5f989c8 100644 --- a/services/ui_service/plots.py +++ b/services/ui_service/plots.py @@ -1,7 +1,7 @@ import matplotlib.pyplot as plt import numpy as np -def possession_plot(ball_tp): +def possession_plot(ball_tp, t1_color="#9cb2a0", t2_color="#9eaec6"): x, y = possession_to_percentages(ball_tp) y_percent = np.array(y) * 100 @@ -9,10 +9,10 @@ def possession_plot(ball_tp): ax_right = ax_left.twinx() for i, val in enumerate(y_percent): - ax_left.bar(i, val, color="#9eaec6", width=1.0) - ax_left.bar(i, 100 - val, bottom=val, color="#9cb2a0", width=1.0) + ax_left.bar(i, val, color=t1_color, width=1.0) + ax_left.bar(i, 100 - val, bottom=val, color=t2_color, width=1.0) - ax_left.plot(x, y_percent, color="black", linewidth=2) + ax_left.plot(x, y_percent, color="black", linewidth=3) ax_left.set_ylim(0, 100) ax_left.set_yticks(np.arange(0, 101, 5)) @@ -29,11 +29,11 @@ def possession_plot(ball_tp): return fig -def pi_plots(pi_stats): +def pi_plots(pi_stats, t1_color="#9cb2a0", t2_color="#9eaec6"): p1, p2, i1, i2 = extract_timeseries(pi_stats) frames = list(range(len(pi_stats))) - plot_pass = passes_plot(frames, p1, p2) - plot_intr = interceptions_plot(frames, i1, i2) + plot_pass = passes_plot(frames, p1, p2, t1_color, t2_color) + plot_intr = interceptions_plot(frames, i1, i2, t1_color, t2_color) return plot_pass, plot_intr def to_percent(v1, v2): @@ -75,29 +75,29 @@ def extract_timeseries(stats): return passes_t1, passes_t2, inter_t1, inter_t2 -def passes_plot(frames, passes_t1, passes_t2): +def passes_plot(frames, passes_t1, passes_t2, t1_color="#9cb2a0", t2_color="#9eaec6"): pct = to_percent(passes_t1, passes_t2) return percent_style_plot( - frames, pct, label="Passes" + frames, pct, label="Passes", t1_color=t1_color, t2_color=t2_color ) -def interceptions_plot(frames, inter_t1, inter_t2): +def interceptions_plot(frames, inter_t1, inter_t2, t1_color="#9cb2a0", t2_color="#9eaec6"): pct = to_percent(inter_t1, inter_t2) return percent_style_plot( - frames, pct, label="Interceptions" + frames, pct, "Interceptions", t1_color=t1_color, t2_color=t2_color ) -def percent_style_plot(x, pct_team1, label): +def percent_style_plot(x, pct_team1, label, t1_color="#9cb2a0", t2_color="#9eaec6"): y_percent = pct_team1 * 100 fig, ax_left = plt.subplots(figsize=(14, 4)) ax_right = ax_left.twinx() for i, val in enumerate(y_percent): - ax_left.bar(i, val, color="#9eaec6", width=1.0) - ax_left.bar(i, 100 - val, bottom=val, color="#9cb2a0", width=1.0) + ax_left.bar(i, val, color=t2_color, width=1.0) + ax_left.bar(i, 100 - val, bottom=val, color=t1_color, width=1.0) - ax_left.plot(x, y_percent, color="black", linewidth=2) + ax_left.plot(x, y_percent, color="black", linewidth=3) ax_left.set_ylim(0, 100) ax_left.set_yticks(np.arange(0, 101, 5)) @@ -114,7 +114,7 @@ def percent_style_plot(x, pct_team1, label): return fig -def control_plot(control_stats): +def control_plot(control_stats, t1_color="#9cb2a0", t2_color="#9eaec6"): y_percent = [] for frame in control_stats: a = float(frame.get("1", 0)) @@ -133,10 +133,10 @@ def control_plot(control_stats): ax_right = ax_left.twinx() for i, val in enumerate(y_percent): - ax_left.bar(i, val, color="#9cb2a0", width=1.0) - ax_left.bar(i, 100 - val, color="#9eaec6", bottom=val, width=1.0) + ax_left.bar(i, val, color=t1_color, width=1.0) + ax_left.bar(i, 100 - val, color=t2_color, bottom=val, width=1.0) - ax_left.plot(x, y_percent, color="black", linewidth=2) + ax_left.plot(x, y_percent, color="black", linewidth=3) ax_left.set_ylim(0, 100) ax_left.set_yticks(np.arange(0, 101, 10)) diff --git a/services/ui_service/tabs/inference_tab.py b/services/ui_service/tabs/inference_tab.py index 0e245e1..f862fb0 100644 --- a/services/ui_service/tabs/inference_tab.py +++ b/services/ui_service/tabs/inference_tab.py @@ -30,10 +30,16 @@ def run_inference(video_file, court_name): data = resp.json() vid_name = data["vid_name"] - - plot_poss = possession_plot(json.loads(data["ball_tp"])) - plot_ctrl = control_plot(json.loads(data["control_stats"])) - plot_pass, plot_intr = pi_plots(json.loads(data["pi_stats"])) + + team_colors = json.loads(data["team_colors"]) + team_hex_colors = { + k: '#%02x%02x%02x' % tuple(reversed(v)) + for k, v in team_colors.items() + } + + plot_poss = possession_plot(json.loads(data["ball_tp"]), team_hex_colors["1"], team_hex_colors["2"]) + plot_ctrl = control_plot(json.loads(data["control_stats"]), team_hex_colors["1"], team_hex_colors["2"]) + plot_pass, plot_intr = pi_plots(json.loads(data["pi_stats"]), team_hex_colors["1"], team_hex_colors["2"]) url_proc = f"{config.VIEWER_BASE}/video/{config.BUCKET_PROCESSED}/{vid_name}" url_mini = f"{config.VIEWER_BASE}/video/{config.BUCKET_MINIMAP}/{vid_name}" From b6598461a1a15ebc3bb79b6bb598e55cbf8f672e Mon Sep 17 00:00:00 2001 From: Isac Paulsson Date: Mon, 8 Dec 2025 15:58:29 +0100 Subject: [PATCH 5/6] desaturate colors for plots --- services/ui_service/tabs/inference_tab.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/services/ui_service/tabs/inference_tab.py b/services/ui_service/tabs/inference_tab.py index f862fb0..b3adbdd 100644 --- a/services/ui_service/tabs/inference_tab.py +++ b/services/ui_service/tabs/inference_tab.py @@ -1,12 +1,22 @@ import os +import cv2 import json import requests +import numpy as np import gradio as gr from shared.storage import upload_video import ui_service.config as config from ui_service.plots import possession_plot, control_plot, pi_plots from ui_service.utils import fetch_local_resource, list_courts +def to_desat_hex(bgr, factor=0.5): + pixel = np.array([[bgr]], dtype=np.uint8) + hsv = cv2.cvtColor(pixel, cv2.COLOR_BGR2HSV) + hsv[0, 0, 1] = np.clip(hsv[0, 0, 1] * factor, 0, 255).astype(np.uint8) + rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + r, g, b = rgb[0, 0] + return '#{:02x}{:02x}{:02x}'.format(r, g, b) + def run_inference(video_file, court_name): if not video_file or not court_name: return [None]*5 + ["Missing inputs"] @@ -33,7 +43,7 @@ def run_inference(video_file, court_name): team_colors = json.loads(data["team_colors"]) team_hex_colors = { - k: '#%02x%02x%02x' % tuple(reversed(v)) + k: to_desat_hex (v) for k, v in team_colors.items() } From 46f6de91b6f2f02460dcaa08166853c97ec4422f Mon Sep 17 00:00:00 2001 From: Isac Paulsson Date: Mon, 8 Dec 2025 16:14:22 +0100 Subject: [PATCH 6/6] saturation and brightness adjustment --- services/ui_service/tabs/inference_tab.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/services/ui_service/tabs/inference_tab.py b/services/ui_service/tabs/inference_tab.py index b3adbdd..1f473be 100644 --- a/services/ui_service/tabs/inference_tab.py +++ b/services/ui_service/tabs/inference_tab.py @@ -9,10 +9,15 @@ from ui_service.plots import possession_plot, control_plot, pi_plots from ui_service.utils import fetch_local_resource, list_courts -def to_desat_hex(bgr, factor=0.5): +def to_desat_hex(bgr, sat_factor=0.5, bright_factor=1.8): pixel = np.array([[bgr]], dtype=np.uint8) hsv = cv2.cvtColor(pixel, cv2.COLOR_BGR2HSV) - hsv[0, 0, 1] = np.clip(hsv[0, 0, 1] * factor, 0, 255).astype(np.uint8) + + hsv[0, 0, 1] = np.clip(hsv[0, 0, 1] * sat_factor, 0, 255).astype(np.uint8) + + val_new = hsv[0, 0, 2] * bright_factor + hsv[0, 0, 2] = np.clip(val_new, 0, 255).astype(np.uint8) + rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) r, g, b = rgb[0, 0] return '#{:02x}{:02x}{:02x}'.format(r, g, b)