Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions cli-tools/gpu-dev-cli/gpu_dev_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3309,19 +3309,45 @@ def _show_diagnostics(connection_info: dict) -> None:
_print_recovery_hints(connection_info)


def _show_lambda_logs(reservation_mgr, reservation_id: str, user_id: str) -> None:
"""Fetch + render the raw lambda (CloudWatch) logs for a reservation."""
from rich.text import Text
rprint("\n[bold]Fetching lambda logs from CloudWatch…[/bold] [dim](a few seconds)[/dim]")
result = reservation_mgr.get_reservation_logs(reservation_id, user_id)
if result is None:
rprint("[yellow]Could not reach the log backend (it may not be deployed yet, "
"or you lack lambda:InvokeFunctionUrl access).[/yellow]")
return
if result.get("error"):
rprint(f"[yellow]Log query: {result['error']}[/yellow]")
lines = result.get("lines") or []
if not lines:
rprint("[dim]No lambda log lines found for this reservation (outside the "
"retention window, or none recorded).[/dim]")
return
body = "\n".join(f"{ln.get('timestamp','')} {ln.get('message','')}".rstrip()
for ln in lines)
console.print(Panel(Text(body[-16000:]),
title=f"Lambda logs · {len(lines)} line(s)", border_style="cyan"))


