Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions canvas/top_down_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def draw_voronoi(self, minimap, subdiv, teams, alpha=0.2):
def draw_players(self, minimap, positions, teams):
for (mx, my), t in zip(positions, teams):
cv2.circle(minimap, (mx, my), 10, self.color[t], -1)
cv2.circle(minimap, (mx, my), 10, (0, 0, 0), 2, lineType=cv2.LINE_AA)
return minimap

def draw_overlay(self, frames, td_track, x=0, y=0):
Expand Down
1 change: 0 additions & 1 deletion services/orchestrator_service/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ COPY utils /app/utils
COPY canvas /app/canvas
COPY shared /app/shared
COPY ball_acq /app/ball_acq
COPY team_assigner /app/team_assigner

RUN pip install --no-cache-dir -r requirements.txt

Expand Down
2 changes: 1 addition & 1 deletion services/orchestrator_service/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_team_assignments_from_service(local_video_path: str, player_tracks):
r = requests.post(url, files=files)
r.raise_for_status()
data = r.json()
return data["team_assignments"]
return data

def get_tracks_from_service(local_video_path: str):
url = DETECTOR_URL
Expand Down
15 changes: 9 additions & 6 deletions services/orchestrator_service/orchestrator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ async def process_video(video_name: str, reference_court: str):

# 3) Get team assignments
team_assignments_json = get_team_assignments_from_service(tmp_video_path, player_tracks)
team_assignments = deserialize_team_assignments(team_assignments_json)

team_assignments = deserialize_team_assignments(team_assignments_json["team_assignments"])
team_colors = team_assignments_json["team_colors"]
print("team colors:: ", team_colors)
print("col 1:", team_colors["1"])
# 4) Ball possession
ball_sensor = BallAcquisitionSensor()
ball_acquisition_list = ball_sensor.detect_ball_possession(player_tracks, ball_tracks)
Expand All @@ -55,7 +57,7 @@ async def process_video(video_name: str, reference_court: str):
H = get_homographies_from_service(tmp_video_path, tmp_ref_path)

# 5) Draw overlays
player_draw = PlayerTrackDrawer()
player_draw = PlayerTrackDrawer(team_1_color=team_colors["1"], team_2_color=team_colors["2"])
ball_draw = BallTrackDrawer()

player_vid_frames = player_draw.draw_annotations(
Expand All @@ -66,9 +68,9 @@ async def process_video(video_name: str, reference_court: str):
)
output_vid_frames = ball_draw.draw_annotations(player_vid_frames, ball_tracks)

top_down_overlay = TDOverlay(tmp_ref_path, base_court, xz=1280, yz=720)
top_down_overlay = TDOverlay(tmp_ref_path, base_court, t1_color=team_colors["1"], t2_color=team_colors["2"], xz=1280, yz=720)
td_tracks = top_down_overlay.get_td_tracks(player_tracks, team_assignments, H)
minimap_frames = [np.zeros_like(frame) for frame in output_vid_frames]
minimap_frames = [np.zeros((720, 1280, 3), dtype=np.uint8) for _ in output_vid_frames]

output_vid_minimap, control_stats = top_down_overlay.draw_overlay(
minimap_frames,
Expand Down Expand Up @@ -106,7 +108,8 @@ async def process_video(video_name: str, reference_court: str):
"ball_tp": f"{ball_team_possessions}",
"vid_name": f"{video_name}",
"control_stats": json.dumps(control_stats),
'pi_stats': json.dumps(passes_and_interceptions)
"pi_stats": json.dumps(passes_and_interceptions),
"team_colors": json.dumps(team_colors)
})

if __name__ == "__main__":
Expand Down
8 changes: 0 additions & 8 deletions services/team_assigner_service/.dockerignore

This file was deleted.

16 changes: 6 additions & 10 deletions services/team_assigner_service/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
FROM pytorch/pytorch:2.9.1-cuda13.0-cudnn9-runtime
FROM python:3.12-slim

