Skip to content

Commit 2cd5855

Browse files
author
Mark Saroufim
committed
Refactor: Extract shared problem sync logic into libkernelbot
- Create problem_sync.py module with shared logic for downloading repos, parsing competition YAMLs, and creating/updating leaderboards - Simplify API endpoint to use shared sync_problems() function - Reduces code duplication between API and Discord cog
1 parent 35c1392 commit 2cd5855

2 files changed

Lines changed: 320 additions & 171 deletions

File tree

src/kernelbot/api/main.py

Lines changed: 20 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,10 @@
33
import datetime
44
import json
55
import os
6-
import subprocess
7-
import tempfile
86
import time
97
from dataclasses import asdict
10-
from pathlib import Path
118
from typing import Annotated, Any, Optional
129

13-
import yaml
1410
from fastapi import Depends, FastAPI, Header, HTTPException, Request, UploadFile
1511
from fastapi.responses import JSONResponse, StreamingResponse
1612

@@ -20,6 +16,7 @@
2016
from libkernelbot.consts import SubmissionMode
2117
from libkernelbot.db_types import IdentityType
2218
from libkernelbot.leaderboard_db import LeaderboardDB, LeaderboardRankedEntry
19+
from libkernelbot.problem_sync import sync_problems
2320
from libkernelbot.submission import (
2421
ProcessedSubmissionRequest,
2522
SubmissionRequest,
@@ -28,7 +25,6 @@
2825
from libkernelbot.task import make_task_definition
2926
from libkernelbot.utils import (
3027
KernelBotError,
31-
parse_deadline,
3228
resolve_problem_directory,
3329
setup_logging,
3430
)
@@ -610,7 +606,7 @@ async def admin_get_submission(
610606

611607

612608
@app.post("/admin/update-problems")
613-
async def admin_update_problems( # noqa: C901
609+
async def admin_update_problems(
614610
payload: dict,
615611
_: Annotated[None, Depends(require_admin)],
616612
db_context=Depends(get_db),
@@ -621,179 +617,32 @@ async def admin_update_problems( # noqa: C901
621617
Downloads the repository, parses competition YAML files, and creates/updates leaderboards.
622618
"""
623619
repository = payload.get("repository", "gpu-mode/reference-kernels")
624-
problem_set = payload.get("problem_set") # Optional - if None, process all
620+
problem_set = payload.get("problem_set")
625621
branch = payload.get("branch", "main")
626622
force = payload.get("force", False)
627623

628-
if "/" in branch:
629-
raise HTTPException(status_code=400, detail="Branch names with slashes are not supported")
630-
631-
url = f"https://github.com/{repository}/archive/{branch}.zip"
632-
folder_name = repository.split("/")[-1] + "-" + branch
633-
634-
created = []
635-
updated = []
636-
skipped = []
637-
errors = []
638-
639-
with tempfile.TemporaryDirectory() as temp_dir:
640-
# Download the repository
641-
try:
642-
subprocess.check_call(
643-
["wget", "-q", "-O", temp_dir + "/problems.zip", url],
644-
encoding="utf-8",
645-
timeout=60,
646-
)
647-
except subprocess.CalledProcessError as e:
648-
raise HTTPException(
649-
status_code=400,
650-
detail=f"Could not download repository from {url}: {e}"
651-
) from e
652-
except subprocess.TimeoutExpired as e:
653-
raise HTTPException(
654-
status_code=408,
655-
detail="Timeout downloading repository"
656-
) from e
657-
658-
# Unzip
659-
try:
660-
subprocess.check_call(
661-
["unzip", "-q", temp_dir + "/problems.zip", "-d", temp_dir],
662-
encoding="utf-8",
663-
timeout=30,
664-
)
665-
except subprocess.CalledProcessError as e:
666-
raise HTTPException(
667-
status_code=400,
668-
detail=f"Could not unzip repository: {e}"
669-
) from e
670-
671-
problem_dir = Path(temp_dir) / folder_name / "problems"
672-
if not problem_dir.exists():
673-
raise HTTPException(
674-
status_code=400,
675-
detail="No 'problems' directory found in repository"
676-
)
677-
678-
# Find competition YAML files
679-
if problem_set is None:
680-
yaml_files = list(problem_dir.glob("*.yaml"))
681-
else:
682-
yaml_file = problem_dir / f"{problem_set}.yaml"
683-
if not yaml_file.exists():
684-
available = [f.stem for f in problem_dir.glob("*.yaml")]
685-
raise HTTPException(
686-
status_code=400,
687-
detail=f"Problem set '{problem_set}' not found. Available: {available}"
688-
)
689-
yaml_files = [yaml_file]
690-
691-
# Get existing leaderboards
692-
with db_context as db:
693-
existing_leaderboards = {lb["name"]: lb for lb in db.get_leaderboards()}
694-
695-
# Process each competition YAML
696-
for yaml_file in yaml_files:
697-
try:
698-
with open(yaml_file) as f:
699-
competition = yaml.safe_load(f)
700-
701-
for problem in competition.get("problems", []):
702-
problem_name = problem.get("name")
703-
directory = problem.get("directory")
704-
deadline_str = problem.get("deadline")
705-
gpus = problem.get("gpus", [])
706-
707-
if not problem_name or not directory:
708-
errors.append({"name": problem_name or "unknown", "error": "Missing name or directory"})
709-
continue
710-
711-
source_path = problem_dir / directory
712-
if not source_path.exists():
713-
errors.append({"name": problem_name, "error": f"Directory {directory} not found"})
714-
continue
715-
716-
try:
717-
definition = make_task_definition(source_path)
718-
except Exception as e:
719-
errors.append({"name": problem_name, "error": f"Failed to parse task.yml: {e}"})
720-
continue
721-
722-
deadline = parse_deadline(deadline_str) if deadline_str else None
723-
if deadline is None:
724-
deadline = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=365)
725-
726-
# Check if leaderboard exists
727-
if problem_name in existing_leaderboards:
728-
old_lb = existing_leaderboards[problem_name]
729-
# Check if update is needed
730-
old_deadline = old_lb["deadline"]
731-
if hasattr(old_deadline, "tzinfo") and old_deadline.tzinfo is None:
732-
old_deadline = old_deadline.replace(tzinfo=datetime.timezone.utc)
733-
734-
deadline_changed = old_deadline != deadline
735-
task_changed = old_lb["task"] != definition.task
736-
737-
if not deadline_changed and not task_changed:
738-
skipped.append({"name": problem_name, "reason": "no changes"})
739-
continue
740-
741-
if task_changed and not force:
742-
# Check if only safe changes (description/templates)
743-
old_task = old_lb["task"]
744-
new_task = definition.task
745-
if (old_task.files != new_task.files or
746-
old_task.config != new_task.config or
747-
old_task.lang != new_task.lang or
748-
old_task.benchmarks != new_task.benchmarks):
749-
skipped.append({
750-
"name": problem_name,
751-
"reason": "significant task changes require --force"
752-
})
753-
continue
754-
755-
# Update the leaderboard
756-
try:
757-
with db_context as db:
758-
db.update_leaderboard(problem_name, deadline, definition)
759-
updated.append(problem_name)
760-
except Exception as e:
761-
errors.append({"name": problem_name, "error": f"Update failed: {e}"})
762-
else:
763-
# Create new leaderboard
764-
if not gpus:
765-
gpus = definition.gpus if definition.gpus else []
766-
if not gpus:
767-
errors.append({"name": problem_name, "error": "No GPUs specified in task.yml or YAML"})
768-
continue
769-
770-
try:
771-
with db_context as db:
772-
db.create_leaderboard(
773-
name=problem_name,
774-
deadline=deadline,
775-
definition=definition,
776-
creator_id=0, # API-created
777-
forum_id=-1, # No Discord forum
778-
gpu_types=gpus,
779-
)
780-
created.append(problem_name)
781-
except Exception as e:
782-
errors.append({"name": problem_name, "error": f"Create failed: {e}"})
783-
784-
except yaml.YAMLError as e:
785-
errors.append({"name": yaml_file.stem, "error": f"Invalid YAML: {e}"})
786-
except Exception as e:
787-
errors.append({"name": yaml_file.stem, "error": str(e)})
624+
try:
625+
result = sync_problems(
626+
db_context=db_context,
627+
repository=repository,
628+
problem_set=problem_set,
629+
branch=branch,
630+
force=force,
631+
creator_id=0, # API-created
632+
forum_id=-1, # No Discord forum
633+
)
634+
except ValueError as e:
635+
raise HTTPException(status_code=400, detail=str(e)) from e
788636

789637
return {
790638
"status": "ok",
791-
"created": created,
792-
"updated": updated,
793-
"skipped": skipped,
794-
"errors": errors,
639+
"created": result.created,
640+
"updated": result.updated,
641+
"skipped": result.skipped,
642+
"errors": result.errors,
795643
}
796644

645+
797646
@app.get("/leaderboards")
798647
async def get_leaderboards(db_context=Depends(get_db)):
799648
"""An endpoint that returns all leaderboards.

0 commit comments

Comments
 (0)