@main.command()
@click.argument("reservation_id", required=False)
@click.option("--logs", "show_logs", is_flag=True,
help="Also fetch the raw lambda logs for this reservation from CloudWatch.")
@click.pass_context
def debug(ctx: click.Context, reservation_id: Optional[str]) -> None:
def debug(ctx: click.Context, reservation_id: Optional[str], show_logs: bool) -> None:
"""Diagnose your own reservation — why a box died or won't connect.

Shows the status timeline, failure reason, OOM events, and captured pod logs,
plus recovery steps — all without needing cluster or lambda access.
plus recovery steps — all without needing cluster or lambda access. Add --logs
to also pull the raw reservation/expiry lambda logs from CloudWatch.

\b
Examples:
gpu-dev debug # pick from your active reservations
gpu-dev debug abc12345 # a specific reservation (id prefix ok)
gpu-dev debug abc12345 --logs # + raw lambda logs from CloudWatch

For a recently failed/expired box, find its id with 'gpu-dev list' then
'gpu-dev debug <id>'.
Expand Down Expand Up @@ -3362,6 +3388,9 @@ def debug(ctx: click.Context, reservation_id: Optional[str]) -> None:

_show_single_reservation(connection_info)
_show_diagnostics(connection_info)
if show_logs:
_show_lambda_logs(reservation_mgr, connection_info["reservation_id"],
user_info["user_id"])

except RuntimeError as e:
rprint(f"[red]❌ {str(e)}[/red]")
Expand Down
20 changes: 18 additions & 2 deletions cli-tools/gpu-dev-cli/gpu_dev_cli/reservations.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def _get_direct_url(self) -> Optional[str]:
pass
return self._direct_url or None

def _signed_post(self, url: str, payload: dict) -> Optional[dict]:
def _signed_post(self, url: str, payload: dict, timeout: int = 20) -> Optional[dict]:
"""SigV4-signed POST to the Function URL. Returns parsed JSON or None."""
try:
creds = self.config.session.get_credentials()
Expand All @@ -623,13 +623,29 @@ def _signed_post(self, url: str, payload: dict) -> Optional[dict]:
aws_req = AWSRequest(method="POST", url=url, data=data,
headers={"Content-Type": "application/json"})
SigV4Auth(creds, "lambda", self.config.aws_region).add_auth(aws_req)
resp = requests.post(url, data=data, headers=dict(aws_req.headers), timeout=20)
resp = requests.post(url, data=data, headers=dict(aws_req.headers), timeout=timeout)
if resp.status_code != 200:
return None
return resp.json()
except Exception:
return None

def get_reservation_logs(self, reservation_id: str, user_id: str) -> Optional[Dict[str, Any]]:
"""Fetch a reservation's lambda logs (CloudWatch Logs Insights) via the
processor Function URL. Returns {"lines": [...]} / {"error": ...}, or None if
the backend/URL is unavailable. Used by `gpu-dev debug --logs`."""
url = self._get_direct_url()
if not url:
return None
payload = {
"action": "get_logs",
"reservation_id": reservation_id,
"user_id": user_id,
"version": get_version(),
}
# CloudWatch Logs Insights queries take longer than a claim — allow ~70s.
return self._signed_post(url, payload, timeout=70)

def claim_direct(self, *, user_id: str, gpu_count: int, gpu_type: str,
duration_hours: Union[int, float], name: Optional[str] = None,
github_user: Optional[str] = None, ref: Optional[str] = None) -> Optional[Dict[str, Any]]:
Expand Down
37 changes: 37 additions & 0 deletions terraform-gpu-devservers/lambda.tf
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,29 @@ resource "aws_iam_role_policy" "reservation_processor_policy" {
"ecr:GetAuthorizationToken"
]
Resource = "*"
},
{
# Read-only CloudWatch Logs Insights queries so `gpu-dev debug --logs` can
# return a reservation's own lambda logs. StartQuery is scoped to the two
# reservation log groups; GetQueryResults/StopQuery are not resource-scopable.
Effect = "Allow"
Action = [
"logs:StartQuery",
"logs:GetLogEvents",
"logs:FilterLogEvents"
]
Resource = [
"arn:aws:logs:*:*:log-group:/aws/lambda/${var.prefix}-reservation-processor:*",
"arn:aws:logs:*:*:log-group:/aws/lambda/${var.prefix}-reservation-expiry:*"
]
},
{
Effect = "Allow"
Action = [
"logs:GetQueryResults",
"logs:StopQuery"
]
Resource = "*"
}
]
})
Expand Down Expand Up @@ -245,6 +268,20 @@ resource "aws_cloudwatch_log_group" "reservation_processor_log_group" {
}
}

# CloudWatch Log Group for the expiry lambda. AWS otherwise auto-creates it with NO
# retention (it had grown to multiple GB unbounded). Manage it to cap retention.
# NOTE: the group already exists in AWS, so import it once per workspace before apply:
# tofu import aws_cloudwatch_log_group.reservation_expiry_log_group /aws/lambda/${var.prefix}-reservation-expiry
resource "aws_cloudwatch_log_group" "reservation_expiry_log_group" {
name = "/aws/lambda/${var.prefix}-reservation-expiry"
retention_in_days = 30

tags = {
Name = "${var.prefix}-reservation-expiry-logs"
Environment = local.current_config.environment
}
}

# Build Lambda package with dependencies and create zip in one step
resource "null_resource" "reservation_processor_build" {
triggers = {
Expand Down
88 changes: 88 additions & 0 deletions terraform-gpu-devservers/lambda/reservation_processor/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,6 +1915,88 @@ def _resp(payload):
}})


def handle_get_logs(body) -> dict:
"""Function URL handler for `gpu-dev debug --logs`: run a CloudWatch Logs Insights
query for one reservation across the processor + expiry log groups and return the
matching lines. Ownership is enforced — find_reservation_by_prefix(user_id=...)
only ever resolves the caller's own reservations."""
def _resp(payload, code=200):
return {"statusCode": code, "headers": {"Content-Type": "application/json"},
"body": json.dumps(payload)}

reservation_id = (body.get("reservation_id") or "").strip()
user_id = (body.get("user_id") or "").strip()
if not reservation_id or not user_id:
return _resp({"error": "reservation_id and user_id are required"}, 400)

try:
reservation = find_reservation_by_prefix(reservation_id, user_id=user_id)
except ValueError as ve:
return _resp({"error": str(ve)}, 404)
except Exception as e:
return _resp({"error": f"reservation lookup failed: {e}"}, 500)