WORKDIR /app

RUN apt-get update && apt-get install -y \
ffmpeg libsm6 libxext6 \
&& rm -rf /var/lib/apt/lists/*

COPY services/team_assigner_service/requirements.txt /app/
COPY services/team_assigner_service/requirements.txt .

# Install PyTorch with CUDA 13.0 runtime first
RUN pip3 install --upgrade pip && \
pip3 install -r requirements.txt
# pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu130 && \

COPY utils /app/utils
COPY shared /app/shared
COPY team_assigner /app/team_assigner
COPY canvas /app/canvas
COPY services/team_assigner_service/team_assigner_service.py /app/
COPY utils/ ./utils/
COPY shared/ ./shared/
COPY services/team_assigner_service/ ./team_assigner_service/

EXPOSE 8000

CMD ["uvicorn", "team_assigner_service:app", "--host", "0.0.0.0", "--port", "8000"]
CMD ["uvicorn", "team_assigner_service.team_assigner_service:app", "--host", "0.0.0.0", "--port", "8000"]
File renamed without changes.
129 changes: 129 additions & 0 deletions services/team_assigner_service/processing/team_assigner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import cv2
import numpy as np
from PIL import Image
from collections import defaultdict
from sklearn.cluster import KMeans

import sys
sys.path.append('../')

class TeamAssigner:
def __init__(self, crop_factor = 0.3):
self.crop_factor = crop_factor

def get_center_crop(self, img_array, crop_factor=0.4, skew_factor=0.1):
h, w, _ = img_array.shape
h_center, w_center = h // 2, w // 2
h_crop, w_crop = int(h * crop_factor), int(w * crop_factor)
h_skew = int(h * skew_factor)

y1 = max(0, h_center - h_crop // 2 - h_skew)
y2 = min(h, h_center + h_crop // 2 - h_skew)
x1 = max(0, w_center - w_crop // 2)
x2 = min(w, w_center + w_crop // 2)

return img_array[y1:y2, x1:x2]

def _extract_features(self, pil_images):
features = []
bins = (8, 4, 4)

for pil in pil_images:
img = np.array(pil)
img_crop = self.get_center_crop(img, crop_factor=self.crop_factor)

hsv = cv2.cvtColor(img_crop, cv2.COLOR_RGB2HSV)

hist = cv2.calcHist([hsv], [0, 1, 2], None, bins, [0, 180, 0, 256, 0, 256])
cv2.normalize(hist, hist)

features.append(hist.flatten())

return np.array(features)

def get_rgb_from_histogram(self, hist_vector, bins=(8, 4, 4)):
hist_3d = hist_vector.reshape(bins)

h_idx, s_idx, v_idx = np.unravel_index(np.argmax(hist_3d), bins)

h_step = 180.0 / bins[0]
s_step = 256.0 / bins[1]
v_step = 256.0 / bins[2]

h_val = int(h_idx * h_step + h_step / 2)
s_val = int(s_idx * s_step + s_step / 2)
v_val = int(v_idx * v_step + v_step / 2)

hsv_pixel = np.array([[[h_val, s_val, v_val]]], dtype=np.uint8)
rgb_pixel = cv2.cvtColor(hsv_pixel, cv2.COLOR_HSV2RGB)[0][0]

return tuple(map(int, rgb_pixel))

def get_player_teams_global(self, vid_frames, player_tracks):
player_features_map = defaultdict(list)
frame_assignments = [dict() for _ in range(len(vid_frames))]

for frame_id, player_track in enumerate(player_tracks):
frame = vid_frames[frame_id]

frame_crops = []
frame_pids = []

for pid, info in player_track.items():
x1, y1, x2, y2 = map(int, info['bbox'])
crop = frame[y1:y2, x1:x2]

pil_img = Image.fromarray(crop)

frame_crops.append(pil_img)
frame_pids.append(pid)

if not frame_crops:
continue

feats = self._extract_features(frame_crops)

for pid, feat in zip(frame_pids, feats):
player_features_map[pid].append(feat)

unique_pids = list(player_features_map.keys())
averaged_features = []

for pid in unique_pids:
# Average features per player id for global embedding (Remove noise)
if len(player_features_map[pid]) > 3: # Id should exist for atleast this number of frames to be used in kmeans
feats = np.array(player_features_map[pid])
avg_feat = np.mean(feats, axis=0)
averaged_features.append(avg_feat)
else:
averaged_features.append(averaged_features[-1])

if not unique_pids:
return frame_assignments

print(f" # Player ID's: {len(unique_pids)}")

kmeans = KMeans(n_clusters=2, random_state=42)

if len(unique_pids) >= 2:
labels = kmeans.fit_predict(averaged_features)

team1_color = self.get_rgb_from_histogram(kmeans.cluster_centers_[0], bins=(8, 4, 4))
team2_color = self.get_rgb_from_histogram(kmeans.cluster_centers_[1], bins=(8, 4, 4))
team_colors = {1: team1_color, 2: team2_color}
else:
labels = [1] * len(unique_pids) # Fallback in case of not enough players
team_colors = {1: (255, 255, 255), 2: (0, 0, 0)}

pid_to_team = {pid: int(label) + 1 for pid, label in zip(unique_pids, labels)}

for frame_id, player_track in enumerate(player_tracks):
for pid in player_track.keys():
if pid in pid_to_team:
team_id = pid_to_team[pid]
frame_assignments[frame_id][pid] = team_id

return frame_assignments, team_colors

def get_player_teams_over_frames(self, vid_frames, player_tracks):
return self.get_player_teams_global(vid_frames, player_tracks)
3 changes: 1 addition & 2 deletions services/team_assigner_service/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ requests
boto3
opencv-python-headless
pillow
transformers
python-multipart
ultralytics
scikit-learn
12 changes: 5 additions & 7 deletions services/team_assigner_service/team_assigner_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json

from utils import read_video
from team_assigner import TeamAssigner
from team_assigner_service.processing.team_assigner import TeamAssigner

def serialize_team_assignments(assignments):
out = []
Expand All @@ -23,10 +23,7 @@ def serialize_team_assignments(assignments):

app = FastAPI()

team_assigner = TeamAssigner(
team_A="WHITE shirt",
team_B="DARK BLUE shirt"
)
team_assigner = TeamAssigner(crop_factor=0.2)

@app.post("/assign_teams")
async def assign_teams(
Expand All @@ -51,14 +48,15 @@ async def assign_teams(
# Read video frames
frames = read_video(str(tmp_video_path))

team_assignments = team_assigner.get_player_teams_over_frames(
team_assignments, team_colors = team_assigner.get_player_teams_over_frames(
vid_frames=frames,
player_tracks=player_tracks,
)

# Serialize
payload = {
"team_assignments": serialize_team_assignments(team_assignments)
"team_assignments": serialize_team_assignments(team_assignments),
"team_colors": team_colors
}

safe = jsonable_encoder(payload)
Expand Down
38 changes: 19 additions & 19 deletions services/ui_service/plots.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import matplotlib.pyplot as plt
import numpy as np

def possession_plot(ball_tp):
def possession_plot(ball_tp, t1_color="#9cb2a0", t2_color="#9eaec6"):
x, y = possession_to_percentages(ball_tp)
y_percent = np.array(y) * 100

fig, ax_left = plt.subplots(figsize=(14, 4))
ax_right = ax_left.twinx()

for i, val in enumerate(y_percent):
ax_left.bar(i, val, color="#9eaec6", width=1.0)
ax_left.bar(i, 100 - val, bottom=val, color="#9cb2a0", width=1.0)
ax_left.bar(i, val, color=t1_color, width=1.0)
ax_left.bar(i, 100 - val, bottom=val, color=t2_color, width=1.0)

ax_left.plot(x, y_percent, color="black", linewidth=2)
ax_left.plot(x, y_percent, color="black", linewidth=3)

ax_left.set_ylim(0, 100)
ax_left.set_yticks(np.arange(0, 101, 5))
Expand All @@ -29,11 +29,11 @@ def possession_plot(ball_tp):

return fig

def pi_plots(pi_stats):
def pi_plots(pi_stats, t1_color="#9cb2a0", t2_color="#9eaec6"):
p1, p2, i1, i2 = extract_timeseries(pi_stats)
frames = list(range(len(pi_stats)))
plot_pass = passes_plot(frames, p1, p2)
plot_intr = interceptions_plot(frames, i1, i2)
plot_pass = passes_plot(frames, p1, p2, t1_color, t2_color)
plot_intr = interceptions_plot(frames, i1, i2, t1_color, t2_color)
return plot_pass, plot_intr

def to_percent(v1, v2):
Expand Down Expand Up @@ -75,29 +75,29 @@ def extract_timeseries(stats):

return passes_t1, passes_t2, inter_t1, inter_t2

def passes_plot(frames, passes_t1, passes_t2):
def passes_plot(frames, passes_t1, passes_t2, t1_color="#9cb2a0", t2_color="#9eaec6"):
pct = to_percent(passes_t1, passes_t2)
return percent_style_plot(
frames, pct, label="Passes"
frames, pct, label="Passes", t1_color=t1_color, t2_color=t2_color
)

def interceptions_plot(frames, inter_t1, inter_t2):
def interceptions_plot(frames, inter_t1, inter_t2, t1_color="#9cb2a0", t2_color="#9eaec6"):
pct = to_percent(inter_t1, inter_t2)
return percent_style_plot(
frames, pct, label="Interceptions"
frames, pct, "Interceptions", t1_color=t1_color, t2_color=t2_color
)

def percent_style_plot(x, pct_team1, label):
def percent_style_plot(x, pct_team1, label, t1_color="#9cb2a0", t2_color="#9eaec6"):
y_percent = pct_team1 * 100

fig, ax_left = plt.subplots(figsize=(14, 4))
ax_right = ax_left.twinx()

for i, val in enumerate(y_percent):
ax_left.bar(i, val, color="#9eaec6", width=1.0)
ax_left.bar(i, 100 - val, bottom=val, color="#9cb2a0", width=1.0)
ax_left.bar(i, val, color=t2_color, width=1.0)
ax_left.bar(i, 100 - val, bottom=val, color=t1_color, width=1.0)

ax_left.plot(x, y_percent, color="black", linewidth=2)
ax_left.plot(x, y_percent, color="black", linewidth=3)

ax_left.set_ylim(0, 100)
ax_left.set_yticks(np.arange(0, 101, 5))
Expand All @@ -114,7 +114,7 @@ def percent_style_plot(x, pct_team1, label):

return fig

def control_plot(control_stats):
def control_plot(control_stats, t1_color="#9cb2a0", t2_color="#9eaec6"):
y_percent = []
for frame in control_stats:
a = float(frame.get("1", 0))
Expand All @@ -133,10 +133,10 @@ def control_plot(control_stats):
ax_right = ax_left.twinx()

for i, val in enumerate(y_percent):
ax_left.bar(i, val, color="#9cb2a0", width=1.0)
ax_left.bar(i, 100 - val, color="#9eaec6", bottom=val, width=1.0)
ax_left.bar(i, val, color=t1_color, width=1.0)
ax_left.bar(i, 100 - val, color=t2_color, bottom=val, width=1.0)

ax_left.plot(x, y_percent, color="black", linewidth=2)
ax_left.plot(x, y_percent, color="black", linewidth=3)

ax_left.set_ylim(0, 100)
ax_left.set_yticks(np.arange(0, 101, 10))
Expand Down
Loading