Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 63 additions & 9 deletions alphatrion/artifact/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
class Artifact:
def __init__(self, team_id: str, insecure: bool = False):
self._team_id = team_id
self._url = os.environ.get(envs.ARTIFACT_REGISTRY_URL)
self._url = self._url.replace("https://", "").replace("http://", "")
self._url = get_registry_url()
self._client = oras.client.OrasClient(
hostname=self._url.strip("/"), auth_backend="token", insecure=insecure
)
Expand Down Expand Up @@ -51,9 +50,8 @@ def push(
if version is None:
version = utiltime.now_2_hash()

url = self._url if self._url.endswith("/") else f"{self._url}/"
path = f"{self._team_id}/{repo_name}:{version}"
target = f"{url}{path}"
target = f"{self._url}/{path}"

try:
self._client.push(target, files=files_to_push, disable_path_validation=True)
Expand All @@ -63,19 +61,75 @@ def push(
return path

def list_versions(self, repo_name: str) -> list[str]:
url = self._url if self._url.endswith("/") else f"{self._url}/"
target = f"{url}{self._team_id}/{repo_name}"
target = f"{self._url}/{self._team_id}/{repo_name}"
try:
tags = self._client.get_tags(target)
return tags
except Exception as e:
raise RuntimeError("Failed to list artifacts versions") from e
# Check if it's a "not found" error (404, repository doesn't exist)
# TODO: it's not a proper way but let's do it for now.
error_msg = str(e).lower()
if (
"404" in error_msg
or "not found" in error_msg
or "does not exist" in error_msg
):
# Return empty list if repository doesn't exist yet
# This is expected for projects without artifacts
return []
# Re-raise other errors
raise RuntimeError(f"Failed to list artifacts versions: {e}") from e

def pull(
self, repo_name: str, version: str, output_dir: str | None = None
) -> list[str]:
"""
Pull artifacts from the registry.

:param repo_name: the name of the repository to pull from
:param version: the version (tag) to pull
:param output_dir: optional directory to save files to
(defaults to ORAS temp directory)
:return: list of absolute file paths that were downloaded
"""
path = f"{self._team_id}/{repo_name}:{version}"
target = f"{self._url}/{path}"

if output_dir:
os.makedirs(output_dir, exist_ok=True)
original_dir = os.getcwd()
os.chdir(output_dir)

try:
# ORAS client returns list of filenames
filenames = self._client.pull(target)

# Get current directory (where files were downloaded)
download_dir = os.getcwd()

# Return absolute paths to downloaded files
return [os.path.abspath(os.path.join(download_dir, f)) for f in filenames]
except Exception as e:
raise RuntimeError(f"Failed to pull artifacts: {e}") from e
finally:
if output_dir:
os.chdir(original_dir)

def delete(self, repo_name: str, versions: str | list[str]):
url = self._url if self._url.endswith("/") else f"{self._url}/"
target = f"{url}{self._team_id}/{repo_name}"
target = f"{self._url}/{self._team_id}/{repo_name}"

try:
self._client.delete_tags(target, tags=versions)
except Exception as e:
raise RuntimeError("Failed to delete artifact versions") from e


def get_registry_url() -> str:
"""Get the ORAS registry URL from environment variables."""
registry_url = os.environ.get(envs.ARTIFACT_REGISTRY_URL)
if not registry_url:
raise RuntimeError("ARTIFACT_REGISTRY_URL not configured")
# Ensure URL has scheme
if not registry_url.startswith(("http://", "https://")):
registry_url = f"http://{registry_url}"
return registry_url.rstrip("/")
11 changes: 9 additions & 2 deletions alphatrion/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,13 @@ async def log_metrics(metrics: dict[str, float]) -> bool:
async def log_execution(
output: dict[str, Any],
input: dict[str, Any] | None = None,
phase: str = "success",
kind: ExecutionKind = ExecutionKind.RUN,
):
execution = None

if kind == ExecutionKind.RUN:
execution = build_run_execution(output=output, input=input)
execution = build_run_execution(output=output, input=input, phase=phase)
else:
raise NotImplementedError(
f"Logging record of kind {execution.kind} is not implemented yet."
Expand Down Expand Up @@ -196,5 +197,11 @@ async def log_execution(
)
runtime.metadb.update_run(
run_id=current_run_id.get(),
meta={EXECUTION_RESULT: {"path": path, "size": file_size}},
meta={
EXECUTION_RESULT: {
"path": path,
"size": file_size,
"file_name": "execution.json",
}
},
)
102 changes: 1 addition & 101 deletions alphatrion/server/cmd/app.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
# ruff: noqa: E501
# ruff: noqa: B904

import os
from importlib.metadata import version

import httpx
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from strawberry.fastapi import GraphQLRouter

from alphatrion import envs
from alphatrion.server.graphql.schema import schema

app = FastAPI()
Expand Down Expand Up @@ -41,99 +37,3 @@ def health_check():
@app.get("/version")
def get_version():
return {"version": version("alphatrion"), "status": "ok"}


# ORAS Registry Proxy Endpoints
def get_registry_url() -> str:
"""Get the ORAS registry URL from environment variables."""
registry_url = os.environ.get(envs.ARTIFACT_REGISTRY_URL)
if not registry_url:
raise HTTPException(
status_code=500, detail="ARTIFACT_REGISTRY_URL not configured"
)
# Ensure URL has scheme
if not registry_url.startswith(("http://", "https://")):
# Default to https if no scheme specified
registry_url = f"https://{registry_url}"
return registry_url.rstrip("/")


@app.get("/api/artifacts/repositories")
async def list_repositories():
"""Proxy request to ORAS registry to list all repositories."""
registry_url = get_registry_url()
async with httpx.AsyncClient() as client:
try:
response = await client.get(
f"{registry_url}/v2/_catalog",
timeout=30.0,
)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
raise HTTPException(status_code=500, detail=f"Registry request failed: {e}")


@app.get("/api/artifacts/repositories/{team}/{project}/tags")
async def list_tags(team: str, project: str):
"""Proxy request to ORAS registry to list tags for a repository."""
registry_url = get_registry_url()
repo_path = f"{team}/{project}"
async with httpx.AsyncClient() as client:
try:
response = await client.get(
f"{registry_url}/v2/{repo_path}/tags/list",
timeout=30.0,
)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
raise HTTPException(status_code=500, detail=f"Failed to list tags: {e}")


@app.get("/api/artifacts/repositories/{team}/{project}/manifests/{tag}")
async def get_manifest(team: str, project: str, tag: str):
"""Proxy request to ORAS registry to get manifest for a specific tag."""
registry_url = get_registry_url()
repo_path = f"{team}/{project}"
async with httpx.AsyncClient() as client:
try:
response = await client.get(
f"{registry_url}/v2/{repo_path}/manifests/{tag}",
headers={
"Accept": "application/vnd.oci.image.manifest.v1+json, application/vnd.docker.distribution.manifest.v2+json"
},
timeout=30.0,
)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
raise HTTPException(status_code=500, detail=f"Failed to get manifest: {e}")


@app.get("/api/artifacts/repositories/{team}/{project}/blobs/{digest:path}")
async def get_blob(team: str, project: str, digest: str):
"""Proxy request to ORAS registry to get blob content."""
registry_url = get_registry_url()
repo_path = f"{team}/{project}"
async with httpx.AsyncClient() as client:
try:
response = await client.get(
f"{registry_url}/v2/{repo_path}/blobs/{digest}",
timeout=30.0,
)
response.raise_for_status()
# Return raw blob content
return Response(
content=response.content,
media_type=response.headers.get(
"content-type", "application/octet-stream"
),
headers={
"Content-Disposition": response.headers.get(
"Content-Disposition", ""
),
},
)
except httpx.HTTPError as e:
raise HTTPException(status_code=500, detail=f"Failed to get blob: {e}")
81 changes: 81 additions & 0 deletions alphatrion/server/graphql/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import os
import uuid
from datetime import datetime

import httpx
import strawberry

from alphatrion.artifact import artifact
from alphatrion.storage import runtime
from alphatrion.storage.sql_models import Status

from .types import (
AddUserToTeamInput,
ArtifactContent,
ArtifactRepository,
ArtifactTag,
CreateTeamInput,
CreateUserInput,
Experiment,
Expand Down Expand Up @@ -288,6 +294,81 @@ def list_exps_by_timeframe(
for e in experiments
]

@staticmethod
async def list_artifact_repositories() -> list[ArtifactRepository]:
"""List all repositories in the ORAS registry."""

registry_url = artifact.get_registry_url()
async with httpx.AsyncClient() as client:
try:
response = await client.get(
f"{registry_url}/v2/_catalog",
timeout=30.0,
)
response.raise_for_status()
data = response.json()
repositories = data.get("repositories", [])
return [ArtifactRepository(name=repo) for repo in repositories]
except httpx.HTTPError as e:
raise RuntimeError(f"Registry request failed: {e}") from e

@staticmethod
async def list_artifact_tags(
team_id: str, project_id: str, repo_type: str | None = None
) -> list[ArtifactTag]:
"""List tags for a repository."""

arf = artifact.Artifact(team_id=team_id, insecure=True)
# Append repo_type suffix to project_id if provided
# (e.g., "project/execution" or "project/checkpoint")
repo_path = f"{project_id}/{repo_type}" if repo_type else project_id
return [ArtifactTag(name=tag) for tag in arf.list_versions(repo_path)]

@staticmethod
async def get_artifact_content(
team_id: str, project_id: str, tag: str, repo_type: str | None = None
) -> ArtifactContent:
"""Get artifact content from registry."""
try:
# Initialize artifact client
arf = artifact.Artifact(team_id=team_id, insecure=True)

# Construct repository path
repo_path = f"{project_id}/{repo_type}" if repo_type else project_id

# Pull the artifact - ORAS will manage temp directory
# Returns absolute paths to files in ORAS temp directory
# Note: One potential issue is if we download too many large files,
# it may fill up disk space. For now we assume artifacts are
# reasonably sized and/or users will manage their registry storage.
file_paths = arf.pull(repo_name=repo_path, version=tag)

if not file_paths:
raise RuntimeError("No files found in artifact")

# Read first file content (file_paths now contains absolute paths)
file_path = file_paths[0]
with open(file_path, encoding="utf-8") as f:
content = f.read()

# Get filename from path
filename = os.path.basename(file_path)

# Determine content type based on file extension
# TODO: for multiple files, this is not right.
if filename.endswith(".json"):
content_type = "application/json"
elif filename.endswith(".txt") or filename.endswith(".log"):
content_type = "text/plain"
else:
content_type = "text/plain"

return ArtifactContent(
filename=filename, content=content, content_type=content_type
)
except Exception as e:
raise RuntimeError(f"Failed to get artifact content: {e}") from e


class GraphQLMutations:
@staticmethod
Expand Down
31 changes: 31 additions & 0 deletions alphatrion/server/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from alphatrion.server.graphql.resolvers import GraphQLMutations, GraphQLResolvers
from alphatrion.server.graphql.types import (
AddUserToTeamInput,
ArtifactContent,
ArtifactRepository,
ArtifactTag,
CreateTeamInput,
CreateUserInput,
Experiment,
Expand Down Expand Up @@ -81,6 +84,34 @@ def runs(

run: Run | None = strawberry.field(resolver=GraphQLResolvers.get_run)

# Artifact queries
@strawberry.field
async def artifact_repos(self) -> list[ArtifactRepository]:
return await GraphQLResolvers.list_artifact_repositories()

@strawberry.field
async def artifact_tags(
self,
team_id: strawberry.ID,
project_id: strawberry.ID,
repo_type: str | None = None,
) -> list[ArtifactTag]:
return await GraphQLResolvers.list_artifact_tags(
str(team_id), str(project_id), repo_type
)

@strawberry.field
async def artifact_content(
self,
team_id: strawberry.ID,
project_id: strawberry.ID,
tag: str,
repo_type: str | None = None,
) -> ArtifactContent:
return await GraphQLResolvers.get_artifact_content(
str(team_id), str(project_id), tag, repo_type
)


@strawberry.type
class Mutation:
Expand Down
Loading
Loading