full_id = reservation.get("reservation_id", reservation_id)
rid8 = full_id[:8]

def _ts(v):
try:
return int(datetime.fromisoformat(str(v).replace("Z", "+00:00")).timestamp())
except Exception:
return None

now = int(time.time())
start = _ts(reservation.get("created_at")) or _ts(reservation.get("launched_at")) or (now - 14 * 86400)
end = (_ts(reservation.get("reservation_ended")) or _ts(reservation.get("expired_at"))
or _ts(reservation.get("cancelled_at")) or _ts(reservation.get("failed_at")) or now)
start = max(start - 120, now - 14 * 86400) # small lead-in; cap to processor retention (14d)
end = min(end + 300, now)
if end <= start:
end = now

fn = os.environ.get("AWS_LAMBDA_FUNCTION_NAME", "")
prefix = fn[:-len("-reservation-processor")] if fn.endswith("-reservation-processor") else "pytorch-gpu-dev"
groups = [f"/aws/lambda/{prefix}-reservation-processor",
f"/aws/lambda/{prefix}-reservation-expiry"]
query = (f'fields @timestamp, @message | filter @message like "{rid8}" '
f'| sort @timestamp asc | limit 1000')

logs = boto3.client("logs")
try:
qid = logs.start_query(logGroupNames=groups, startTime=start, endTime=end,
queryString=query, limit=1000)["queryId"]
except Exception as e:
# A log group may not exist in this workspace — retry with just the processor.
try:
qid = logs.start_query(logGroupNames=groups[:1], startTime=start, endTime=end,
queryString=query, limit=1000)["queryId"]
except Exception as e2:
return _resp({"error": f"could not start log query: {e2}", "lines": []}, 500)

result = None
for _ in range(45):
r = logs.get_query_results(queryId=qid)
if r.get("status") in ("Complete", "Failed", "Cancelled", "Timeout"):
result = r
break
time.sleep(1)

if not result or result.get("status") != "Complete":
try:
logs.stop_query(queryId=qid)
except Exception:
pass
st = result.get("status") if result else "timeout"
return _resp({"error": f"log query did not complete (status={st})", "lines": []})

lines = []
for row in result.get("results", []):
d = {f["field"]: f["value"] for f in row}
lines.append({"timestamp": d.get("@timestamp", ""),
"message": (d.get("@message", "") or "").rstrip("\n")})
return _resp({"lines": lines, "reservation_id": full_id})


def handler(event, context):
"""Main Lambda handler"""
try:
Expand All @@ -1924,6 +2006,12 @@ def handler(event, context):
# Returns the active reservation in the HTTP response — no SQS, no poll.
# Only warm-eligible requests; anything else tells the CLI to use SQS.
if event.get("requestContext", {}).get("http") or event.get("rawPath"):
try:
_fu_body = json.loads(event.get("body") or "{}")
except Exception:
_fu_body = {}
if _fu_body.get("action") == "get_logs":
return handle_get_logs(_fu_body)
return handle_direct_claim(event)

# Scheduled tick to keep the warm pool topped up.
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/cli/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,33 @@ def test_debug_no_id_auto_selects_single(cli_runner):
out = _clean(res.output)
assert res.exit_code == 0
mgr.get_connection_info.assert_called_once_with("9b1466cc-aaa", "alice@example.com")


def test_debug_logs_flag_renders_lines(cli_runner):
mgr = MagicMock()
mgr.get_connection_info.return_value = _ci(status="failed")
mgr.get_reservation_logs.return_value = {"lines": [
{"timestamp": "2026-06-09T20:07:30", "message": "Creating pod gpu-dev-9b1466cc"},
{"timestamp": "2026-06-09T20:55:00", "message": "Evicted: node low on memory"},
]}
res = _invoke(cli_runner, mgr, ["9b1466cc", "--logs"])
out = _clean(res.output)
assert "Lambda logs" in out and "node low on memory" in out
mgr.get_reservation_logs.assert_called_once_with("9b1466cc-f272-40a6-90da-2bf0f4c1e599",
"alice@example.com")


