diff --git a/Cargo.lock b/Cargo.lock index 018f901..0e0e3f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,6 +83,7 @@ checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", + "base64", "bytes", "futures-util", "http", @@ -94,6 +95,7 @@ dependencies = [ "matchit", "memchr", "mime", + "multer", "percent-encoding", "pin-project-lite", "rustversion", @@ -101,8 +103,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -151,6 +155,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bs58" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf88ba1141d185c399bee5288d850d63b8369520c1eafc32a0430b5b6c287bf4" +dependencies = [ + "tinyvec", +] + [[package]] name = "bumpalo" version = "3.19.1" @@ -325,6 +338,12 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "data-encoding" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" + [[package]] name = "deflate64" version = "0.1.10" @@ -467,6 +486,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -474,6 +508,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -482,6 +517,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -517,6 +563,7 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -1082,6 +1129,23 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "multer" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http", + "httparse", + "memchr", + "mime", + "spin", + "version_check", +] + [[package]] name = "native-tls" version = "0.2.16" @@ -1245,6 +1309,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -1279,6 +1352,36 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -1627,6 +1730,12 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -1717,14 +1826,16 @@ dependencies = [ [[package]] name = "term-executor" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anyhow", "axum", "base64", + "bs58", "chrono", "dashmap", "flate2", + "futures", "hex", "parking_lot", "reqwest", @@ -1734,8 +1845,9 @@ dependencies = [ "sha2", "tar", "tempfile", - "thiserror", + "thiserror 2.0.18", "tokio", + "tokio-stream", "tokio-test", "tower", "tower-http", @@ -1745,13 +1857,33 @@ dependencies = [ "zip", ] +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + [[package]] name = "thiserror" version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -1803,6 +1935,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.49.0" @@ -1873,6 +2020,18 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "tokio-tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -2001,6 +2160,24 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror 1.0.69", + "utf-8", +] + [[package]] name = "typenum" version = "1.19.0" @@ -2043,6 +2220,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -2589,6 +2772,26 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zerocopy" +version = "0.8.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db6d35d663eadb6c932438e763b262fe1a70987f9ae936e60158176d710cae4a" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4122cd3169e94605190e77839c9a40d40ed048d305bfdc146e7df40ab0f3e517" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zerofrom" version = "0.1.6" @@ -2685,7 +2888,7 @@ dependencies = [ "memchr", "pbkdf2", "sha1", - "thiserror", + "thiserror 2.0.18", "time", "xz2", "zeroize", diff --git a/Cargo.toml b/Cargo.toml index 3f2f255..6e7db04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "term-executor" version = "1.0.0" edition = "2021" -description = "Remote evaluation executor for term-challenge — runs on Basilica" +description = "Attested SWE-bench evaluation executor for term-challenge miners" authors = ["Platform Network"] license = "Apache-2.0" repository = "https://github.com/PlatformNetwork/term-executor" @@ -13,12 +13,14 @@ path = "src/main.rs" [dependencies] # Web framework -axum = { version = "0.7", features = ["json"] } +axum = { version = "0.7", features = ["json", "ws", "multipart"] } tower = "0.5" tower-http = { version = "0.6", features = ["cors", "trace"] } # Async runtime tokio = { version = "1", features = ["full", "process"] } +tokio-stream = "0.1" +futures = "0.3" # Serialization serde = { version = "1", features = ["derive"] } @@ -55,7 +57,8 @@ tempfile = "3" parking_lot = "0.12" dashmap = "6" - +# SS58 address validation +bs58 = "0.5" [dev-dependencies] tokio-test = "0.4" diff --git a/Dockerfile b/Dockerfile index 72f8a98..8422996 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,10 +11,12 @@ RUN cargo build --release && strip target/release/term-executor FROM debian:bookworm-slim RUN apt-get update && apt-get install -y --no-install-recommends \ ca-certificates git curl libssl3 \ + python3 python3-pip python3-venv \ + build-essential \ && rm -rf /var/lib/apt/lists/* COPY --from=builder /build/target/release/term-executor /usr/local/bin/ RUN mkdir -p /tmp/sessions EXPOSE 8080 HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ CMD curl -f http://localhost:8080/health || exit 1 -CMD ["term-executor"] +ENTRYPOINT ["/usr/local/bin/term-executor"] diff --git a/README.md b/README.md index 441eaaa..ce98dba 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,18 @@ # term-executor -Remote evaluation executor for [term-challenge](https://github.com/PlatformNetwork/term-challenge). Runs on [Basilica](https://basilica.ai) as a containerized service that receives agent code, executes it against a task repository, and runs validation tests. +Remote evaluation executor for [term-challenge](https://github.com/PlatformNetwork/term-challenge). Runs as a containerized service that receives task archives, executes agent code against SWE-bench repositories, and runs validation tests. ## Architecture ```mermaid graph LR - PS[Platform Server] -->|POST /evaluate| TE[term-executor] - TE -->|1. Download| Task[(Task Archive)] + PS[Platform Server] -->|POST /submit| TE[term-executor] + TE -->|1. Extract archive| Tasks[(tasks/ + agent_code/)] TE -->|2. git clone| Repo[(Repository)] TE -->|3. Run agent| Agent[Agent Code] TE -->|4. Run tests| Tests[Test Scripts] - PS -->|GET /evaluate/id| TE - V[Validator] -->|GET /public/metadata| Basilica[Basilica API] - Basilica -->|Verify image + state| TE + PS -->|GET /batch/id| TE + PS -->|WS /ws?batch_id=id| TE ``` ## Evaluation Flow @@ -24,36 +23,30 @@ sequenceDiagram participant TE as term-executor participant GH as GitHub - PS->>TE: POST /evaluate {agent_code, task_url} - TE-->>PS: 202 {eval_id} + PS->>TE: POST /submit (multipart archive) + TE-->>PS: 202 {batch_id, ws_url} - TE->>TE: Download task archive - TE->>GH: git clone repo - TE->>TE: Install dependencies - TE->>TE: Write & run agent code - TE->>TE: Write test files - TE->>TE: Run test scripts + par For each task (up to 8 concurrent) + TE->>GH: git clone repo + TE->>TE: Install dependencies + TE->>TE: Write & run agent code + TE->>TE: Write test files + TE->>TE: Run test scripts + end - PS->>TE: GET /evaluate/{eval_id} - TE-->>PS: {status, passed, test_results} + PS->>TE: GET /batch/{batch_id} + TE-->>PS: {status, passed_tasks, aggregate_reward} ``` -## Session Lifecycle +## Authentication + +Only a single authorized hotkey can submit tasks. The hotkey must be sent via the `X-Hotkey` HTTP header: -```mermaid -stateDiagram-v2 - [*] --> Pending: POST /evaluate - Pending --> DownloadingTask - DownloadingTask --> CloningRepo - CloningRepo --> InstallingDeps - InstallingDeps --> RunningAgent - RunningAgent --> RunningTests - RunningTests --> Cleanup - Cleanup --> Completed: all tests pass - Cleanup --> Failed: tests fail or error - Pending --> Cancelled: timeout / cancel - RunningAgent --> Cancelled: timeout ``` +X-Hotkey: 5GziQCcRpN8NCJktX343brnfuVe3w6gUYieeStXPD1Dag2At +``` + +Health, status, and metrics endpoints are public. ## API Reference @@ -69,15 +62,15 @@ GET /health ``` GET /status → 200 { - "version": "0.1.0", + "version": "0.2.0", "uptime_secs": 3600, - "active_evals": 2, - "total_evals": 150, - "passed": 120, - "failed": 28, - "cancelled": 2, - "capacity": 4, - "available_slots": 2 + "active_batches": 1, + "total_batches": 10, + "completed_batches": 9, + "tasks_passed": 45, + "tasks_failed": 5, + "max_concurrent_tasks": 8, + "has_active_batch": true } ``` @@ -86,117 +79,138 @@ GET /status ``` GET /metrics → 200 (text/plain) - term_executor_evaluations_total 150 - term_executor_evaluations_passed 120 - term_executor_evaluations_failed 28 - term_executor_evaluations_active 2 + term_executor_batches_total 10 + term_executor_batches_active 1 + term_executor_batches_completed 9 + term_executor_tasks_passed 45 + term_executor_tasks_failed 5 ... ``` -### Submit Evaluation +### Submit Batch + +Upload a multipart archive containing `tasks/` and `agent_code/` directories. ``` -POST /evaluate -Authorization: Bearer -Content-Type: application/json +POST /submit +X-Hotkey: 5GziQCcRpN8NCJktX343brnfuVe3w6gUYieeStXPD1Dag2At +Content-Type: multipart/form-data + +Field: archive (file) -{ - "agent_code": "import os\n...", - "agent_language": "python", - "task_url": "https://example.com/task.tar.gz", - "timeout_secs": 600 -} +→ 202 { + "batch_id": "uuid", + "total_tasks": 5, + "concurrent_tasks": 8, + "ws_url": "/ws?batch_id=uuid" + } +→ 400 (invalid archive) +→ 401 (unauthorized) +→ 503 (batch already running) +``` -→ 202 {"eval_id": "uuid"} -→ 400 (invalid input) -→ 401 (bad token) -→ 503 (at capacity) +Optional query parameter: `?concurrent_tasks=4` to limit concurrency. + +### Get Batch Status + +``` +GET /batch/{batch_id} +→ 200 { + "batch_id": "uuid", + "status": "completed", + "total_tasks": 5, + "completed_tasks": 5, + "passed_tasks": 4, + "failed_tasks": 1, + "aggregate_reward": 0.8, + "error": null, + "duration_ms": 120000 + } +→ 404 (not found) ``` -### Poll Evaluation +### Get Batch Tasks ``` -GET /evaluate/{eval_id} +GET /batch/{batch_id}/tasks +→ 200 { + "batch_id": "uuid", + "tasks": [ + { + "task_id": "task-1", + "status": "completed", + "passed": true, + "reward": 1.0, + "test_output": "...", + "error": null, + "duration_ms": 25000 + } + ] + } +``` +### Get Single Task + +``` +GET /batch/{batch_id}/task/{task_id} → 200 { - "eval_id": "uuid", + "task_id": "task-1", "status": "completed", - "step": "done", "passed": true, + "reward": 1.0, "test_results": [ {"name": "test_1.sh", "passed": true, "exit_code": 0, "output": "..."} ], - "agent_output": "...", "test_output": "...", "error": null, - "duration_ms": 45000 + "duration_ms": 25000 } -→ 404 (not found) ``` -### List Evaluations +### List All Batches ``` -GET /evaluations -→ 200 [{"eval_id": "...", "task_url": "...", "language": "python", "created_at": "..."}] +GET /batches +→ 200 [ + {"batch_id": "uuid", "created_at": "2024-01-01T00:00:00Z", "status": "completed"} + ] ``` -## Deployment - -### Basilica (recommended) +### WebSocket (Real-time Updates) -```bash -basilica deploy ghcr.io/platformnetwork/term-executor:latest \ - --name my-executor \ - --port 8080 \ - --public-metadata \ - --health-path /health \ - --health-initial-delay 10 \ - --cpu 2 --memory 4Gi \ - --env AUTH_TOKEN=your-secret-token \ - --env MAX_CONCURRENT_EVALS=4 ``` - -### Docker - -```bash -docker run -d \ - -p 8080:8080 \ - -e AUTH_TOKEN=your-secret-token \ - ghcr.io/platformnetwork/term-executor:latest +WS /ws?batch_id={batch_id} ``` -## Configuration - -All configuration is via environment variables: +On connect, receives a `snapshot` event with current state. Then streams events: -| Variable | Default | Description | -|----------|---------|-------------| -| `PORT` | `8080` | HTTP listen port | -| `AUTH_TOKEN` | *(none)* | Bearer token for `/evaluate` endpoint. If unset, auth is disabled | -| `SESSION_TTL_SECS` | `1800` | Max session lifetime before reaping | -| `MAX_CONCURRENT_EVALS` | `4` | Maximum parallel evaluations | -| `DISK_QUOTA_MB` | `2048` | Max disk per session | -| `CLONE_TIMEOUT_SECS` | `120` | Git clone timeout | -| `AGENT_TIMEOUT_SECS` | `600` | Agent execution timeout | -| `TEST_TIMEOUT_SECS` | `300` | Test suite timeout | -| `MAX_AGENT_CODE_BYTES` | `5242880` | Max agent code payload (5MB) | -| `MAX_OUTPUT_BYTES` | `1048576` | Max captured output per command (1MB) | -| `WORKSPACE_BASE` | `/tmp/sessions` | Base directory for session workspaces | +```json +{"event": "task_started", "batch_id": "uuid", "task_id": "task-1", "data": {"task_id": "task-1"}} +{"event": "task_complete", "batch_id": "uuid", "task_id": "task-1", "data": {"task_id": "task-1", "status": "completed", "passed": true, "reward": 1.0}} +{"event": "batch_complete", "batch_id": "uuid", "data": {"status": "completed", "total": 5, "passed": 4, "failed": 1, "reward": 0.8, "duration_ms": 120000}} +``` -## SWE-Forge Task Format +## Archive Format -Tasks are archives (`.tar.gz` or `.zip`) containing: +Upload a `.zip` or `.tar.gz` archive with this structure: ``` -workspace.yaml # Required: repo URL, version, base_commit, install commands -prompt.md # Required: task description (no solution hints) -original_pr.md # Optional: raw PR body -checks.txt # Optional: flat list of test commands (one per line) -tests/ - fail_to_pass_1.sh # Test scripts (exit 0 = pass) - pass_to_pass_1.sh - test_utils.py # Test source files written to repo +archive/ +├── tasks/ +│ ├── task-1/ +│ │ ├── workspace.yaml # Required: repo URL, version, base_commit, install commands +│ │ ├── prompt.md # Required: task description +│ │ ├── checks.txt # Optional: test commands (one per line) +│ │ └── tests/ +│ │ ├── test_1.sh # Test scripts (exit 0 = pass) +│ │ └── helper.py # Non-.sh files written to repo +│ └── task-2/ +│ ├── workspace.yaml +│ ├── prompt.md +│ └── tests/ +│ └── test_1.sh +└── agent_code/ + └── agent.py # Agent code (never exposed in API responses) ``` ### workspace.yaml @@ -211,14 +225,38 @@ install: - "pip install pytest" ``` -## Security +## Reward -- **Auth**: Bearer token on `/evaluate`. Health/status/metrics are public. -- **Resource limits**: Per-process memory via `ulimit`, `nice` priority lowering, output truncation at 1MB. -- **Disk quota**: Checked per session, rejects if exceeded. -- **Timeouts**: Clone, agent, and test phases each have configurable timeouts. -- **Process isolation**: Each command runs in its own process group, killed on timeout. -- **Session reaping**: Expired sessions cleaned up every 60 seconds. +Binary reward per task: +- **1.0** if all test scripts exit with code 0 +- **0.0** otherwise + +Aggregate reward is the mean across all tasks in the batch. + +## Configuration + +All configuration is via environment variables: + +| Variable | Default | Description | +|----------|---------|-------------| +| `PORT` | `8080` | HTTP listen port | +| `SESSION_TTL_SECS` | `7200` | Max batch lifetime before reaping | +| `MAX_CONCURRENT_TASKS` | `8` | Maximum parallel task executions | +| `CLONE_TIMEOUT_SECS` | `180` | Git clone timeout | +| `AGENT_TIMEOUT_SECS` | `600` | Agent execution timeout | +| `TEST_TIMEOUT_SECS` | `300` | Test suite timeout | +| `MAX_ARCHIVE_BYTES` | `524288000` | Max upload archive size (500MB) | +| `MAX_OUTPUT_BYTES` | `1048576` | Max captured output per command (1MB) | +| `WORKSPACE_BASE` | `/tmp/sessions` | Base directory for task workspaces | + +## Docker + +```bash +docker build -t term-executor . +docker run -d -p 8080:8080 term-executor +``` + +The Dockerfile uses a multi-stage build. The only entrypoint is the `term-executor` binary. ## Development @@ -230,7 +268,7 @@ cargo build cargo test # Run locally -AUTH_TOKEN=test PORT=8080 cargo run +PORT=8080 cargo run # Format cargo +nightly fmt @@ -239,6 +277,15 @@ cargo +nightly fmt cargo +nightly clippy -- -D warnings ``` +## Security + +- **Auth**: Only the hardcoded authorized hotkey (SS58 address) can submit batches via `X-Hotkey` header. +- **Agent code hidden**: Agent code is never returned in any API response. +- **Timeouts**: Clone, agent, and test phases each have configurable timeouts. +- **Output truncation**: Command output is capped at 1MB. +- **Session reaping**: Expired batches are cleaned up every 60 seconds. +- **Single batch**: Only one batch can run at a time, preventing resource exhaustion. + ## License Apache-2.0 diff --git a/src/auth.rs b/src/auth.rs index 61353c0..a016d77 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,64 +1,26 @@ -use axum::{ - extract::Request, - http::{header, StatusCode}, - middleware::Next, - response::Response, -}; -use uuid::Uuid; +use crate::config::AUTHORIZED_HOTKEY; -#[allow(dead_code)] -pub async fn auth_middleware(request: Request, next: Next) -> Result { - let token = request - .extensions() - .get::>() - .cloned() - .flatten(); - - let Some(expected_token) = token else { - // No auth configured → pass through - let mut response = next.run(request).await; - inject_request_id(&mut response); - return Ok(response); - }; - - let auth_header = request - .headers() - .get(header::AUTHORIZATION) - .and_then(|v| v.to_str().ok()); - - match auth_header { - Some(h) if h.strip_prefix("Bearer ").unwrap_or(h) == expected_token => { - let mut response = next.run(request).await; - inject_request_id(&mut response); - Ok(response) - } - _ => { - tracing::warn!( - "Auth failed from {}", - request - .headers() - .get("x-forwarded-for") - .and_then(|v| v.to_str().ok()) - .unwrap_or("unknown") - ); - Err(StatusCode::UNAUTHORIZED) - } +pub fn verify_hotkey(hotkey: Option<&str>) -> bool { + match hotkey { + Some(k) => k == AUTHORIZED_HOTKEY, + None => false, } } -fn inject_request_id(response: &mut Response) { - let id = Uuid::new_v4().to_string(); - response - .headers_mut() - .insert("x-request-id", id.parse().unwrap()); +pub fn extract_hotkey(headers: &axum::http::HeaderMap) -> Option { + headers + .get("X-Hotkey") + .or_else(|| headers.get("x-hotkey")) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()) } -/// Simple token check function for endpoints that check auth directly. -pub fn check_token(auth_header: Option<&str>, expected: &str) -> bool { - match auth_header { - Some(h) => h.strip_prefix("Bearer ").unwrap_or(h) == expected, - None => false, +#[allow(dead_code)] +pub fn validate_ss58(address: &str) -> bool { + if address.len() < 2 || !address.starts_with('5') { + return false; } + bs58::decode(address).into_vec().is_ok() } #[cfg(test)] @@ -66,22 +28,20 @@ mod tests { use super::*; #[test] - fn test_check_token_bearer() { - assert!(check_token(Some("Bearer secret123"), "secret123")); - } - - #[test] - fn test_check_token_raw() { - assert!(check_token(Some("secret123"), "secret123")); + fn test_verify_hotkey_valid() { + assert!(verify_hotkey(Some(AUTHORIZED_HOTKEY))); } #[test] - fn test_check_token_wrong() { - assert!(!check_token(Some("Bearer wrong"), "secret123")); + fn test_verify_hotkey_invalid() { + assert!(!verify_hotkey(Some("5InvalidHotkey"))); + assert!(!verify_hotkey(None)); } #[test] - fn test_check_token_missing() { - assert!(!check_token(None, "secret123")); + fn test_validate_ss58() { + assert!(validate_ss58(AUTHORIZED_HOTKEY)); + assert!(!validate_ss58("")); + assert!(!validate_ss58("not-an-address")); } } diff --git a/src/config.rs b/src/config.rs index dd4b70a..9ceee06 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,25 +1,27 @@ use std::path::PathBuf; const DEFAULT_PORT: u16 = 8080; -const DEFAULT_SESSION_TTL: u64 = 1800; -const DEFAULT_MAX_CONCURRENT: usize = 4; -const DEFAULT_CLONE_TIMEOUT: u64 = 120; +const DEFAULT_SESSION_TTL: u64 = 7200; +const DEFAULT_MAX_CONCURRENT: usize = 8; +const DEFAULT_CLONE_TIMEOUT: u64 = 180; const DEFAULT_AGENT_TIMEOUT: u64 = 600; const DEFAULT_TEST_TIMEOUT: u64 = 300; -const DEFAULT_MAX_AGENT_CODE_BYTES: usize = 5 * 1024 * 1024; +const DEFAULT_MAX_ARCHIVE_BYTES: usize = 500 * 1024 * 1024; +#[allow(dead_code)] const DEFAULT_MAX_OUTPUT_BYTES: usize = 1024 * 1024; const DEFAULT_WORKSPACE_BASE: &str = "/tmp/sessions"; +pub const AUTHORIZED_HOTKEY: &str = "5GziQCcRpN8NCJktX343brnfuVe3w6gUYieeStXPD1Dag2At"; + #[derive(Debug, Clone)] pub struct Config { pub port: u16, - pub auth_token: Option, pub session_ttl_secs: u64, - pub max_concurrent_evals: usize, + pub max_concurrent_tasks: usize, pub clone_timeout_secs: u64, pub agent_timeout_secs: u64, pub test_timeout_secs: u64, - pub max_agent_code_bytes: usize, + pub max_archive_bytes: usize, #[allow(dead_code)] pub max_output_bytes: usize, pub workspace_base: PathBuf, @@ -29,13 +31,12 @@ impl Config { pub fn from_env() -> Self { Self { port: env_parse("PORT", DEFAULT_PORT), - auth_token: std::env::var("AUTH_TOKEN").ok(), session_ttl_secs: env_parse("SESSION_TTL_SECS", DEFAULT_SESSION_TTL), - max_concurrent_evals: env_parse("MAX_CONCURRENT_EVALS", DEFAULT_MAX_CONCURRENT), + max_concurrent_tasks: env_parse("MAX_CONCURRENT_TASKS", DEFAULT_MAX_CONCURRENT), clone_timeout_secs: env_parse("CLONE_TIMEOUT_SECS", DEFAULT_CLONE_TIMEOUT), agent_timeout_secs: env_parse("AGENT_TIMEOUT_SECS", DEFAULT_AGENT_TIMEOUT), test_timeout_secs: env_parse("TEST_TIMEOUT_SECS", DEFAULT_TEST_TIMEOUT), - max_agent_code_bytes: env_parse("MAX_AGENT_CODE_BYTES", DEFAULT_MAX_AGENT_CODE_BYTES), + max_archive_bytes: env_parse("MAX_ARCHIVE_BYTES", DEFAULT_MAX_ARCHIVE_BYTES), max_output_bytes: env_parse("MAX_OUTPUT_BYTES", DEFAULT_MAX_OUTPUT_BYTES), workspace_base: PathBuf::from( std::env::var("WORKSPACE_BASE").unwrap_or_else(|_| DEFAULT_WORKSPACE_BASE.into()), @@ -46,20 +47,13 @@ impl Config { pub fn print_banner(&self) { tracing::info!("╔══════════════════════════════════════════════════╗"); tracing::info!( - "║ term-executor v{} ║", + "║ term-executor v{} ║", env!("CARGO_PKG_VERSION") ); tracing::info!("╠══════════════════════════════════════════════════╣"); tracing::info!("║ Port: {:<28}║", self.port); - tracing::info!( - "║ Auth: {:<28}║", - if self.auth_token.is_some() { - "enabled" - } else { - "disabled" - } - ); - tracing::info!("║ Max concurrent: {:<28}║", self.max_concurrent_evals); + tracing::info!("║ Authorized hotkey: {}...║", &AUTHORIZED_HOTKEY[..10]); + tracing::info!("║ Max concurrent: {:<28}║", self.max_concurrent_tasks); tracing::info!("║ Session TTL: {:<25}s ║", self.session_ttl_secs); tracing::info!("║ Clone timeout: {:<25}s ║", self.clone_timeout_secs); tracing::info!("║ Agent timeout: {:<25}s ║", self.agent_timeout_secs); @@ -87,11 +81,17 @@ mod tests { fn test_config_defaults() { let cfg = Config::from_env(); assert_eq!(cfg.port, DEFAULT_PORT); - assert_eq!(cfg.max_concurrent_evals, DEFAULT_MAX_CONCURRENT); + assert_eq!(cfg.max_concurrent_tasks, DEFAULT_MAX_CONCURRENT); } #[test] fn test_env_parse_fallback() { assert_eq!(env_parse::("NONEXISTENT_VAR_XYZ", 42), 42); } + + #[test] + fn test_authorized_hotkey_valid() { + assert!(AUTHORIZED_HOTKEY.starts_with("5G")); + assert_eq!(AUTHORIZED_HOTKEY.len(), 48); + } } diff --git a/src/executor.rs b/src/executor.rs index eced7de..4335436 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -3,15 +3,15 @@ use std::path::Path; use std::sync::Arc; use std::time::Duration; use tokio::process::Command; -use tokio::sync::watch::Receiver; +use tokio::sync::Semaphore; use tracing::{debug, error, info, warn}; use crate::config::Config; use crate::metrics::Metrics; use crate::session::{ - EvalRequest, EvalResult, EvalStatus, EvalStep, Session, SessionManager, TaskTestResult, + Batch, BatchResult, BatchStatus, SessionManager, TaskResult, TaskStatus, TaskTestResult, }; -use crate::task; +use crate::task::{ExtractedArchive, SweForgeTask}; const MAX_OUTPUT: usize = 1024 * 1024; @@ -88,183 +88,314 @@ impl Executor { } } - pub fn spawn_eval(&self, session: Arc) { + pub fn spawn_batch( + &self, + batch: Arc, + archive: ExtractedArchive, + concurrent_limit: usize, + ) { let config = self.config.clone(); let sessions = self.sessions.clone(); let metrics = self.metrics.clone(); - let cancel_rx = session.cancel.subscribe(); tokio::spawn(async move { let start = std::time::Instant::now(); - metrics.start_eval(); + metrics.start_batch(); - let result = run_eval(&config, &session, cancel_rx).await; + let result = run_batch(&config, &batch, archive, concurrent_limit).await; let duration_ms = start.elapsed().as_millis() as u64; - let mut res = session.result.lock().await; + let mut res = batch.result.lock().await; match result { - Ok(eval) => { - let passed = eval.passed; - *res = eval; + Ok(batch_result) => { + let all_passed = batch_result.passed_tasks == batch_result.total_tasks; + *res = batch_result; res.duration_ms = Some(duration_ms); - metrics.finish_eval(passed, duration_ms); - if passed.unwrap_or(false) { - sessions.mark_completed(); - } else { - sessions.mark_failed(); - } + metrics.finish_batch(all_passed, duration_ms); + sessions.mark_completed(); } Err(e) => { - error!("Evaluation {} failed: {:#}", session.id, e); - res.status = EvalStatus::Failed; - res.step = EvalStep::Done; + error!("Batch {} failed: {:#}", batch.id, e); + res.status = BatchStatus::Failed; res.error = Some(format!("{:#}", e)); res.duration_ms = Some(duration_ms); - metrics.finish_eval(None, duration_ms); + metrics.finish_batch(false, duration_ms); sessions.mark_failed(); } } + + batch + .emit_event( + "batch_complete", + None, + serde_json::json!({ + "status": res.status, + "total": res.total_tasks, + "passed": res.passed_tasks, + "failed": res.failed_tasks, + "reward": res.aggregate_reward, + "duration_ms": res.duration_ms, + }), + ) + .await; }); } } -async fn run_eval( +async fn run_batch( config: &Config, - session: &Session, - cancel_rx: Receiver, -) -> Result { - let work_dir = config.workspace_base.join(&session.id); - tokio::fs::create_dir_all(&work_dir).await?; - - let result = async { - // 1. Download task - set_step(session, EvalStep::DownloadingTask).await; - if *cancel_rx.borrow() { - anyhow::bail!("Cancelled"); - } + batch: &Batch, + archive: ExtractedArchive, + concurrent_limit: usize, +) -> Result { + let total_tasks = archive.tasks.len(); + let agent_code = Arc::new(archive.agent_code); + let agent_language = Arc::new(archive.agent_language); + + { + let mut res = batch.result.lock().await; + res.status = BatchStatus::Running; + res.total_tasks = total_tasks; + } - let task_dir = work_dir.join("task"); - task::download_and_extract(&session.request.task_url, &task_dir).await?; - let task_root = task::find_task_root(&task_dir)?; - let swe_task = task::parse_task(&task_root)?; + batch + .emit_event( + "batch_started", + None, + serde_json::json!({ + "total_tasks": total_tasks, + "concurrent_limit": concurrent_limit, + }), + ) + .await; - // 2. Clone repository - set_step(session, EvalStep::CloningRepo).await; - if *cancel_rx.borrow() { - anyhow::bail!("Cancelled"); - } + let semaphore = Arc::new(Semaphore::new(concurrent_limit)); + let task_results: Arc>> = + Arc::new(tokio::sync::Mutex::new(Vec::new())); + + let mut handles = Vec::new(); + + for task in archive.tasks { + let config = config.clone(); + let batch_id = batch.id.clone(); + let events_tx = batch.events_tx.clone(); + let agent_code = agent_code.clone(); + let agent_language = agent_language.clone(); + let semaphore = semaphore.clone(); + let task_results = task_results.clone(); + let cancel_rx = batch.cancel.subscribe(); + + let handle = tokio::spawn(async move { + let _permit = semaphore.acquire().await.unwrap(); + + let task_id = task.id.clone(); + let _ = events_tx.send(crate::session::WsEvent { + event: "task_started".to_string(), + batch_id: batch_id.clone(), + task_id: Some(task_id.clone()), + data: serde_json::json!({ "task_id": task_id }), + }); + + let result = + run_single_task(&config, &task, &agent_code, &agent_language, cancel_rx).await; + + let _ = events_tx.send(crate::session::WsEvent { + event: "task_complete".to_string(), + batch_id: batch_id.clone(), + task_id: Some(task_id.clone()), + data: serde_json::json!({ + "task_id": task_id, + "status": result.status, + "passed": result.passed, + "reward": result.reward, + }), + }); + + task_results.lock().await.push(result); + }); - let repo_dir = work_dir.join("repo"); - clone_repo( - &swe_task.workspace.repo, - &repo_dir, - config.clone_timeout_secs, - ) - .await?; + handles.push(handle); + } - if let Some(ref commit) = swe_task.workspace.base_commit { - checkout_commit(&repo_dir, commit, config.clone_timeout_secs).await?; + for handle in handles { + if let Err(e) = handle.await { + warn!("Task handle panicked: {}", e); } + } - // 3. Install dependencies - set_step(session, EvalStep::InstallingDeps).await; - if *cancel_rx.borrow() { - anyhow::bail!("Cancelled"); - } + let results = task_results.lock().await; + let completed = results.len(); + let passed = results.iter().filter(|r| r.reward == 1.0).count(); + let failed = completed - passed; + let aggregate_reward = if total_tasks > 0 { + results.iter().map(|r| r.reward).sum::() / total_tasks as f64 + } else { + 0.0 + }; - if let Some(ref install_cmds) = swe_task.workspace.install { - for cmd in install_cmds { - info!("Running install command: {}", cmd); - let (_, stderr, exit) = run_shell( - cmd, - &repo_dir, - Duration::from_secs(config.clone_timeout_secs), - None, - ) - .await?; - if exit != 0 { - warn!( - "Install command failed (exit {}): {}", - exit, - &stderr[..stderr.len().min(500)] - ); - } - } - } + Ok(BatchResult { + batch_id: batch.id.clone(), + status: BatchStatus::Completed, + total_tasks, + completed_tasks: completed, + passed_tasks: passed, + failed_tasks: failed, + tasks: results.clone(), + aggregate_reward, + error: None, + duration_ms: None, + }) +} + +async fn run_single_task( + config: &Config, + task: &SweForgeTask, + agent_code: &str, + agent_language: &str, + cancel_rx: tokio::sync::watch::Receiver, +) -> TaskResult { + let start = std::time::Instant::now(); + let mut result = TaskResult::new(task.id.clone()); + + let work_dir = config.workspace_base.join(&task.id); + if let Err(e) = tokio::fs::create_dir_all(&work_dir).await { + result.status = TaskStatus::Failed; + result.error = Some(format!("Failed to create work dir: {}", e)); + return result; + } + + let eval_result = run_task_pipeline( + config, + task, + agent_code, + agent_language, + &work_dir, + &cancel_rx, + ) + .await; + + crate::cleanup::remove_work_dir(&work_dir).await; - // 4. Write + run agent code - set_step(session, EvalStep::RunningAgent).await; - if *cancel_rx.borrow() { - anyhow::bail!("Cancelled"); + let duration_ms = start.elapsed().as_millis() as u64; + + match eval_result { + Ok(mut r) => { + r.duration_ms = Some(duration_ms); + r + } + Err(e) => { + result.status = TaskStatus::Failed; + result.error = Some(format!("{:#}", e)); + result.duration_ms = Some(duration_ms); + result } + } +} - let agent_output = run_agent( - &session.request, - &swe_task.prompt, - &repo_dir, - config.agent_timeout_secs, - ) - .await?; +async fn run_task_pipeline( + config: &Config, + task: &SweForgeTask, + agent_code: &str, + agent_language: &str, + work_dir: &Path, + cancel_rx: &tokio::sync::watch::Receiver, +) -> Result { + let mut result = TaskResult::new(task.id.clone()); + + if *cancel_rx.borrow() { + anyhow::bail!("Cancelled"); + } + + result.status = TaskStatus::CloningRepo; + let repo_dir = work_dir.join("repo"); + clone_repo(&task.workspace.repo, &repo_dir, config.clone_timeout_secs).await?; - // 5. Write test source files - for (name, content) in &swe_task.test_source_files { - let dest = repo_dir.join(name); - if let Some(parent) = dest.parent() { - tokio::fs::create_dir_all(parent).await?; + if let Some(ref commit) = task.workspace.base_commit { + checkout_commit(&repo_dir, commit, config.clone_timeout_secs).await?; + } + + if *cancel_rx.borrow() { + anyhow::bail!("Cancelled"); + } + + result.status = TaskStatus::InstallingDeps; + if let Some(ref install_cmds) = task.workspace.install { + for cmd in install_cmds { + info!("[{}] Installing: {}", task.id, cmd); + let (_, stderr, exit) = run_shell( + cmd, + &repo_dir, + Duration::from_secs(config.clone_timeout_secs), + None, + ) + .await?; + if exit != 0 { + warn!( + "[{}] Install failed (exit {}): {}", + task.id, + exit, + &stderr[..stderr.len().min(500)] + ); } - tokio::fs::write(&dest, content).await?; } + } + + if *cancel_rx.borrow() { + anyhow::bail!("Cancelled"); + } - // 6. Run tests - set_step(session, EvalStep::RunningTests).await; - if *cancel_rx.borrow() { - anyhow::bail!("Cancelled"); + result.status = TaskStatus::RunningAgent; + let agent_output = run_agent( + agent_code, + agent_language, + &task.prompt, + &repo_dir, + config.agent_timeout_secs, + ) + .await?; + let _ = agent_output; + + for (name, content) in &task.test_source_files { + let dest = repo_dir.join(name); + if let Some(parent) = dest.parent() { + tokio::fs::create_dir_all(parent).await?; } + tokio::fs::write(&dest, content).await?; + } - let test_results = - run_tests(&swe_task.test_scripts, &repo_dir, config.test_timeout_secs).await?; - - let all_passed = test_results.iter().all(|t| t.passed); - let test_output_combined = test_results - .iter() - .map(|t| { - format!( - "=== {} (exit {}) ===\n{}\n{}", - t.name, - t.exit_code, - t.output, - if t.passed { "PASS" } else { "FAIL" } - ) - }) - .collect::>() - .join("\n\n"); - - Ok(EvalResult { - status: EvalStatus::Completed, - step: EvalStep::Done, - passed: Some(all_passed), - test_results, - agent_output, - test_output: test_output_combined, - error: None, - duration_ms: None, - }) + if *cancel_rx.borrow() { + anyhow::bail!("Cancelled"); } - .await; - // Cleanup - set_step(session, EvalStep::Cleanup).await; - crate::cleanup::remove_work_dir(&work_dir).await; + result.status = TaskStatus::RunningTests; + let test_results = run_tests(&task.test_scripts, &repo_dir, config.test_timeout_secs).await?; + + let all_passed = test_results.iter().all(|t| t.passed); + let test_output_combined = test_results + .iter() + .map(|t| { + format!( + "=== {} (exit {}) ===\n{}\n{}", + t.name, + t.exit_code, + t.output, + if t.passed { "PASS" } else { "FAIL" } + ) + }) + .collect::>() + .join("\n\n"); - result -} + result.status = if all_passed { + TaskStatus::Completed + } else { + TaskStatus::Failed + }; + result.passed = Some(all_passed); + result.reward = if all_passed { 1.0 } else { 0.0 }; + result.test_results = test_results; + result.test_output = test_output_combined; -async fn set_step(session: &Session, step: EvalStep) { - let mut res = session.result.lock().await; - res.step = step; - if res.status == EvalStatus::Pending { - res.status = EvalStatus::Running; - } + Ok(result) } async fn clone_repo(repo_url: &str, dest: &Path, timeout_secs: u64) -> Result<()> { @@ -338,20 +469,21 @@ fn agent_runner(language: &str, script_path: &str) -> Vec { } async fn run_agent( - request: &EvalRequest, + agent_code: &str, + agent_language: &str, prompt: &str, repo_dir: &Path, timeout_secs: u64, ) -> Result { - let ext = agent_extension(&request.agent_language); + let ext = agent_extension(agent_language); let script_name = format!("_agent_code{}", ext); let script_path = repo_dir.join(&script_name); - tokio::fs::write(&script_path, &request.agent_code).await?; + tokio::fs::write(&script_path, agent_code).await?; let prompt_path = repo_dir.join("_task_prompt.md"); tokio::fs::write(&prompt_path, prompt).await?; - let argv_owned = agent_runner(&request.agent_language, &script_name); + let argv_owned = agent_runner(agent_language, &script_name); let argv: Vec<&str> = argv_owned.iter().map(|s| s.as_str()).collect(); info!("Running agent: {:?}", argv); diff --git a/src/handlers.rs b/src/handlers.rs index f9050a3..4c19378 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,28 +1,27 @@ use axum::{ - extract::State, + extract::{Multipart, State}, http::StatusCode, response::{IntoResponse, Json, Response}, routing::{get, post}, Router, }; use chrono::Utc; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use std::sync::atomic::Ordering; use std::sync::Arc; -use tokio::sync::Semaphore; use crate::auth; use crate::config::Config; use crate::executor::Executor; use crate::metrics::Metrics; -use crate::session::{EvalRequest, SessionManager}; +use crate::session::SessionManager; +use crate::ws; pub struct AppState { pub config: Arc, pub sessions: Arc, pub metrics: Arc, pub executor: Arc, - pub semaphore: Arc, pub started_at: chrono::DateTime, } @@ -31,9 +30,12 @@ pub fn router(state: Arc) -> Router { .route("/health", get(health)) .route("/status", get(status)) .route("/metrics", get(metrics)) - .route("/evaluate", post(evaluate)) - .route("/evaluate/{id}", get(get_eval)) - .route("/evaluations", get(list_evals)) + .route("/submit", post(submit_batch)) + .route("/batch/{id}", get(get_batch)) + .route("/batch/{id}/tasks", get(get_batch_tasks)) + .route("/batch/{id}/task/{task_id}", get(get_task)) + .route("/batches", get(list_batches)) + .route("/ws", get(ws::ws_handler)) .with_state(state) } @@ -45,13 +47,13 @@ async fn health() -> impl IntoResponse { struct StatusResponse { version: String, uptime_secs: i64, - active_evals: u64, - total_evals: u64, - passed: u64, - failed: u64, - cancelled: u64, - capacity: usize, - available_slots: usize, + active_batches: u64, + total_batches: u64, + completed_batches: u64, + tasks_passed: u64, + tasks_failed: u64, + max_concurrent_tasks: usize, + has_active_batch: bool, } async fn status(State(state): State>) -> Json { @@ -59,13 +61,13 @@ async fn status(State(state): State>) -> Json { Json(StatusResponse { version: env!("CARGO_PKG_VERSION").to_string(), uptime_secs: uptime, - active_evals: state.metrics.evals_active.load(Ordering::Relaxed), - total_evals: state.metrics.evals_total.load(Ordering::Relaxed), - passed: state.metrics.evals_passed.load(Ordering::Relaxed), - failed: state.metrics.evals_failed.load(Ordering::Relaxed), - cancelled: state.metrics.evals_cancelled.load(Ordering::Relaxed), - capacity: state.config.max_concurrent_evals, - available_slots: state.semaphore.available_permits(), + active_batches: state.metrics.batches_active.load(Ordering::Relaxed), + total_batches: state.metrics.batches_total.load(Ordering::Relaxed), + completed_batches: state.metrics.batches_completed.load(Ordering::Relaxed), + tasks_passed: state.metrics.tasks_passed.load(Ordering::Relaxed), + tasks_failed: state.metrics.tasks_failed.load(Ordering::Relaxed), + max_concurrent_tasks: state.config.max_concurrent_tasks, + has_active_batch: state.sessions.has_active_batch(), }) } @@ -79,141 +81,210 @@ async fn metrics(State(state): State>) -> Response { .into_response() } -#[derive(Deserialize)] -struct EvalPayload { - agent_code: String, - #[serde(default = "default_language")] - agent_language: String, - task_url: String, - #[serde(default)] - timeout_secs: Option, +#[derive(serde::Deserialize)] +struct SubmitQuery { + #[serde(default = "default_concurrent")] + concurrent_tasks: Option, } -fn default_language() -> String { - "python".to_string() +fn default_concurrent() -> Option { + None } -async fn evaluate( +async fn submit_batch( State(state): State>, headers: axum::http::HeaderMap, - Json(payload): Json, -) -> Result { - // Auth check - if let Some(ref expected) = state.config.auth_token { - let auth_header = headers - .get(axum::http::header::AUTHORIZATION) - .and_then(|v| v.to_str().ok()); - if !auth::check_token(auth_header, expected) { - return Err((StatusCode::UNAUTHORIZED, "Invalid token".to_string())); - } - } - - // Validate payload - if payload.agent_code.len() > state.config.max_agent_code_bytes { + query: axum::extract::Query, + mut multipart: Multipart, +) -> Result)> { + let hotkey = auth::extract_hotkey(&headers); + if !auth::verify_hotkey(hotkey.as_deref()) { return Err(( - StatusCode::BAD_REQUEST, - format!( - "agent_code too large ({} bytes, max {})", - payload.agent_code.len(), - state.config.max_agent_code_bytes - ), + StatusCode::UNAUTHORIZED, + Json(serde_json::json!({ + "error": "unauthorized", + "message": "Invalid or missing X-Hotkey header. Only the authorized hotkey can submit tasks." + })), )); } - if payload.task_url.len() > 2048 { + if state.sessions.has_active_batch() { return Err(( - StatusCode::BAD_REQUEST, - "task_url too long (max 2048 chars)".to_string(), + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "error": "busy", + "message": "A batch is already running. Wait for it to complete." + })), )); } - if payload.task_url.is_empty() { - return Err((StatusCode::BAD_REQUEST, "task_url is required".to_string())); + let mut archive_data: Option> = None; + + while let Ok(Some(field)) = multipart.next_field().await { + let name = field.name().unwrap_or("").to_string(); + if name == "archive" || name == "file" { + let data = field.bytes().await.map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "upload_failed", + "message": format!("Failed to read upload: {}", e) + })), + ) + })?; + archive_data = Some(data.to_vec()); + } } - if payload.agent_code.is_empty() { - return Err(( + let archive_bytes = archive_data.ok_or_else(|| { + ( StatusCode::BAD_REQUEST, - "agent_code is required".to_string(), - )); - } + Json(serde_json::json!({ + "error": "missing_archive", + "message": "No archive file uploaded. Send a multipart form with field 'archive'." + })), + ) + })?; - // Capacity check - let permit = state.semaphore.clone().try_acquire_owned(); - if permit.is_err() { + if archive_bytes.len() > state.config.max_archive_bytes { return Err(( - StatusCode::SERVICE_UNAVAILABLE, - format!( - "At capacity ({}/{}). Try again later.", - state.config.max_concurrent_evals, state.config.max_concurrent_evals - ), + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "archive_too_large", + "message": format!("Archive is {} bytes, max is {}", archive_bytes.len(), state.config.max_archive_bytes) + })), )); } - let request = EvalRequest { - agent_code: payload.agent_code, - agent_language: payload.agent_language, - task_url: payload.task_url, - timeout_secs: payload.timeout_secs, - }; - - let session = state.sessions.create(request); - let id = session.id.clone(); - - // Spawn with permit held; permit is dropped when task completes - let executor = state.executor.clone(); - let permit = permit.unwrap(); - tokio::spawn(async move { - executor.spawn_eval(session); - // Hold the permit until the session manager marks it done - // We don't actually need to hold it since the semaphore tracks capacity - drop(permit); - }); + let extract_dir = state.config.workspace_base.join("_extract_tmp"); + let _ = tokio::fs::remove_dir_all(&extract_dir).await; + + let extracted = crate::task::extract_uploaded_archive(&archive_bytes, &extract_dir) + .await + .map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "extraction_failed", + "message": format!("Failed to extract archive: {}", e) + })), + ) + })?; + + let _ = tokio::fs::remove_dir_all(&extract_dir).await; + + let total_tasks = extracted.tasks.len(); + let concurrent = query + .concurrent_tasks + .unwrap_or(state.config.max_concurrent_tasks) + .min(state.config.max_concurrent_tasks); + + let batch = state.sessions.create_batch(total_tasks); + let batch_id = batch.id.clone(); + + state.executor.spawn_batch(batch, extracted, concurrent); Ok(( StatusCode::ACCEPTED, - Json(serde_json::json!({ "eval_id": id })), + Json(serde_json::json!({ + "batch_id": batch_id, + "total_tasks": total_tasks, + "concurrent_tasks": concurrent, + "ws_url": format!("/ws?batch_id={}", batch_id), + })), )) } -async fn get_eval( +async fn get_batch( State(state): State>, axum::extract::Path(id): axum::extract::Path, ) -> Result, StatusCode> { - let session = state.sessions.get(&id).ok_or(StatusCode::NOT_FOUND)?; - let result = session.result.lock().await; + let batch = state.sessions.get(&id).ok_or(StatusCode::NOT_FOUND)?; + let result = batch.result.lock().await; Ok(Json(serde_json::json!({ - "eval_id": session.id, + "batch_id": result.batch_id, "status": result.status, - "step": result.step, - "passed": result.passed, - "test_results": result.test_results, - "agent_output": result.agent_output, - "test_output": result.test_output, + "total_tasks": result.total_tasks, + "completed_tasks": result.completed_tasks, + "passed_tasks": result.passed_tasks, + "failed_tasks": result.failed_tasks, + "aggregate_reward": result.aggregate_reward, "error": result.error, "duration_ms": result.duration_ms, }))) } +async fn get_batch_tasks( + State(state): State>, + axum::extract::Path(id): axum::extract::Path, +) -> Result, StatusCode> { + let batch = state.sessions.get(&id).ok_or(StatusCode::NOT_FOUND)?; + let result = batch.result.lock().await; + + let tasks: Vec = result + .tasks + .iter() + .map(|t| { + serde_json::json!({ + "task_id": t.task_id, + "status": t.status, + "passed": t.passed, + "reward": t.reward, + "test_output": t.test_output, + "error": t.error, + "duration_ms": t.duration_ms, + }) + }) + .collect(); + + Ok(Json(serde_json::json!({ + "batch_id": result.batch_id, + "tasks": tasks, + }))) +} + +async fn get_task( + State(state): State>, + axum::extract::Path((batch_id, task_id)): axum::extract::Path<(String, String)>, +) -> Result, StatusCode> { + let batch = state.sessions.get(&batch_id).ok_or(StatusCode::NOT_FOUND)?; + let result = batch.result.lock().await; + + let task = result + .tasks + .iter() + .find(|t| t.task_id == task_id) + .ok_or(StatusCode::NOT_FOUND)?; + + Ok(Json(serde_json::json!({ + "task_id": task.task_id, + "status": task.status, + "passed": task.passed, + "reward": task.reward, + "test_results": task.test_results, + "test_output": task.test_output, + "error": task.error, + "duration_ms": task.duration_ms, + }))) +} + #[derive(Serialize)] -struct EvalListEntry { - eval_id: String, - task_url: String, - language: String, +struct BatchListEntry { + batch_id: String, created_at: String, + status: crate::session::BatchStatus, } -async fn list_evals(State(state): State>) -> Json> { - let sessions = state.sessions.list_sessions(); +async fn list_batches(State(state): State>) -> Json> { + let batches = state.sessions.list_batches(); Json( - sessions + batches .into_iter() - .map(|s| EvalListEntry { - eval_id: s.id, - task_url: s.task_url, - language: s.language, - created_at: s.created_at.to_rfc3339(), + .map(|b| BatchListEntry { + batch_id: b.batch_id, + created_at: b.created_at.to_rfc3339(), + status: b.status, }) .collect(), ) diff --git a/src/main.rs b/src/main.rs index f81ec19..7c92ae6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,12 +4,11 @@ mod config; mod executor; mod handlers; mod metrics; - mod session; mod task; +mod ws; use std::sync::Arc; -use tokio::sync::Semaphore; use tracing::info; #[tokio::main] @@ -24,14 +23,12 @@ async fn main() { let config = Arc::new(config::Config::from_env()); config.print_banner(); - // Create workspace base directory tokio::fs::create_dir_all(&config.workspace_base) .await .expect("Failed to create workspace directory"); let sessions = Arc::new(session::SessionManager::new(config.session_ttl_secs)); let metrics_store = metrics::Metrics::new(); - let semaphore = Arc::new(Semaphore::new(config.max_concurrent_evals)); let executor = Arc::new(executor::Executor::new( config.clone(), sessions.clone(), @@ -43,20 +40,17 @@ async fn main() { sessions: sessions.clone(), metrics: metrics_store, executor, - semaphore, started_at: chrono::Utc::now(), }); let app = handlers::router(state); let addr = format!("0.0.0.0:{}", config.port); - // Session reaper let sessions_reaper = sessions.clone(); tokio::spawn(async move { sessions_reaper.reaper_loop().await; }); - // Stale dir reaper let workspace = config.workspace_base.clone(); let ttl = config.session_ttl_secs; tokio::spawn(async move { @@ -70,7 +64,6 @@ async fn main() { info!("Listening on {}", addr); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); - // Graceful shutdown on SIGTERM let shutdown = async { tokio::signal::ctrl_c() .await diff --git a/src/metrics.rs b/src/metrics.rs index d48e999..e37c637 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -3,82 +3,91 @@ use std::sync::Arc; #[derive(Debug)] pub struct Metrics { - pub evals_total: AtomicU64, - pub evals_passed: AtomicU64, - pub evals_failed: AtomicU64, - pub evals_cancelled: AtomicU64, - pub evals_active: AtomicU64, - pub evals_duration_sum_ms: AtomicU64, + pub batches_total: AtomicU64, + pub batches_active: AtomicU64, + pub batches_completed: AtomicU64, + pub tasks_total: AtomicU64, + pub tasks_passed: AtomicU64, + pub tasks_failed: AtomicU64, + pub duration_sum_ms: AtomicU64, } impl Metrics { pub fn new() -> Arc { Arc::new(Self { - evals_total: AtomicU64::new(0), - evals_passed: AtomicU64::new(0), - evals_failed: AtomicU64::new(0), - evals_cancelled: AtomicU64::new(0), - evals_active: AtomicU64::new(0), - evals_duration_sum_ms: AtomicU64::new(0), + batches_total: AtomicU64::new(0), + batches_active: AtomicU64::new(0), + batches_completed: AtomicU64::new(0), + tasks_total: AtomicU64::new(0), + tasks_passed: AtomicU64::new(0), + tasks_failed: AtomicU64::new(0), + duration_sum_ms: AtomicU64::new(0), }) } - pub fn start_eval(&self) { - self.evals_total.fetch_add(1, Ordering::Relaxed); - self.evals_active.fetch_add(1, Ordering::Relaxed); + pub fn start_batch(&self) { + self.batches_total.fetch_add(1, Ordering::Relaxed); + self.batches_active.fetch_add(1, Ordering::Relaxed); } - pub fn finish_eval(&self, passed: Option, duration_ms: u64) { - self.evals_active.fetch_sub(1, Ordering::Relaxed); - self.evals_duration_sum_ms + pub fn finish_batch(&self, all_passed: bool, duration_ms: u64) { + self.batches_active.fetch_sub(1, Ordering::Relaxed); + self.batches_completed.fetch_add(1, Ordering::Relaxed); + self.duration_sum_ms .fetch_add(duration_ms, Ordering::Relaxed); - match passed { - Some(true) => { - self.evals_passed.fetch_add(1, Ordering::Relaxed); - } - Some(false) => { - self.evals_failed.fetch_add(1, Ordering::Relaxed); - } - None => { - self.evals_failed.fetch_add(1, Ordering::Relaxed); - } + if all_passed { + self.tasks_passed.fetch_add(1, Ordering::Relaxed); } } #[allow(dead_code)] - pub fn cancel_eval(&self) { - self.evals_active.fetch_sub(1, Ordering::Relaxed); - self.evals_cancelled.fetch_add(1, Ordering::Relaxed); + pub fn record_task_result(&self, passed: bool) { + self.tasks_total.fetch_add(1, Ordering::Relaxed); + if passed { + self.tasks_passed.fetch_add(1, Ordering::Relaxed); + } else { + self.tasks_failed.fetch_add(1, Ordering::Relaxed); + } } pub fn render_prometheus(&self) -> String { - let total = self.evals_total.load(Ordering::Relaxed); - let passed = self.evals_passed.load(Ordering::Relaxed); - let failed = self.evals_failed.load(Ordering::Relaxed); - let cancelled = self.evals_cancelled.load(Ordering::Relaxed); - let active = self.evals_active.load(Ordering::Relaxed); - let dur_sum = self.evals_duration_sum_ms.load(Ordering::Relaxed); + let batches_total = self.batches_total.load(Ordering::Relaxed); + let batches_active = self.batches_active.load(Ordering::Relaxed); + let batches_completed = self.batches_completed.load(Ordering::Relaxed); + let tasks_total = self.tasks_total.load(Ordering::Relaxed); + let tasks_passed = self.tasks_passed.load(Ordering::Relaxed); + let tasks_failed = self.tasks_failed.load(Ordering::Relaxed); + let dur_sum = self.duration_sum_ms.load(Ordering::Relaxed); format!( - "# HELP term_executor_evaluations_total Total evaluations started.\n\ - # TYPE term_executor_evaluations_total counter\n\ - term_executor_evaluations_total {}\n\ - # HELP term_executor_evaluations_passed Total evaluations that passed.\n\ - # TYPE term_executor_evaluations_passed counter\n\ - term_executor_evaluations_passed {}\n\ - # HELP term_executor_evaluations_failed Total evaluations that failed.\n\ - # TYPE term_executor_evaluations_failed counter\n\ - term_executor_evaluations_failed {}\n\ - # HELP term_executor_evaluations_cancelled Total evaluations cancelled.\n\ - # TYPE term_executor_evaluations_cancelled counter\n\ - term_executor_evaluations_cancelled {}\n\ - # HELP term_executor_evaluations_active Currently running evaluations.\n\ - # TYPE term_executor_evaluations_active gauge\n\ - term_executor_evaluations_active {}\n\ - # HELP term_executor_evaluations_duration_ms_sum Sum of evaluation durations in ms.\n\ - # TYPE term_executor_evaluations_duration_ms_sum counter\n\ - term_executor_evaluations_duration_ms_sum {}\n", - total, passed, failed, cancelled, active, dur_sum + "# HELP term_executor_batches_total Total batches submitted.\n\ + # TYPE term_executor_batches_total counter\n\ + term_executor_batches_total {}\n\ + # HELP term_executor_batches_active Currently running batches.\n\ + # TYPE term_executor_batches_active gauge\n\ + term_executor_batches_active {}\n\ + # HELP term_executor_batches_completed Completed batches.\n\ + # TYPE term_executor_batches_completed counter\n\ + term_executor_batches_completed {}\n\ + # HELP term_executor_tasks_total Total tasks evaluated.\n\ + # TYPE term_executor_tasks_total counter\n\ + term_executor_tasks_total {}\n\ + # HELP term_executor_tasks_passed Tasks that passed (reward=1).\n\ + # TYPE term_executor_tasks_passed counter\n\ + term_executor_tasks_passed {}\n\ + # HELP term_executor_tasks_failed Tasks that failed (reward=0).\n\ + # TYPE term_executor_tasks_failed counter\n\ + term_executor_tasks_failed {}\n\ + # HELP term_executor_duration_ms_sum Sum of batch durations in ms.\n\ + # TYPE term_executor_duration_ms_sum counter\n\ + term_executor_duration_ms_sum {}\n", + batches_total, + batches_active, + batches_completed, + tasks_total, + tasks_passed, + tasks_failed, + dur_sum ) } } @@ -90,32 +99,22 @@ mod tests { #[test] fn test_metrics_lifecycle() { let m = Metrics::new(); - m.start_eval(); - assert_eq!(m.evals_active.load(Ordering::Relaxed), 1); - assert_eq!(m.evals_total.load(Ordering::Relaxed), 1); - - m.finish_eval(Some(true), 5000); - assert_eq!(m.evals_active.load(Ordering::Relaxed), 0); - assert_eq!(m.evals_passed.load(Ordering::Relaxed), 1); - } + m.start_batch(); + assert_eq!(m.batches_active.load(Ordering::Relaxed), 1); + assert_eq!(m.batches_total.load(Ordering::Relaxed), 1); - #[test] - fn test_metrics_cancel() { - let m = Metrics::new(); - m.start_eval(); - m.cancel_eval(); - assert_eq!(m.evals_cancelled.load(Ordering::Relaxed), 1); - assert_eq!(m.evals_active.load(Ordering::Relaxed), 0); + m.finish_batch(true, 5000); + assert_eq!(m.batches_active.load(Ordering::Relaxed), 0); + assert_eq!(m.batches_completed.load(Ordering::Relaxed), 1); } #[test] fn test_prometheus_output() { let m = Metrics::new(); - m.start_eval(); - m.finish_eval(Some(false), 1234); + m.start_batch(); + m.finish_batch(false, 1234); let out = m.render_prometheus(); - assert!(out.contains("term_executor_evaluations_total 1")); - assert!(out.contains("term_executor_evaluations_failed 1")); - assert!(out.contains("term_executor_evaluations_duration_ms_sum 1234")); + assert!(out.contains("term_executor_batches_total 1")); + assert!(out.contains("term_executor_duration_ms_sum 1234")); } } diff --git a/src/session.rs b/src/session.rs index 0005ad5..3bca9a4 100644 --- a/src/session.rs +++ b/src/session.rs @@ -3,39 +3,29 @@ use dashmap::DashMap; use serde::{Deserialize, Serialize}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use tokio::sync::Mutex; +use tokio::sync::{broadcast, Mutex}; use tracing::info; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EvalRequest { - pub agent_code: String, - pub agent_language: String, - pub task_url: String, - #[serde(default)] - pub timeout_secs: Option, -} - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "snake_case")] -pub enum EvalStatus { +pub enum BatchStatus { Pending, + Extracting, Running, Completed, Failed, - Cancelled, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "snake_case")] -pub enum EvalStep { +pub enum TaskStatus { Queued, - DownloadingTask, CloningRepo, InstallingDeps, RunningAgent, RunningTests, - Cleanup, - Done, + Completed, + Failed, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -47,32 +37,80 @@ pub struct TaskTestResult { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EvalResult { - pub status: EvalStatus, - pub step: EvalStep, +pub struct TaskResult { + pub task_id: String, + pub status: TaskStatus, pub passed: Option, + pub reward: f64, pub test_results: Vec, - pub agent_output: String, pub test_output: String, pub error: Option, pub duration_ms: Option, } -pub struct Session { +impl TaskResult { + pub fn new(task_id: String) -> Self { + Self { + task_id, + status: TaskStatus::Queued, + passed: None, + reward: 0.0, + test_results: Vec::new(), + test_output: String::new(), + error: None, + duration_ms: None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BatchResult { + pub batch_id: String, + pub status: BatchStatus, + pub total_tasks: usize, + pub completed_tasks: usize, + pub passed_tasks: usize, + pub failed_tasks: usize, + pub tasks: Vec, + pub aggregate_reward: f64, + pub error: Option, + pub duration_ms: Option, +} + +#[derive(Debug, Clone, Serialize)] +pub struct WsEvent { + pub event: String, + pub batch_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub task_id: Option, + pub data: serde_json::Value, +} + +pub struct Batch { pub id: String, - pub request: EvalRequest, - pub result: Arc>, pub created_at: DateTime, + pub result: Arc>, + pub events_tx: broadcast::Sender, pub cancel: tokio::sync::watch::Sender, } -#[allow(dead_code)] +impl Batch { + pub async fn emit_event(&self, event: &str, task_id: Option<&str>, data: serde_json::Value) { + let ws_event = WsEvent { + event: event.to_string(), + batch_id: self.id.clone(), + task_id: task_id.map(|s| s.to_string()), + data, + }; + let _ = self.events_tx.send(ws_event); + } +} + pub struct SessionStats { pub created: AtomicU64, pub active: AtomicU64, pub completed: AtomicU64, pub failed: AtomicU64, - pub cancelled: AtomicU64, } impl SessionStats { @@ -82,13 +120,12 @@ impl SessionStats { active: AtomicU64::new(0), completed: AtomicU64::new(0), failed: AtomicU64::new(0), - cancelled: AtomicU64::new(0), } } } pub struct SessionManager { - sessions: DashMap>, + batches: DashMap>, ttl_secs: u64, pub stats: SessionStats, } @@ -96,63 +133,72 @@ pub struct SessionManager { impl SessionManager { pub fn new(ttl_secs: u64) -> Self { Self { - sessions: DashMap::new(), + batches: DashMap::new(), ttl_secs, stats: SessionStats::new(), } } - pub fn create(&self, request: EvalRequest) -> Arc { + pub fn create_batch(&self, total_tasks: usize) -> Arc { let id = uuid::Uuid::new_v4().to_string(); + let (events_tx, _) = broadcast::channel(256); let (cancel_tx, _) = tokio::sync::watch::channel(false); - let session = Arc::new(Session { + let batch = Arc::new(Batch { id: id.clone(), - request, - result: Arc::new(Mutex::new(EvalResult { - status: EvalStatus::Pending, - step: EvalStep::Queued, - passed: None, - test_results: Vec::new(), - agent_output: String::new(), - test_output: String::new(), + created_at: Utc::now(), + result: Arc::new(Mutex::new(BatchResult { + batch_id: id.clone(), + status: BatchStatus::Pending, + total_tasks, + completed_tasks: 0, + passed_tasks: 0, + failed_tasks: 0, + tasks: Vec::new(), + aggregate_reward: 0.0, error: None, duration_ms: None, })), - created_at: Utc::now(), + events_tx, cancel: cancel_tx, }); - self.sessions.insert(id, session.clone()); + self.batches.insert(id, batch.clone()); self.stats.created.fetch_add(1, Ordering::Relaxed); self.stats.active.fetch_add(1, Ordering::Relaxed); - session - } - - pub fn get(&self, id: &str) -> Option> { - self.sessions.get(id).map(|s| s.value().clone()) + batch } - #[allow(dead_code)] - pub fn remove(&self, id: &str) -> Option> { - self.sessions.remove(id).map(|(_, s)| s) + pub fn get(&self, id: &str) -> Option> { + self.batches.get(id).map(|b| b.value().clone()) } - #[allow(dead_code)] - pub fn active_count(&self) -> usize { - self.stats.active.load(Ordering::Relaxed) as usize + pub fn has_active_batch(&self) -> bool { + for entry in self.batches.iter() { + let result = entry.value().result.try_lock(); + if let Ok(r) = result { + if r.status == BatchStatus::Running || r.status == BatchStatus::Extracting { + return true; + } + } + } + false } - pub fn list_sessions(&self) -> Vec { - self.sessions + pub fn list_batches(&self) -> Vec { + self.batches .iter() .map(|entry| { - let s = entry.value(); - SessionSummary { - id: s.id.clone(), - task_url: s.request.task_url.clone(), - language: s.request.agent_language.clone(), - created_at: s.created_at, + let b = entry.value(); + let status = b + .result + .try_lock() + .map(|r| r.status.clone()) + .unwrap_or(BatchStatus::Running); + BatchSummary { + batch_id: b.id.clone(), + created_at: b.created_at, + status, } }) .collect() @@ -168,12 +214,6 @@ impl SessionManager { self.stats.failed.fetch_add(1, Ordering::Relaxed); } - #[allow(dead_code)] - pub fn mark_cancelled(&self) { - self.stats.active.fetch_sub(1, Ordering::Relaxed); - self.stats.cancelled.fetch_add(1, Ordering::Relaxed); - } - pub async fn reaper_loop(&self) { let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(60)); loop { @@ -181,7 +221,7 @@ impl SessionManager { let now = Utc::now(); let mut expired = Vec::new(); - for entry in self.sessions.iter() { + for entry in self.batches.iter() { let age = (now - entry.value().created_at).num_seconds() as u64; if age > self.ttl_secs { expired.push(entry.key().clone()); @@ -189,9 +229,9 @@ impl SessionManager { } for id in expired { - if let Some((_, session)) = self.sessions.remove(&id) { - let _ = session.cancel.send(true); - info!("Reaped expired session {}", id); + if let Some((_, batch)) = self.batches.remove(&id) { + let _ = batch.cancel.send(true); + info!("Reaped expired batch {}", id); } } } @@ -199,9 +239,8 @@ impl SessionManager { } #[derive(Debug, Clone, Serialize)] -pub struct SessionSummary { - pub id: String, - pub task_url: String, - pub language: String, +pub struct BatchSummary { + pub batch_id: String, pub created_at: DateTime, + pub status: BatchStatus, } diff --git a/src/task.rs b/src/task.rs index 1219bc1..41abd4b 100644 --- a/src/task.rs +++ b/src/task.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; use tracing::{debug, info}; -const MAX_ARCHIVE_SIZE: usize = 100 * 1024 * 1024; // 100MB +const MAX_ARCHIVE_SIZE: usize = 500 * 1024 * 1024; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WorkspaceConfig { @@ -19,73 +19,189 @@ pub struct WorkspaceConfig { #[derive(Debug)] pub struct SweForgeTask { + pub id: String, pub workspace: WorkspaceConfig, pub prompt: String, pub test_scripts: Vec<(String, String)>, pub test_source_files: Vec<(String, String)>, } -pub async fn download_and_extract(url: &str, dest: &Path) -> Result<()> { - info!("Downloading task archive from {}", url); - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(120)) - .build()?; - - let resp = client - .get(url) - .send() - .await - .context("Failed to download task archive")?; +#[derive(Debug)] +pub struct ExtractedArchive { + pub tasks: Vec, + pub agent_code: String, + pub agent_language: String, +} - if !resp.status().is_success() { - anyhow::bail!( - "Task archive download failed: HTTP {}", - resp.status().as_u16() - ); +pub fn extract_archive_bytes(data: &[u8], dest: &Path) -> Result<()> { + if let Ok(mut archive) = zip::ZipArchive::new(std::io::Cursor::new(data)) { + debug!("Extracting ZIP archive ({} entries)", archive.len()); + archive + .extract(dest) + .context("Failed to extract ZIP archive")?; + return Ok(()); } - let bytes = resp.bytes().await.context("Failed to read response body")?; + let gz = flate2::read::GzDecoder::new(data); + let mut archive = tar::Archive::new(gz); + archive + .unpack(dest) + .context("Failed to extract tar.gz archive")?; + debug!("Extracted tar.gz archive"); - if bytes.len() > MAX_ARCHIVE_SIZE { + Ok(()) +} + +pub async fn extract_uploaded_archive(data: &[u8], dest: &Path) -> Result { + if data.len() > MAX_ARCHIVE_SIZE { anyhow::bail!( - "Task archive too large: {} bytes (max {})", - bytes.len(), + "Archive too large: {} bytes (max {})", + data.len(), MAX_ARCHIVE_SIZE ); } - info!("Downloaded {} bytes, extracting...", bytes.len()); + info!("Extracting {} bytes archive...", data.len()); tokio::fs::create_dir_all(dest) .await .context("Failed to create extraction directory")?; - let dest = dest.to_path_buf(); - let bytes_vec = bytes.to_vec(); - tokio::task::spawn_blocking(move || extract_archive(&bytes_vec, &dest)) + let dest_owned = dest.to_path_buf(); + let data_vec = data.to_vec(); + tokio::task::spawn_blocking(move || extract_archive_bytes(&data_vec, &dest_owned)) .await .context("Extract task panicked")??; - Ok(()) + let root = find_archive_root(dest)?; + + let agent_code = load_agent_code(&root)?; + let agent_language = detect_agent_language(&root); + let tasks = load_tasks(&root)?; + + info!( + "Extracted {} tasks, agent language: {}", + tasks.len(), + agent_language + ); + + Ok(ExtractedArchive { + tasks, + agent_code, + agent_language, + }) } -fn extract_archive(data: &[u8], dest: &Path) -> Result<()> { - if let Ok(mut archive) = zip::ZipArchive::new(std::io::Cursor::new(data)) { - debug!("Extracting ZIP archive ({} entries)", archive.len()); - archive - .extract(dest) - .context("Failed to extract ZIP archive")?; - return Ok(()); +fn find_archive_root(base: &Path) -> Result { + if base.join("tasks").exists() || base.join("agent_code").exists() { + return Ok(base.to_path_buf()); } - let gz = flate2::read::GzDecoder::new(data); - let mut archive = tar::Archive::new(gz); - archive - .unpack(dest) - .context("Failed to extract tar.gz archive")?; - debug!("Extracted tar.gz archive"); + for entry in std::fs::read_dir(base).context("Failed to read extracted directory")? { + let entry = entry?; + let path = entry.path(); + if path.is_dir() && (path.join("tasks").exists() || path.join("agent_code").exists()) { + return Ok(path); + } + } - Ok(()) + anyhow::bail!( + "No tasks/ or agent_code/ found in archive at {}", + base.display() + ) +} + +fn load_agent_code(root: &Path) -> Result { + let agent_dir = root.join("agent_code"); + if !agent_dir.exists() { + anyhow::bail!("agent_code/ directory not found in archive"); + } + + let mut agent_content = String::new(); + let mut files: Vec<_> = std::fs::read_dir(&agent_dir)? + .filter_map(|e| e.ok()) + .filter(|e| e.path().is_file()) + .collect(); + files.sort_by_key(|e| e.file_name()); + + for entry in &files { + let content = std::fs::read_to_string(entry.path()) + .with_context(|| format!("Failed to read agent file: {:?}", entry.path()))?; + if files.len() == 1 { + agent_content = content; + } else { + agent_content.push_str(&format!( + "# --- {} ---\n", + entry.file_name().to_string_lossy() + )); + agent_content.push_str(&content); + agent_content.push('\n'); + } + } + + if agent_content.is_empty() { + anyhow::bail!("agent_code/ directory is empty"); + } + + Ok(agent_content) +} + +fn detect_agent_language(root: &Path) -> String { + let agent_dir = root.join("agent_code"); + if let Ok(entries) = std::fs::read_dir(&agent_dir) { + for entry in entries.flatten() { + let name = entry.file_name().to_string_lossy().to_string(); + if name.ends_with(".py") { + return "python".to_string(); + } + if name.ends_with(".js") { + return "javascript".to_string(); + } + if name.ends_with(".ts") { + return "typescript".to_string(); + } + if name.ends_with(".sh") { + return "shell".to_string(); + } + if name.ends_with(".rs") { + return "rust".to_string(); + } + if name.ends_with(".go") { + return "go".to_string(); + } + } + } + "python".to_string() +} + +fn load_tasks(root: &Path) -> Result> { + let tasks_dir = root.join("tasks"); + if !tasks_dir.exists() { + anyhow::bail!("tasks/ directory not found in archive"); + } + + let mut tasks = Vec::new(); + let mut entries: Vec<_> = std::fs::read_dir(&tasks_dir)? + .filter_map(|e| e.ok()) + .filter(|e| e.path().is_dir()) + .collect(); + entries.sort_by_key(|e| e.file_name()); + + for entry in entries { + let task_dir = entry.path(); + match parse_task(&task_dir) { + Ok(task) => tasks.push(task), + Err(e) => { + tracing::warn!("Skipping task dir {}: {}", task_dir.display(), e); + } + } + } + + if tasks.is_empty() { + anyhow::bail!("No valid tasks found in tasks/ directory"); + } + + Ok(tasks) } pub fn parse_task(task_dir: &Path) -> Result { @@ -98,10 +214,14 @@ pub fn parse_task(task_dir: &Path) -> Result { let prompt_path = task_dir.join("prompt.md"); let prompt = std::fs::read_to_string(&prompt_path).context("Missing prompt.md")?; + let id = task_dir + .file_name() + .map(|n| n.to_string_lossy().to_string()) + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + let mut test_scripts = Vec::new(); let mut test_source_files = Vec::new(); - // Load from tests/ directory let tests_dir = task_dir.join("tests"); if tests_dir.exists() { load_tests_recursive( @@ -112,7 +232,6 @@ pub fn parse_task(task_dir: &Path) -> Result { )?; } - // Load from checks.txt (alternative flat format) let checks_path = task_dir.join("checks.txt"); if checks_path.exists() && test_scripts.is_empty() { let checks = std::fs::read_to_string(&checks_path).context("Failed to read checks.txt")?; @@ -125,15 +244,10 @@ pub fn parse_task(task_dir: &Path) -> Result { let content = format!("#!/bin/sh\nset -e\n{}\n", line); test_scripts.push((name, content)); } - if !test_scripts.is_empty() { - info!( - "Loaded {} test commands from checks.txt", - test_scripts.len() - ); - } } Ok(SweForgeTask { + id, workspace, prompt, test_scripts, @@ -175,23 +289,47 @@ fn load_tests_recursive( Ok(()) } -pub fn find_task_root(base: &Path) -> Result { - if base.join("workspace.yaml").exists() { - return Ok(base.to_path_buf()); +#[allow(dead_code)] +pub async fn download_and_extract(url: &str, dest: &Path) -> Result<()> { + info!("Downloading task archive from {}", url); + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(120)) + .build()?; + + let resp = client + .get(url) + .send() + .await + .context("Failed to download task archive")?; + + if !resp.status().is_success() { + anyhow::bail!( + "Task archive download failed: HTTP {}", + resp.status().as_u16() + ); } - for entry in std::fs::read_dir(base).context("Failed to read extracted directory")? { - let entry = entry?; - let path = entry.path(); - if path.is_dir() && path.join("workspace.yaml").exists() { - return Ok(path); - } + let bytes = resp.bytes().await.context("Failed to read response body")?; + + if bytes.len() > MAX_ARCHIVE_SIZE { + anyhow::bail!( + "Task archive too large: {} bytes (max {})", + bytes.len(), + MAX_ARCHIVE_SIZE + ); } - anyhow::bail!( - "No workspace.yaml found in extracted task archive at {}", - base.display() - ) + tokio::fs::create_dir_all(dest) + .await + .context("Failed to create extraction directory")?; + + let dest = dest.to_path_buf(); + let bytes_vec = bytes.to_vec(); + tokio::task::spawn_blocking(move || extract_archive_bytes(&bytes_vec, &dest)) + .await + .context("Extract task panicked")??; + + Ok(()) } #[cfg(test)] @@ -212,65 +350,30 @@ language: "python" assert_eq!(config.repo, "https://github.com/psf/requests"); assert_eq!(config.version, "v2.31.0"); assert_eq!(config.base_commit.as_deref(), Some("abc123")); - assert_eq!(config.install.as_ref().unwrap().len(), 1); - assert_eq!(config.language.as_deref(), Some("python")); } #[test] - fn test_parse_workspace_minimal() { - let yaml = r#" -repo: "https://github.com/psf/requests" -version: "v2.31.0" -"#; - let config: WorkspaceConfig = serde_yaml::from_str(yaml).unwrap(); - assert!(config.base_commit.is_none()); - assert!(config.install.is_none()); + fn test_detect_agent_language() { + let tmp = tempfile::tempdir().unwrap(); + let agent_dir = tmp.path().join("agent_code"); + std::fs::create_dir_all(&agent_dir).unwrap(); + std::fs::write(agent_dir.join("main.py"), "print('hello')").unwrap(); + assert_eq!(detect_agent_language(tmp.path()), "python"); } #[test] - fn test_parse_task_with_checks_txt() { + fn test_parse_task_with_checks() { let tmp = tempfile::tempdir().unwrap(); let dir = tmp.path(); - std::fs::write( dir.join("workspace.yaml"), "repo: https://github.com/test/repo\nversion: v1.0\n", ) .unwrap(); std::fs::write(dir.join("prompt.md"), "Fix the bug").unwrap(); - std::fs::write( - dir.join("checks.txt"), - "# comment\npython -m pytest tests/\ncargo test\n", - ) - .unwrap(); + std::fs::write(dir.join("checks.txt"), "pytest tests/\ncargo test\n").unwrap(); let task = parse_task(dir).unwrap(); assert_eq!(task.test_scripts.len(), 2); - assert!(task.test_scripts[0].1.contains("pytest")); - assert!(task.test_scripts[1].1.contains("cargo test")); - } - - #[test] - fn test_find_task_root_direct() { - let tmp = tempfile::tempdir().unwrap(); - std::fs::write(tmp.path().join("workspace.yaml"), "repo: x\nversion: v1\n").unwrap(); - let root = find_task_root(tmp.path()).unwrap(); - assert_eq!(root, tmp.path()); - } - - #[test] - fn test_find_task_root_nested() { - let tmp = tempfile::tempdir().unwrap(); - let nested = tmp.path().join("task-dir"); - std::fs::create_dir_all(&nested).unwrap(); - std::fs::write(nested.join("workspace.yaml"), "repo: x\nversion: v1\n").unwrap(); - let root = find_task_root(tmp.path()).unwrap(); - assert_eq!(root, nested); - } - - #[test] - fn test_find_task_root_missing() { - let tmp = tempfile::tempdir().unwrap(); - assert!(find_task_root(tmp.path()).is_err()); } } diff --git a/src/ws.rs b/src/ws.rs new file mode 100644 index 0000000..3f4dd91 --- /dev/null +++ b/src/ws.rs @@ -0,0 +1,130 @@ +use axum::{ + extract::{ + ws::{Message, WebSocket}, + Query, State, WebSocketUpgrade, + }, + response::Response, +}; +use futures::{SinkExt, StreamExt}; +use serde::Deserialize; +use std::sync::Arc; +use tokio::sync::broadcast; +use tracing::{debug, info, warn}; + +use crate::handlers::AppState; +use crate::session::WsEvent; + +#[derive(Deserialize)] +pub struct WsQuery { + pub batch_id: String, +} + +pub async fn ws_handler( + ws: WebSocketUpgrade, + State(state): State>, + Query(query): Query, +) -> Response { + let batch_id = query.batch_id; + ws.on_upgrade(move |socket| handle_ws(socket, state, batch_id)) +} + +async fn handle_ws(socket: WebSocket, state: Arc, batch_id: String) { + let batch = match state.sessions.get(&batch_id) { + Some(b) => b, + None => { + let (mut sender, _) = socket.split(); + let err = serde_json::json!({ + "error": "batch_not_found", + "batch_id": batch_id, + }); + let _ = sender + .send(Message::Text(serde_json::to_string(&err).unwrap())) + .await; + return; + } + }; + + info!("WebSocket connected for batch {}", batch_id); + + let mut rx: broadcast::Receiver = batch.events_tx.subscribe(); + let (mut sender, mut receiver) = socket.split(); + + let current_state = batch.result.lock().await; + let snapshot = serde_json::json!({ + "event": "snapshot", + "batch_id": batch_id, + "data": { + "status": current_state.status, + "total_tasks": current_state.total_tasks, + "completed_tasks": current_state.completed_tasks, + "passed_tasks": current_state.passed_tasks, + "failed_tasks": current_state.failed_tasks, + "aggregate_reward": current_state.aggregate_reward, + "tasks": current_state.tasks, + } + }); + drop(current_state); + + if sender + .send(Message::Text(serde_json::to_string(&snapshot).unwrap())) + .await + .is_err() + { + return; + } + + let batch_id_send = batch_id.clone(); + let send_task = tokio::spawn(async move { + loop { + match rx.recv().await { + Ok(event) => { + let json = match serde_json::to_string(&event) { + Ok(j) => j, + Err(_) => continue, + }; + if sender.send(Message::Text(json)).await.is_err() { + break; + } + } + Err(broadcast::error::RecvError::Lagged(n)) => { + debug!("WebSocket lagged by {} messages", n); + continue; + } + Err(broadcast::error::RecvError::Closed) => { + let close_msg = serde_json::json!({ + "event": "stream_closed", + "batch_id": batch_id_send, + }); + let _ = sender + .send(Message::Text(serde_json::to_string(&close_msg).unwrap())) + .await; + break; + } + } + } + }); + + let recv_task = tokio::spawn(async move { + while let Some(msg) = receiver.next().await { + match msg { + Ok(Message::Close(_)) => break, + Ok(Message::Ping(data)) => { + debug!("Received ping"); + let _ = data; + } + Err(e) => { + warn!("WebSocket receive error: {}", e); + break; + } + _ => {} + } + } + }); + + tokio::select! { + _ = send_task => {}, + _ = recv_task => {}, + } + + info!("WebSocket disconnected for batch {}", batch_id); +}