def test_debug_logs_flag_backend_unavailable(cli_runner):
mgr = MagicMock()
mgr.get_connection_info.return_value = _ci(status="failed")
mgr.get_reservation_logs.return_value = None
res = _invoke(cli_runner, mgr, ["9b1466cc", "--logs"])
out = _clean(res.output)
assert "Could not reach the log backend" in out


def test_debug_without_logs_flag_does_not_query(cli_runner):
mgr = MagicMock()
mgr.get_connection_info.return_value = _ci(status="failed")
_invoke(cli_runner, mgr, ["9b1466cc"])
mgr.get_reservation_logs.assert_not_called()
59 changes: 59 additions & 0 deletions tests/unit/lambda_fn/test_get_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Unit tests for handle_get_logs — the Function URL handler powering
`gpu-dev debug --logs` (CloudWatch Logs Insights query for one reservation)."""
import json
from unittest.mock import MagicMock


def test_get_logs_requires_fields(lambda_index):
r = lambda_index.handle_get_logs({"reservation_id": "", "user_id": ""})
assert r["statusCode"] == 400


def test_get_logs_ownership_enforced(lambda_index, monkeypatch):
# find_reservation_by_prefix(user_id=...) raises for someone else's reservation.
def _raise(rid, user_id=None):
raise ValueError(f"Reservation {rid} not found for user {user_id}")
monkeypatch.setattr(lambda_index, "find_reservation_by_prefix", _raise)
r = lambda_index.handle_get_logs({"reservation_id": "abc", "user_id": "x@y.com"})
assert r["statusCode"] == 404
assert "not found" in json.loads(r["body"])["error"]


def test_get_logs_happy_path(lambda_index, monkeypatch):
monkeypatch.setattr(lambda_index, "find_reservation_by_prefix",
lambda rid, user_id=None: {
"reservation_id": "9b1466cc-f272-40a6-90da-2bf0f4c1e599",
"created_at": "2026-06-09T19:51:00"})
logs = MagicMock()
logs.start_query.return_value = {"queryId": "q1"}
logs.get_query_results.return_value = {"status": "Complete", "results": [
[{"field": "@timestamp", "value": "2026-06-09 20:07:30.123"},
{"field": "@message", "value": "Creating pod gpu-dev-9b1466cc\n"}],
]}
monkeypatch.setattr(lambda_index.boto3, "client", lambda svc, *a, **k: logs)

r = lambda_index.handle_get_logs({"reservation_id": "9b1466cc", "user_id": "x@y.com"})
assert r["statusCode"] == 200
body = json.loads(r["body"])
assert len(body["lines"]) == 1
assert body["lines"][0]["message"] == "Creating pod gpu-dev-9b1466cc" # newline stripped
# query filters on the 8-char id prefix, scoped to the reservation's lifetime
kw = logs.start_query.call_args.kwargs
assert "9b1466cc" in kw["queryString"]
assert kw["startTime"] < kw["endTime"]


def test_get_logs_query_incomplete_returns_error(lambda_index, monkeypatch):
monkeypatch.setattr(lambda_index, "find_reservation_by_prefix",
lambda rid, user_id=None: {"reservation_id": "9b1466cc-x",
"created_at": "2026-06-09T19:51:00"})
logs = MagicMock()
logs.start_query.return_value = {"queryId": "q1"}
logs.get_query_results.return_value = {"status": "Failed", "results": []}
monkeypatch.setattr(lambda_index.boto3, "client", lambda svc, *a, **k: logs)
monkeypatch.setattr(lambda_index.time, "sleep", lambda *_a, **_k: None)

r = lambda_index.handle_get_logs({"reservation_id": "9b1466cc", "user_id": "x@y.com"})
assert r["statusCode"] == 200
body = json.loads(r["body"])
assert body["lines"] == [] and "did not complete" in body["error"]
Loading