diff --git a/progress.md b/progress.md new file mode 100644 index 00000000..dbfbfe3b --- /dev/null +++ b/progress.md @@ -0,0 +1,469 @@ +# Cloud-Agnostic Architecture Progress + +## Overview + +This document tracks the progress of making ODC (Open Developer Cloud) cloud-agnostic, supporting AWS, GCP, and custom deployments. + +--- + +## Section 1: Current State (Dev Branch) + +The dev branch has successfully migrated from AWS Lambda/DynamoDB to a Kubernetes-native architecture: + +| Component | Old (main) | New (dev) | Status | +|-----------|------------|-----------|--------| +| Job Queue | AWS SQS | PostgreSQL PGMQ | ✅ Done | +| State Store | DynamoDB | PostgreSQL | ✅ Done | +| Reservation Processing | Lambda | K8s Pod (processor) | ✅ Done | +| Availability Updates | Lambda | K8s CronJob | ✅ Done | +| Expiry Handling | Lambda | K8s CronJob | ✅ Done | +| API Service | Lambda + API Gateway | FastAPI in K8s | ✅ Done | + +### Architecture Diagram + +``` +┌─────────────┐ ┌──────────────────────────────────────────────────────┐ +│ CLI │────▶│ API Service (FastAPI) │ +│ (gpu-dev) │ │ - AWS IAM Auth (to be replaced with OIDC) │ +└─────────────┘ └──────────────────┬───────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────────────────┐ +│ PostgreSQL + PGMQ │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │reservations │ │ disks │ │ api_users │ │ pgmq queues │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ │ +└──────────────────────────────────────┬───────────────────────────────────┘ + │ + ┌──────────────────────────┼──────────────────────────┐ + ▼ ▼ ▼ +┌───────────────────┐ ┌───────────────────┐ ┌───────────────────┐ +│ Reservation │ │ Availability │ │ Expiry │ +│ Processor Pod │ │ Updater CronJob │ │ Handler CronJob │ +│ - Polls PGMQ │ │ - Updates GPU │ │ - Sends warnings │ +│ - Creates pods │ │ availability │ │ - Expires old res │ +│ - Manages disks │ │ │ │ - Creates snapshots│ +└───────────────────┘ └───────────────────┘ └───────────────────┘ +``` + +--- + +## Section 2: AWS-Specific Dependencies + +### Block Storage (EBS) - HIGH PRIORITY + +| File | AWS Dependency | Lines | +|------|----------------|-------| +| `shared/disk_reconciler.py` | EC2 volume listing, tagging, snapshot operations | ~800 | +| `shared/snapshot_utils.py` | `ec2_client.create_snapshot()`, `describe_snapshots()` | ~200 | +| `reservation_handler.py` | Direct EBS volume attachment, cross-AZ migration | ~400 | +| `expiry/main.py` | EC2 snapshot tagging | ~50 | +| `database/schema/003_disks.sql` | `ebs_volume_id` column | Schema | + +**Current Flow:** +```python +# Pod spec uses direct EBS attachment (NOT PVC!) +client.V1Volume( + name="dev-home", + aws_elastic_block_store=client.V1AWSElasticBlockStoreVolumeSource( + volume_id=ebs_volume_id, + fs_type="ext4" + ) +) +``` + +**Target Flow (CSI-based):** +```python +# Use PersistentVolumeClaim instead +client.V1Volume( + name="dev-home", + persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource( + claim_name=f"gpu-dev-{user_id}-{disk_name}" + ) +) +``` + +### Snapshots - HIGH PRIORITY + +| Current (AWS) | Target (K8s Native) | +|---------------|---------------------| +| `ec2_client.create_snapshot(VolumeId)` | `VolumeSnapshot` CR | +| `ec2_client.describe_snapshots()` | `kubectl get volumesnapshots` | +| `ec2_client.delete_snapshot()` | `kubectl delete volumesnapshot` | +| Wait via `get_waiter("snapshot_completed")` | Watch VolumeSnapshot status | + +**Required Components:** +- Snapshot Controller (deploy as addon or standalone) +- VolumeSnapshotClass for each CSI driver +- Update all snapshot_utils.py to use K8s API + +### File Storage (EFS) - MEDIUM PRIORITY + +| File | AWS Dependency | +|------|----------------| +| `efs.tf` | EFS resources, mount targets | +| `reservation_handler.py` | `create_or_find_user_efs()`, EFS client API | + +**Options:** +1. **EFS CSI Driver** - AWS-specific but uses K8s primitives +2. **Generic NFS CSI** - Cloud-agnostic, works with any NFS +3. **GCP Filestore CSI** - GCP equivalent + +### DNS (Route53) - LOW PRIORITY + +| File | AWS Dependency | +|------|----------------| +| `shared/dns_utils.py` | Route53 record management | +| `route53.tf` | Hosted zone configuration | + +**Target:** Use `external-dns` controller with annotations + +### Authentication (IAM/STS) - HIGH PRIORITY + +| File | AWS Dependency | +|------|----------------| +| `api-service/app/main.py` | STS `get_caller_identity()` verification | +| `cli-tools/gpu-dev-cli/gpu_dev_cli/auth.py` | AWS credentials for API auth | + +**Target:** OIDC-based authentication (see Section 4) + +### Container Registry (ECR) - LOW PRIORITY + +| Current | Options | +|---------|---------| +| ECR with pull-through cache | GHCR (works everywhere) | +| | In-cluster registry with pull-through | +| | GCP Artifact Registry | + +--- + +## Section 3: Provider Interface Design + +### Directory Structure + +``` +terraform-gpu-devservers/ +├── providers/ +│ ├── __init__.py # Provider factory +│ ├── base.py # Abstract interfaces +│ ├── aws.py # AWS implementation +│ ├── gcp.py # GCP implementation +│ └── custom.py # Template for custom providers +``` + +### Base Interface + +```python +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, List, Dict, Any + +@dataclass +class VolumeInfo: + volume_id: str + size_gb: int + availability_zone: str + status: str + tags: Dict[str, str] + +@dataclass +class SnapshotInfo: + snapshot_id: str + volume_id: str + status: str + size_gb: int + created_at: str + tags: Dict[str, str] + +class CloudProvider(ABC): + @abstractmethod + def name(self) -> str: + """Provider name (aws, gcp, custom)""" + pass + + # === Block Storage === + @abstractmethod + def create_volume(self, size_gb: int, availability_zone: str, + volume_type: str = "ssd", tags: Dict[str, str] = None, + snapshot_id: Optional[str] = None) -> VolumeInfo: + pass + + @abstractmethod + def delete_volume(self, volume_id: str) -> bool: + pass + + @abstractmethod + def attach_volume(self, volume_id: str, instance_id: str, + device_path: str) -> bool: + pass + + @abstractmethod + def detach_volume(self, volume_id: str) -> bool: + pass + + # === Snapshots === + @abstractmethod + def create_snapshot(self, volume_id: str, description: str = "", + tags: Dict[str, str] = None) -> SnapshotInfo: + pass + + @abstractmethod + def delete_snapshot(self, snapshot_id: str) -> bool: + pass + + @abstractmethod + def list_snapshots(self, filters: Dict[str, str] = None) -> List[SnapshotInfo]: + pass + + @abstractmethod + def wait_for_snapshot(self, snapshot_id: str, + timeout_seconds: int = 600) -> bool: + pass + + # === Object Storage === + @abstractmethod + def upload_to_object_storage(self, bucket: str, key: str, + content: bytes) -> str: + pass + + @abstractmethod + def download_from_object_storage(self, bucket: str, + key: str) -> Optional[bytes]: + pass + + +class AuthProvider(ABC): + @abstractmethod + def verify_token(self, token: str) -> Optional[Dict[str, Any]]: + """Verify auth token, return user info or None""" + pass + + @abstractmethod + def create_api_key(self, user_id: str, scopes: List[str], + ttl_hours: int = 24) -> str: + pass +``` + +### Provider Factory + +```python +import os +from typing import Optional +from .base import CloudProvider + +_provider_instance: Optional[CloudProvider] = None + +def get_cloud_provider() -> CloudProvider: + global _provider_instance + if _provider_instance is not None: + return _provider_instance + + provider_name = os.environ.get("CLOUD_PROVIDER", "aws").lower() + + if provider_name == "aws": + from .aws import AWSProvider + _provider_instance = AWSProvider( + region=os.environ.get("AWS_REGION", "us-east-2") + ) + elif provider_name == "gcp": + from .gcp import GCPProvider + _provider_instance = GCPProvider( + project=os.environ.get("GCP_PROJECT"), + zone=os.environ.get("GCP_ZONE", "us-central1-a") + ) + elif provider_name == "custom": + from .custom import CustomProvider + _provider_instance = CustomProvider() + else: + raise ValueError(f"Unknown cloud provider: {provider_name}") + + return _provider_instance +``` + +--- + +## Section 4: OIDC Authentication Design + +### Current Flow (AWS IAM) + +``` +CLI -> AWS Credentials -> API Service -> STS Verify -> Issue API Key +``` + +### Target Flow (OIDC) + +``` +┌──────────┐ ┌──────────┐ ┌─────────────┐ ┌──────────┐ +│ User │────▶│ OIDC │────▶│ API Service │────▶│ Resource │ +│ (CLI) │ │ Provider │ │ (validates) │ │ Creation │ +└──────────┘ └──────────┘ └─────────────┘ └──────────┘ + │ + ▼ + ┌─────────────┐ + │ audit_log │ + │ (traceable) │ + └─────────────┘ +``` + +### Database Changes + +```sql +-- Add OIDC fields to api_users +ALTER TABLE api_users ADD COLUMN oidc_subject VARCHAR(255); +ALTER TABLE api_users ADD COLUMN oidc_issuer VARCHAR(512); +ALTER TABLE api_users ADD COLUMN oidc_claims JSONB; + +-- Audit log for traceability (including Bedrock/Claude usage) +CREATE TABLE audit_log ( + id SERIAL PRIMARY KEY, + timestamp TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + user_id INTEGER REFERENCES api_users(id), + username VARCHAR(255), + action VARCHAR(100), -- reserve, cancel, extend, etc. + resource_type VARCHAR(50), + resource_id VARCHAR(255), + request_metadata JSONB, + bedrock_request_id VARCHAR(255), + bedrock_tokens_used INTEGER +); +``` + +### OIDC Provider Options + +| Provider | Pros | Cons | +|----------|------|------| +| **GitHub** | Devs have accounts, matches SSH key auth | Limited enterprise features | +| **Google** | Universal, easy setup | May not fit enterprise policy | +| **Okta/Auth0** | Enterprise features, MFA | Cost, complexity | +| **Dex** | Self-hosted, multi-provider | Operational overhead | + +**Recommendation:** Start with GitHub OIDC (matches existing SSH key auth pattern), add enterprise options later. + +--- + +## Section 5: Migration Phases + +### Phase 1: Abstraction Layer (Days 1-2) - LOW RISK + +- [ ] Create `providers/` directory structure +- [ ] Implement `CloudProvider` base class +- [ ] Implement `AWSProvider` wrapping existing boto3 calls +- [ ] Add `get_cloud_provider()` factory +- [ ] **No production changes yet** + +### Phase 2: Refactor Storage Code (Days 3-7) - MEDIUM RISK + +- [ ] Refactor `shared/snapshot_utils.py` to use provider interface +- [ ] Refactor `reservation_handler.py` volume operations +- [ ] Refactor `shared/disk_reconciler.py` +- [ ] Update `expiry/main.py` snapshot operations +- [ ] Add comprehensive tests + +### Phase 3: K8s-Native Storage (Week 2) - MEDIUM RISK + +- [ ] Deploy Snapshot Controller +- [ ] Create VolumeSnapshotClass resources +- [ ] Add `K8sStorageProvider` implementation +- [ ] Support PVC-based volume attachment +- [ ] Make storage backend configurable + +### Phase 4: GCP Provider (Weeks 3-4) - MEDIUM-HIGH RISK + +- [ ] Implement `GCPProvider` class +- [ ] Create Terraform modules for GKE +- [ ] GCE Persistent Disk operations +- [ ] GCP Filestore for shared storage +- [ ] End-to-end testing on GCP + +### Phase 5: OIDC Authentication (Week 5) - HIGH RISK + +- [ ] Add OIDC token verification to API service +- [ ] Create user mapping (OIDC subject -> internal user) +- [ ] Add audit logging with full traceability +- [ ] Update CLI for OIDC login flow +- [ ] Dual-auth period (AWS IAM + OIDC) + +### Phase 6: DNS and Load Balancing (Week 6) - LOW-MEDIUM RISK + +- [ ] Deploy external-dns controller +- [ ] Replace Route53 calls with K8s annotations +- [ ] Make DNS optional/configurable +- [ ] Document DNS-free deployment option + +--- + +## Section 6: Files to Modify + +### Storage Abstraction + +| File | Changes | +|------|---------| +| `shared/disk_reconciler.py` | Replace EC2 API with provider interface | +| `shared/snapshot_utils.py` | Replace EC2 snapshot calls with provider/K8s API | +| `shared/disk_db.py` | Add `pvc_name`, `storage_class` columns | +| `reservation_handler.py` | Use PVC-based volumes | +| `expiry/main.py` | CSI-based snapshot cleanup | +| `database/schema/003_disks.sql` | Add PVC columns | +| `eks.tf` | Add snapshot controller addon | +| `monitoring.tf` | Add VolumeSnapshotClass resources | + +### Authentication + +| File | Changes | +|------|---------| +| `api-service/app/main.py` | Add OIDC verification endpoint | +| `cli-tools/gpu-dev-cli/gpu_dev_cli/auth.py` | OIDC login flow | +| `database/schema/` | Add audit_log table, OIDC fields | + +### DNS + +| File | Changes | +|------|---------| +| `shared/dns_utils.py` | Make optional, add external-dns option | +| `route53.tf` | Make conditional | + +--- + +## Section 7: Open Questions + +### Authentication +1. Which OIDC provider(s) to support initially? +2. How to handle AWS IAM → OIDC transition? +3. What user attributes needed from OIDC claims? +4. How to trace Bedrock/Claude token usage to users? + +### Storage +5. Acceptable latency for snapshot operations? +6. Should we support both direct EBS and PVC modes? + +### Registry +7. Single registry (GHCR) or multi-region? +8. How to handle custom images built per reservation? + +### General +9. Is external-dns required or optional? +10. Should provider interface be a separate package? +11. Priority: AWS improvements vs. multi-cloud? + +--- + +## Appendix: Key File References + +### Storage +- `terraform-gpu-devservers/shared/snapshot_utils.py` +- `terraform-gpu-devservers/shared/disk_reconciler.py` +- `terraform-gpu-devservers/shared/disk_db.py` +- `terraform-gpu-devservers/reservation-processor-service/processor/reservation_handler.py` + +### Authentication +- `terraform-gpu-devservers/api-service/app/main.py` +- `cli-tools/gpu-dev-cli/gpu_dev_cli/auth.py` + +### Database +- `terraform-gpu-devservers/database/schema/002_reservations.sql` +- `terraform-gpu-devservers/database/schema/003_disks.sql` + +### Infrastructure +- `terraform-gpu-devservers/eks.tf` +- `terraform-gpu-devservers/efs.tf` +- `terraform-gpu-devservers/monitoring.tf` diff --git a/terraform-gpu-devservers/database/migrations/006_add_oidc_auth.sql b/terraform-gpu-devservers/database/migrations/006_add_oidc_auth.sql new file mode 100644 index 00000000..a7fda2bb --- /dev/null +++ b/terraform-gpu-devservers/database/migrations/006_add_oidc_auth.sql @@ -0,0 +1,185 @@ +-- Migration: Add OIDC authentication support +-- Adds OIDC identity tracking to api_users and creates audit/token usage tables + +-- ============================================================================ +-- Add OIDC columns to api_users +-- ============================================================================ + +-- Add OIDC subject identifier (unique per issuer) +ALTER TABLE api_users ADD COLUMN IF NOT EXISTS oidc_subject VARCHAR(512); + +-- Add OIDC issuer URL (e.g., https://token.actions.githubusercontent.com) +ALTER TABLE api_users ADD COLUMN IF NOT EXISTS oidc_issuer VARCHAR(512); + +-- Create unique constraint for OIDC identity (subject + issuer combo) +CREATE UNIQUE INDEX IF NOT EXISTS idx_api_users_oidc_identity + ON api_users(oidc_subject, oidc_issuer) + WHERE oidc_subject IS NOT NULL AND oidc_issuer IS NOT NULL; + +-- Index for looking up users by OIDC issuer +CREATE INDEX IF NOT EXISTS idx_api_users_oidc_issuer + ON api_users(oidc_issuer) + WHERE oidc_issuer IS NOT NULL; + +-- Add comments for documentation +COMMENT ON COLUMN api_users.oidc_subject IS 'OIDC subject identifier (sub claim from JWT)'; +COMMENT ON COLUMN api_users.oidc_issuer IS 'OIDC issuer URL (iss claim from JWT)'; + +-- ============================================================================ +-- Create audit_log table +-- ============================================================================ + +CREATE TABLE IF NOT EXISTS audit_log ( + event_id SERIAL PRIMARY KEY, + + -- Who performed the action + user_id INTEGER REFERENCES api_users(user_id) ON DELETE SET NULL, + username VARCHAR(255), + + -- What action was performed + event_type VARCHAR(64) NOT NULL, + action TEXT NOT NULL, + + -- What resource was affected + resource_type VARCHAR(64), + resource_id VARCHAR(255), + + -- Additional details (JSON) + details JSONB DEFAULT '{}', + + -- Request context + ip_address INET, + user_agent TEXT, + + -- Timestamp + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); + +-- Index for querying user's audit history +CREATE INDEX IF NOT EXISTS idx_audit_log_user_id + ON audit_log(user_id, created_at DESC) + WHERE user_id IS NOT NULL; + +-- Index for querying by event type +CREATE INDEX IF NOT EXISTS idx_audit_log_event_type + ON audit_log(event_type, created_at DESC); + +-- Index for querying resource history +CREATE INDEX IF NOT EXISTS idx_audit_log_resource + ON audit_log(resource_type, resource_id, created_at DESC) + WHERE resource_type IS NOT NULL AND resource_id IS NOT NULL; + +-- Index for time-based queries (cleanup, reporting) +CREATE INDEX IF NOT EXISTS idx_audit_log_created_at + ON audit_log(created_at); + +-- Add comments +COMMENT ON TABLE audit_log IS 'Audit trail for all user actions and system events'; +COMMENT ON COLUMN audit_log.event_type IS 'Event category (auth.login, reservation.create, etc.)'; +COMMENT ON COLUMN audit_log.action IS 'Human-readable description of the action'; +COMMENT ON COLUMN audit_log.details IS 'Additional event details in JSON format'; + +-- ============================================================================ +-- Create token_usage table for LLM billing/monitoring +-- ============================================================================ + +CREATE TABLE IF NOT EXISTS token_usage ( + usage_id SERIAL PRIMARY KEY, + + -- Who used the tokens + user_id INTEGER NOT NULL REFERENCES api_users(user_id) ON DELETE CASCADE, + + -- What model was used + model VARCHAR(128) NOT NULL, + + -- Token counts + input_tokens INTEGER NOT NULL DEFAULT 0, + output_tokens INTEGER NOT NULL DEFAULT 0, + total_tokens INTEGER NOT NULL DEFAULT 0, + + -- Cost tracking (optional) + cost_usd DECIMAL(12, 6), + + -- Request correlation + request_id VARCHAR(255), + + -- Timestamp + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); + +-- Index for querying user's token usage +CREATE INDEX IF NOT EXISTS idx_token_usage_user_id + ON token_usage(user_id, created_at DESC); + +-- Index for querying by model +CREATE INDEX IF NOT EXISTS idx_token_usage_model + ON token_usage(model, created_at DESC); + +-- Index for time-based aggregation (billing reports) +CREATE INDEX IF NOT EXISTS idx_token_usage_created_at + ON token_usage(created_at); + +-- Index for correlating with requests +CREATE INDEX IF NOT EXISTS idx_token_usage_request_id + ON token_usage(request_id) + WHERE request_id IS NOT NULL; + +-- Add comments +COMMENT ON TABLE token_usage IS 'Tracks LLM/AI token usage for billing and monitoring'; +COMMENT ON COLUMN token_usage.model IS 'LLM model name (e.g., claude-3-opus, gpt-4)'; +COMMENT ON COLUMN token_usage.cost_usd IS 'Estimated cost in USD based on model pricing'; +COMMENT ON COLUMN token_usage.request_id IS 'Request ID for correlation with audit log'; + +-- ============================================================================ +-- Create view for user token usage summary +-- ============================================================================ + +CREATE OR REPLACE VIEW user_token_summary AS +SELECT + u.user_id, + u.username, + t.model, + COUNT(*) as request_count, + SUM(t.input_tokens) as total_input_tokens, + SUM(t.output_tokens) as total_output_tokens, + SUM(t.total_tokens) as total_tokens, + SUM(COALESCE(t.cost_usd, 0)) as total_cost_usd, + MIN(t.created_at) as first_usage, + MAX(t.created_at) as last_usage +FROM token_usage t +JOIN api_users u ON t.user_id = u.user_id +GROUP BY u.user_id, u.username, t.model; + +COMMENT ON VIEW user_token_summary IS 'Aggregated token usage per user per model'; + +-- ============================================================================ +-- Create function for audit log cleanup +-- ============================================================================ + +CREATE OR REPLACE FUNCTION cleanup_old_audit_logs(days_to_keep INTEGER DEFAULT 90) +RETURNS INTEGER AS $$ +DECLARE + deleted_count INTEGER; +BEGIN + DELETE FROM audit_log + WHERE created_at < NOW() - (days_to_keep || ' days')::INTERVAL; + + GET DIAGNOSTICS deleted_count = ROW_COUNT; + RETURN deleted_count; +END; +$$ LANGUAGE plpgsql; + +COMMENT ON FUNCTION cleanup_old_audit_logs(INTEGER) IS 'Deletes audit log entries older than specified days'; + +-- ============================================================================ +-- Migrate existing AWS-authenticated users to have placeholder OIDC info +-- ============================================================================ + +-- For existing users authenticated via AWS, we can optionally mark them +-- This allows gradual migration without breaking existing functionality +-- Uncomment if you want to track AWS SSO users separately: + +-- UPDATE api_users +-- SET oidc_issuer = 'aws-sts-legacy' +-- WHERE oidc_issuer IS NULL +-- AND username LIKE '%@%'; diff --git a/terraform-gpu-devservers/providers/__init__.py b/terraform-gpu-devservers/providers/__init__.py new file mode 100644 index 00000000..2c6f5a97 --- /dev/null +++ b/terraform-gpu-devservers/providers/__init__.py @@ -0,0 +1,195 @@ +""" +Cloud Provider Factory + +This module provides a factory function to get the appropriate cloud provider +based on configuration. The provider abstraction allows the GPU reservation +system to work with multiple cloud platforms without modifying core business logic. + +Usage: + from providers import get_cloud_provider + + provider = get_cloud_provider() + + # Storage operations + volume = provider.create_volume(size_gb=100, availability_zone='us-east-2a') + + # Snapshot operations + snapshot = provider.create_snapshot(volume.volume_id) + + # Object storage + uri = provider.upload_to_object_storage('bucket', 'key', b'content') + +Configuration: + Set CLOUD_PROVIDER environment variable: + - 'aws' (default): Amazon Web Services + - 'gcp': Google Cloud Platform + - 'custom': Custom/on-premises provider + + Provider-specific configuration via environment variables: + - AWS: AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY + - GCP: GCP_PROJECT, GCP_ZONE, GOOGLE_APPLICATION_CREDENTIALS + - Custom: CUSTOM_STORAGE_BACKEND, CUSTOM_AUTH_BACKEND +""" + +import logging +import os +from typing import Optional + +from .base import ( + AuthProvider, + AuthenticationError, + AuthorizationError, + CloudProvider, + NodeInfo, + ProviderError, + QuotaExceededError, + SnapshotInfo, + SnapshotNotFoundError, + VolumeInfo, + VolumeInUseError, + VolumeNotFoundError, +) + +logger = logging.getLogger(__name__) + +# Cached provider instance +_provider_instance: Optional[CloudProvider] = None + + +def get_cloud_provider( + provider_name: Optional[str] = None, + force_new: bool = False, + **kwargs +) -> CloudProvider: + """ + Get the configured cloud provider instance. + + This factory function returns the appropriate provider based on configuration. + The provider instance is cached for performance; use force_new=True to + create a new instance. + + Args: + provider_name: Override the provider (defaults to CLOUD_PROVIDER env var) + force_new: Force creation of new instance (bypass cache) + **kwargs: Provider-specific configuration options + + Returns: + CloudProvider instance (AWSProvider, GCPProvider, or CustomProvider) + + Raises: + ValueError: If provider name is not recognized + + Example: + # Use default provider from environment + provider = get_cloud_provider() + + # Override provider for testing + provider = get_cloud_provider('custom') + + # Force new instance with custom config + provider = get_cloud_provider('aws', force_new=True, region='us-west-2') + """ + global _provider_instance + + # Use cached instance if available and not forcing new + if _provider_instance is not None and not force_new and provider_name is None: + return _provider_instance + + # Determine provider name + name = provider_name or os.environ.get("CLOUD_PROVIDER", "aws") + name = name.lower() + + logger.info(f"Initializing cloud provider: {name}") + + if name == "aws": + from .aws import AWSProvider + region = kwargs.get("region") or os.environ.get("AWS_REGION", "us-east-2") + provider = AWSProvider(region=region) + + elif name == "gcp": + from .gcp import GCPProvider + project = kwargs.get("project") or os.environ.get("GCP_PROJECT", "") + zone = kwargs.get("zone") or os.environ.get("GCP_ZONE", "us-central1-a") + if not project: + raise ValueError( + "GCP_PROJECT environment variable must be set for GCP provider" + ) + provider = GCPProvider(project=project, zone=zone) + + elif name == "custom": + from .custom import CustomProvider + provider = CustomProvider() + + else: + raise ValueError( + f"Unknown cloud provider: {name}. " + f"Valid options: aws, gcp, custom" + ) + + # Cache the instance + if not force_new: + _provider_instance = provider + + return provider + + +def get_auth_provider( + provider_name: Optional[str] = None, + **kwargs +) -> AuthProvider: + """ + Get an authentication provider instance. + + Args: + provider_name: Override the provider (defaults to CLOUD_PROVIDER env var) + **kwargs: Provider-specific configuration options + + Returns: + AuthProvider instance + """ + name = provider_name or os.environ.get("CLOUD_PROVIDER", "aws") + name = name.lower() + + if name == "aws": + from .aws import AWSIAMAuthProvider + region = kwargs.get("region") or os.environ.get("AWS_REGION", "us-east-2") + return AWSIAMAuthProvider(region=region) + + elif name == "gcp": + raise NotImplementedError("GCP auth provider not implemented") + + elif name == "custom": + from .custom import CustomAuthProvider + return CustomAuthProvider() + + else: + raise ValueError(f"Unknown auth provider: {name}") + + +def clear_provider_cache(): + """Clear the cached provider instance.""" + global _provider_instance + _provider_instance = None + + +__all__ = [ + # Factory functions + "get_cloud_provider", + "get_auth_provider", + "clear_provider_cache", + # Base classes + "CloudProvider", + "AuthProvider", + # Data classes + "VolumeInfo", + "SnapshotInfo", + "NodeInfo", + # Exceptions + "ProviderError", + "VolumeNotFoundError", + "VolumeInUseError", + "SnapshotNotFoundError", + "QuotaExceededError", + "AuthenticationError", + "AuthorizationError", +] diff --git a/terraform-gpu-devservers/providers/aws.py b/terraform-gpu-devservers/providers/aws.py new file mode 100644 index 00000000..1679d907 --- /dev/null +++ b/terraform-gpu-devservers/providers/aws.py @@ -0,0 +1,357 @@ +""" +AWS Cloud Provider Implementation + +Wraps existing boto3 code to provide cloud-agnostic interface. +""" + +import logging +from typing import Dict, List, Optional, Any +from datetime import datetime, timezone + +import boto3 +from botocore.exceptions import ClientError + +from .base import ( + CloudProvider, + AuthProvider, + VolumeInfo, + SnapshotInfo, + NodeInfo, +) + +logger = logging.getLogger(__name__) + + +class AWSProvider(CloudProvider): + """AWS implementation of CloudProvider interface.""" + + def __init__(self, region: str = "us-east-2"): + self.region = region + self._ec2 = None + self._s3 = None + self._autoscaling = None + self._efs = None + + @property + def ec2(self): + if self._ec2 is None: + self._ec2 = boto3.client("ec2", region_name=self.region) + return self._ec2 + + @property + def s3(self): + if self._s3 is None: + self._s3 = boto3.client("s3", region_name=self.region) + return self._s3 + + @property + def efs(self): + if self._efs is None: + self._efs = boto3.client("efs", region_name=self.region) + return self._efs + + def name(self) -> str: + return "aws" + + # === Block Storage (EBS) === + + def create_volume( + self, + size_gb: int, + availability_zone: str, + volume_type: str = "ssd", + tags: Optional[Dict[str, str]] = None, + snapshot_id: Optional[str] = None, + ) -> VolumeInfo: + """Create an EBS volume.""" + aws_volume_type = {"ssd": "gp3", "hdd": "sc1", "io": "io2"}.get(volume_type, "gp3") + + params = { + "AvailabilityZone": availability_zone, + "Size": size_gb, + "VolumeType": aws_volume_type, + "Encrypted": True, + } + + if snapshot_id: + params["SnapshotId"] = snapshot_id + + if aws_volume_type == "gp3": + params["Iops"] = 3000 + params["Throughput"] = 125 + + response = self.ec2.create_volume(**params) + volume_id = response["VolumeId"] + + if tags: + self.ec2.create_tags( + Resources=[volume_id], + Tags=[{"Key": k, "Value": v} for k, v in tags.items()], + ) + + return VolumeInfo( + volume_id=volume_id, + size_gb=response["Size"], + availability_zone=response["AvailabilityZone"], + status=response["State"], + tags=tags or {}, + ) + + def delete_volume(self, volume_id: str) -> bool: + """Delete an EBS volume.""" + try: + self.ec2.delete_volume(VolumeId=volume_id) + return True + except ClientError as e: + logger.error(f"Failed to delete volume {volume_id}: {e}") + return False + + def attach_volume( + self, volume_id: str, instance_id: str, device_path: str + ) -> bool: + """Attach EBS volume to EC2 instance.""" + try: + self.ec2.attach_volume( + VolumeId=volume_id, + InstanceId=instance_id, + Device=device_path, + ) + return True + except ClientError as e: + logger.error(f"Failed to attach volume {volume_id}: {e}") + return False + + def detach_volume(self, volume_id: str) -> bool: + """Detach EBS volume from instance.""" + try: + self.ec2.detach_volume(VolumeId=volume_id) + return True + except ClientError as e: + logger.error(f"Failed to detach volume {volume_id}: {e}") + return False + + def get_volume(self, volume_id: str) -> Optional[VolumeInfo]: + """Get EBS volume information.""" + try: + response = self.ec2.describe_volumes(VolumeIds=[volume_id]) + if response["Volumes"]: + vol = response["Volumes"][0] + tags = {t["Key"]: t["Value"] for t in vol.get("Tags", [])} + return VolumeInfo( + volume_id=vol["VolumeId"], + size_gb=vol["Size"], + availability_zone=vol["AvailabilityZone"], + status=vol["State"], + tags=tags, + ) + except ClientError: + pass + return None + + def list_volumes( + self, filters: Optional[Dict[str, str]] = None + ) -> List[VolumeInfo]: + """List EBS volumes matching filters.""" + aws_filters = [] + if filters: + for key, value in filters.items(): + aws_filters.append({"Name": f"tag:{key}", "Values": [value]}) + + response = self.ec2.describe_volumes(Filters=aws_filters) + + volumes = [] + for vol in response.get("Volumes", []): + tags = {t["Key"]: t["Value"] for t in vol.get("Tags", [])} + volumes.append( + VolumeInfo( + volume_id=vol["VolumeId"], + size_gb=vol["Size"], + availability_zone=vol["AvailabilityZone"], + status=vol["State"], + tags=tags, + ) + ) + return volumes + + # === Snapshots === + + def create_snapshot( + self, + volume_id: str, + description: str = "", + tags: Optional[Dict[str, str]] = None, + ) -> SnapshotInfo: + """Create EBS snapshot.""" + response = self.ec2.create_snapshot( + VolumeId=volume_id, + Description=description, + ) + + snapshot_id = response["SnapshotId"] + + if tags: + self.ec2.create_tags( + Resources=[snapshot_id], + Tags=[{"Key": k, "Value": v} for k, v in tags.items()], + ) + + return SnapshotInfo( + snapshot_id=snapshot_id, + volume_id=volume_id, + status=response["State"], + size_gb=response["VolumeSize"], + created_at=response["StartTime"].isoformat(), + tags=tags or {}, + ) + + def delete_snapshot(self, snapshot_id: str) -> bool: + """Delete EBS snapshot.""" + try: + self.ec2.delete_snapshot(SnapshotId=snapshot_id) + return True + except ClientError as e: + logger.error(f"Failed to delete snapshot {snapshot_id}: {e}") + return False + + def get_snapshot(self, snapshot_id: str) -> Optional[SnapshotInfo]: + """Get EBS snapshot information.""" + try: + response = self.ec2.describe_snapshots(SnapshotIds=[snapshot_id]) + if response["Snapshots"]: + snap = response["Snapshots"][0] + tags = {t["Key"]: t["Value"] for t in snap.get("Tags", [])} + return SnapshotInfo( + snapshot_id=snap["SnapshotId"], + volume_id=snap["VolumeId"], + status=snap["State"], + size_gb=snap["VolumeSize"], + created_at=snap["StartTime"].isoformat(), + tags=tags, + ) + except ClientError: + pass + return None + + def list_snapshots( + self, filters: Optional[Dict[str, str]] = None + ) -> List[SnapshotInfo]: + """List EBS snapshots matching filters.""" + aws_filters = [{"Name": "owner-id", "Values": ["self"]}] + if filters: + for key, value in filters.items(): + aws_filters.append({"Name": f"tag:{key}", "Values": [value]}) + + response = self.ec2.describe_snapshots(Filters=aws_filters) + + snapshots = [] + for snap in response.get("Snapshots", []): + tags = {t["Key"]: t["Value"] for t in snap.get("Tags", [])} + snapshots.append( + SnapshotInfo( + snapshot_id=snap["SnapshotId"], + volume_id=snap["VolumeId"], + status=snap["State"], + size_gb=snap["VolumeSize"], + created_at=snap["StartTime"].isoformat(), + tags=tags, + ) + ) + return snapshots + + def wait_for_snapshot( + self, snapshot_id: str, timeout_seconds: int = 600 + ) -> bool: + """Wait for EBS snapshot to complete.""" + try: + waiter = self.ec2.get_waiter("snapshot_completed") + waiter.wait( + SnapshotIds=[snapshot_id], + WaiterConfig={"Delay": 15, "MaxAttempts": timeout_seconds // 15}, + ) + return True + except Exception as e: + logger.error(f"Snapshot {snapshot_id} did not complete: {e}") + return False + + # === Compute === + + def get_nodes_by_gpu_type(self, gpu_type: str) -> List[NodeInfo]: + """Get EC2 instances by GPU type label.""" + # This would typically query K8s nodes, not EC2 directly + # For now, return empty - K8s client handles this + return [] + + def get_node_availability(self) -> Dict[str, Dict[str, int]]: + """Get GPU availability - delegates to K8s.""" + # This is handled by the availability updater + return {} + + # === Object Storage (S3) === + + def upload_to_object_storage( + self, + bucket: str, + key: str, + content: bytes, + metadata: Optional[Dict[str, str]] = None, + ) -> str: + """Upload content to S3.""" + extra_args = {} + if metadata: + extra_args["Metadata"] = metadata + + self.s3.put_object( + Bucket=bucket, + Key=key, + Body=content, + **extra_args, + ) + + return f"s3://{bucket}/{key}" + + def download_from_object_storage( + self, bucket: str, key: str + ) -> Optional[bytes]: + """Download content from S3.""" + try: + response = self.s3.get_object(Bucket=bucket, Key=key) + return response["Body"].read() + except ClientError: + return None + + +class AWSIAMAuthProvider(AuthProvider): + """AWS IAM/STS based authentication (legacy).""" + + def __init__(self, region: str = "us-east-2"): + self.region = region + self._sts = None + + @property + def sts(self): + if self._sts is None: + self._sts = boto3.client("sts", region_name=self.region) + return self._sts + + def verify_token(self, token: str) -> Optional[Dict[str, Any]]: + """Verify AWS credentials (token contains access key info).""" + # This is handled differently - credentials are verified via STS + return None + + def get_user_info(self, token: str) -> Optional[Dict[str, Any]]: + """Get user info from AWS identity.""" + try: + response = self.sts.get_caller_identity() + return { + "user_id": response["UserId"], + "account": response["Account"], + "arn": response["Arn"], + } + except Exception: + return None + + def create_api_key( + self, user_id: str, scopes: List[str], ttl_hours: int = 24 + ) -> str: + """Create API key - handled by API service.""" + raise NotImplementedError("Use API service for API key creation") diff --git a/terraform-gpu-devservers/providers/base.py b/terraform-gpu-devservers/providers/base.py new file mode 100644 index 00000000..9be2d733 --- /dev/null +++ b/terraform-gpu-devservers/providers/base.py @@ -0,0 +1,270 @@ +""" +Abstract base classes for cloud provider interfaces. + +This module defines the abstract interfaces that all cloud providers must implement. +The abstraction allows the GPU reservation system to work with multiple cloud platforms +(AWS, GCP, custom on-prem) without modifying core business logic. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + + +@dataclass +class VolumeInfo: + """Standardized volume information across providers.""" + volume_id: str + size_gb: int + availability_zone: str + status: str # 'available', 'in-use', 'creating', 'deleting' + tags: dict[str, str] + + +@dataclass +class SnapshotInfo: + """Standardized snapshot information across providers.""" + snapshot_id: str + volume_id: str + status: str # 'pending', 'completed', 'error' + size_gb: int + created_at: str # ISO format timestamp + tags: dict[str, str] + + +@dataclass +class NodeInfo: + """Standardized node/instance information across providers.""" + node_id: str + name: str + instance_type: str + availability_zone: str + gpu_type: str | None + gpu_count: int + status: str # 'running', 'stopped', 'terminated' + labels: dict[str, str] + + +class CloudProvider(ABC): + """ + Abstract base class for cloud provider implementations. + + This is the main interface for cloud-specific functionality. + Each cloud implementation provides concrete implementations of + storage, snapshot, compute, and object storage operations. + + Example: + provider = get_cloud_provider() # Returns AWSProvider or GCPProvider + + # Storage operations + volume = provider.create_volume(size_gb=100, availability_zone='us-east-2a') + + # Snapshot operations + snapshot = provider.create_snapshot(volume.volume_id) + + # Object storage + uri = provider.upload_to_object_storage('bucket', 'key', b'content') + """ + + @abstractmethod + def name(self) -> str: + """Provider name (aws, gcp, custom).""" + pass + + # === Block Storage === + + @abstractmethod + def create_volume( + self, + size_gb: int, + availability_zone: str, + volume_type: str = "ssd", + tags: dict[str, str] | None = None, + snapshot_id: str | None = None, + ) -> VolumeInfo: + """ + Create a block storage volume. + + Args: + size_gb: Volume size in gigabytes + availability_zone: Zone for volume placement + volume_type: Storage class (ssd, hdd, io) + tags: Key-value tags for the volume + snapshot_id: Create volume from snapshot + + Returns: + VolumeInfo with created volume details + """ + pass + + @abstractmethod + def delete_volume(self, volume_id: str) -> bool: + """Delete a block storage volume.""" + pass + + @abstractmethod + def attach_volume( + self, volume_id: str, instance_id: str, device_path: str + ) -> bool: + """Attach volume to instance.""" + pass + + @abstractmethod + def detach_volume(self, volume_id: str) -> bool: + """Detach volume from instance.""" + pass + + @abstractmethod + def get_volume(self, volume_id: str) -> VolumeInfo | None: + """Get volume information.""" + pass + + @abstractmethod + def list_volumes( + self, filters: dict[str, str] | None = None + ) -> list[VolumeInfo]: + """List volumes matching filters (by tags).""" + pass + + # === Snapshots === + + @abstractmethod + def create_snapshot( + self, + volume_id: str, + description: str = "", + tags: dict[str, str] | None = None, + ) -> SnapshotInfo: + """Create a snapshot of a volume.""" + pass + + @abstractmethod + def delete_snapshot(self, snapshot_id: str) -> bool: + """Delete a snapshot.""" + pass + + @abstractmethod + def get_snapshot(self, snapshot_id: str) -> SnapshotInfo | None: + """Get snapshot information.""" + pass + + @abstractmethod + def list_snapshots( + self, filters: dict[str, str] | None = None + ) -> list[SnapshotInfo]: + """List snapshots matching filters (by tags).""" + pass + + @abstractmethod + def wait_for_snapshot( + self, snapshot_id: str, timeout_seconds: int = 600 + ) -> bool: + """Wait for snapshot to complete.""" + pass + + # === Compute === + + @abstractmethod + def get_nodes_by_gpu_type(self, gpu_type: str) -> list[NodeInfo]: + """Get nodes/instances by GPU type.""" + pass + + @abstractmethod + def get_node_availability(self) -> dict[str, dict[str, int]]: + """Get GPU availability by type.""" + pass + + # === Object Storage === + + @abstractmethod + def upload_to_object_storage( + self, + bucket: str, + key: str, + content: bytes, + metadata: dict[str, str] | None = None, + ) -> str: + """Upload content to object storage. Returns URI.""" + pass + + @abstractmethod + def download_from_object_storage( + self, bucket: str, key: str + ) -> bytes | None: + """Download content from object storage.""" + pass + + +class AuthProvider(ABC): + """ + Abstract interface for identity verification. + + Used for authenticating API requests and verifying user identity. + """ + + @abstractmethod + def verify_token(self, token: str) -> dict[str, Any] | None: + """ + Verify an authentication token. + + Returns user info dict if valid, None if invalid. + """ + pass + + @abstractmethod + def get_user_info(self, token: str) -> dict[str, Any] | None: + """Get user information from token.""" + pass + + @abstractmethod + def create_api_key( + self, user_id: str, scopes: list[str], ttl_hours: int = 24 + ) -> str: + """Create an API key for a user.""" + pass + + +class ProviderError(Exception): + """Base exception for provider errors.""" + + def __init__( + self, + message: str, + provider: str = "unknown", + operation: str = "unknown", + details: dict | None = None + ): + self.provider = provider + self.operation = operation + self.details = details or {} + super().__init__(f"[{provider}] {operation}: {message}") + + +class VolumeNotFoundError(ProviderError): + """Volume does not exist.""" + pass + + +class VolumeInUseError(ProviderError): + """Volume is attached and cannot be modified.""" + pass + + +class SnapshotNotFoundError(ProviderError): + """Snapshot does not exist.""" + pass + + +class QuotaExceededError(ProviderError): + """Resource quota exceeded.""" + pass + + +class AuthenticationError(ProviderError): + """Authentication failed.""" + pass + + +class AuthorizationError(ProviderError): + """User not authorized for operation.""" + pass diff --git a/terraform-gpu-devservers/providers/custom.py b/terraform-gpu-devservers/providers/custom.py new file mode 100644 index 00000000..87deacc1 --- /dev/null +++ b/terraform-gpu-devservers/providers/custom.py @@ -0,0 +1,403 @@ +""" +Custom Cloud Provider Template + +This module provides a template for implementing custom providers for: +- On-premises data centers +- Private clouds (OpenStack, VMware vSphere) +- Alternative cloud providers (DigitalOcean, Linode, etc.) +- Hybrid environments + +IMPLEMENTATION GUIDE +==================== + +1. Copy this file and rename it for your environment +2. Implement each abstract method +3. Register your provider in __init__.py +4. Set CLOUD_PROVIDER environment variable + +STORAGE INTEGRATION PATTERNS +============================ + +LVM (Linux Volume Manager): + - Create logical volumes in volume groups + - Use thin provisioning for snapshots + - Mount via device mapper + +Ceph RBD: + - Create RBD images in pools + - Use librbd or rbd CLI + - Map via rbd-nbd or krbd + +iSCSI: + - Create LUNs on storage array + - Map to initiator + - Discover and login to targets + +NFS: + - Create directories/quotas on NFS server + - Export via /etc/exports or storage array + +OBJECT STORAGE PATTERNS +======================= + +MinIO: + - S3-compatible API + - Use boto3 with custom endpoint + +Ceph RadosGW: + - S3-compatible API + - Use boto3 with custom endpoint + +Local filesystem: + - Use local directory as object store + - Simple for testing +""" + +import logging +import os +from typing import Any + +from .base import ( + AuthProvider, + CloudProvider, + NodeInfo, + SnapshotInfo, + VolumeInfo, +) + +logger = logging.getLogger(__name__) + + +class CustomProvider(CloudProvider): + """ + Template for custom cloud provider implementations. + + To implement: + 1. Replace each NotImplementedError with actual implementation + 2. Configure via environment variables + 3. Add any additional helper methods needed + """ + + def __init__(self): + # Configuration from environment + self.storage_backend = os.environ.get("CUSTOM_STORAGE_BACKEND", "lvm") + self.object_store_path = os.environ.get("CUSTOM_OBJECT_STORE", "/var/lib/gpu-dev/objects") + + def name(self) -> str: + return "custom" + + # === Block Storage === + + def create_volume( + self, + size_gb: int, + availability_zone: str, + volume_type: str = "ssd", + tags: dict[str, str] | None = None, + snapshot_id: str | None = None, + ) -> VolumeInfo: + """ + Create a block storage volume. + + Example LVM implementation: + import subprocess + import uuid + vol_name = f"gpudev-{uuid.uuid4().hex[:8]}" + cmd = ['lvcreate', '-L', f'{size_gb}G', '-n', vol_name, 'vg_gpudev'] + if snapshot_id: + cmd.extend(['--snapshot', snapshot_id]) + subprocess.run(cmd, check=True) + return VolumeInfo(volume_id=vol_name, ...) + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement create_volume() for your storage backend." + ) + + def delete_volume(self, volume_id: str) -> bool: + """ + Delete a block storage volume. + + Example LVM implementation: + subprocess.run(['lvremove', '-f', f'vg_gpudev/{volume_id}'], check=True) + return True + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement delete_volume() for your storage backend." + ) + + def attach_volume( + self, volume_id: str, instance_id: str, device_path: str + ) -> bool: + """ + Attach volume to instance. + + For Kubernetes-based workloads, this typically means: + 1. Make the volume accessible on the node (iSCSI login, RBD map, etc.) + 2. Create a PersistentVolume pointing to the device + 3. Let Kubernetes handle the pod mounting + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement attach_volume() for your storage backend." + ) + + def detach_volume(self, volume_id: str) -> bool: + """ + Detach volume from instance. + + Ensure the volume is properly unmounted before detaching. + For iSCSI: logout from target + For RBD: unmap the device + For NFS: unmount the share + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement detach_volume() for your storage backend." + ) + + def get_volume(self, volume_id: str) -> VolumeInfo | None: + """ + Get volume information. + + Example LVM implementation: + result = subprocess.run( + ['lvs', '--noheadings', '-o', 'lv_size,lv_attr', f'vg_gpudev/{volume_id}'], + capture_output=True, text=True + ) + if result.returncode != 0: + return None + # Parse output and return VolumeInfo + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement get_volume() for your storage backend." + ) + + def list_volumes( + self, filters: dict[str, str] | None = None + ) -> list[VolumeInfo]: + """ + List volumes matching filters. + + Note: For backends without native tagging, store tags in a local database + or use naming conventions to encode metadata. + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement list_volumes() for your storage backend." + ) + + # === Snapshots === + + def create_snapshot( + self, + volume_id: str, + description: str = "", + tags: dict[str, str] | None = None, + ) -> SnapshotInfo: + """ + Create a snapshot of a volume. + + Example LVM implementation: + import uuid + snap_name = f"snap-{uuid.uuid4().hex[:8]}" + subprocess.run([ + 'lvcreate', '--snapshot', + '-L', '10G', # COW pool size + '-n', snap_name, + f'vg_gpudev/{volume_id}' + ], check=True) + return SnapshotInfo(snapshot_id=snap_name, ...) + """ + raise NotImplementedError( + f"Custom snapshots ({self.storage_backend}) not implemented. " + "Implement create_snapshot() for your storage backend." + ) + + def delete_snapshot(self, snapshot_id: str) -> bool: + """Delete a snapshot.""" + raise NotImplementedError( + f"Custom snapshots ({self.storage_backend}) not implemented. " + "Implement delete_snapshot() for your storage backend." + ) + + def get_snapshot(self, snapshot_id: str) -> SnapshotInfo | None: + """Get snapshot information.""" + raise NotImplementedError( + f"Custom snapshots ({self.storage_backend}) not implemented. " + "Implement get_snapshot() for your storage backend." + ) + + def list_snapshots( + self, filters: dict[str, str] | None = None + ) -> list[SnapshotInfo]: + """List snapshots matching filters.""" + raise NotImplementedError( + f"Custom snapshots ({self.storage_backend}) not implemented. " + "Implement list_snapshots() for your storage backend." + ) + + def wait_for_snapshot( + self, snapshot_id: str, timeout_seconds: int = 600 + ) -> bool: + """ + Wait for snapshot to complete. + + For LVM/ZFS snapshots, this is typically instant. + For storage arrays, poll the API until complete. + """ + raise NotImplementedError( + f"Custom snapshots ({self.storage_backend}) not implemented. " + "Implement wait_for_snapshot() for your storage backend." + ) + + # === Compute === + + def get_nodes_by_gpu_type(self, gpu_type: str) -> list[NodeInfo]: + """ + Get nodes/instances by GPU type. + + For Kubernetes-based deployments, query K8s API: + from kubernetes import client + v1 = client.CoreV1Api() + nodes = v1.list_node(label_selector=f'gpu-type={gpu_type}') + """ + raise NotImplementedError( + "Custom compute not implemented. " + "Query via Kubernetes API instead." + ) + + def get_node_availability(self) -> dict[str, dict[str, int]]: + """ + Get GPU availability by type. + + This is typically handled by the availability-updater-service + which queries Kubernetes for GPU allocations. + """ + raise NotImplementedError( + "Handled by availability updater service." + ) + + # === Object Storage === + + def upload_to_object_storage( + self, + bucket: str, + key: str, + content: bytes, + metadata: dict[str, str] | None = None, + ) -> str: + """ + Upload content to object storage. + + Example MinIO/S3-compatible implementation: + import boto3 + s3 = boto3.client('s3', endpoint_url=os.environ['MINIO_ENDPOINT']) + s3.put_object(Bucket=bucket, Key=key, Body=content, Metadata=metadata or {}) + return f's3://{bucket}/{key}' + + Example filesystem implementation: + import os + path = os.path.join(self.object_store_path, bucket, key) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'wb') as f: + f.write(content) + return f'file://{path}' + """ + raise NotImplementedError( + "Custom object storage not implemented. " + "Implement upload_to_object_storage() for your storage backend." + ) + + def download_from_object_storage( + self, bucket: str, key: str + ) -> bytes | None: + """ + Download content from object storage. + + Example filesystem implementation: + path = os.path.join(self.object_store_path, bucket, key) + if os.path.exists(path): + with open(path, 'rb') as f: + return f.read() + return None + """ + raise NotImplementedError( + "Custom object storage not implemented. " + "Implement download_from_object_storage() for your storage backend." + ) + + +class CustomAuthProvider(AuthProvider): + """ + Template for custom authentication provider. + + Common patterns: + + LDAP/Active Directory: + from ldap3 import Server, Connection, ALL + server = Server('ldap://ad.example.com', get_info=ALL) + conn = Connection(server, user=bind_dn, password=bind_pw) + conn.bind() + conn.search('dc=example,dc=com', f'(uid={username})', attributes=['memberOf']) + + OIDC (Keycloak, Okta): + from jose import jwt + payload = jwt.decode(token, key, algorithms=['RS256'], audience='gpu-dev') + return {'user_id': payload['sub'], 'email': payload['email'], ...} + + SAML: + from onelogin.saml2.auth import OneLogin_Saml2_Auth + auth = OneLogin_Saml2_Auth(request_data, saml_settings) + auth.process_response() + return {'user_id': auth.get_nameid(), ...} + """ + + def __init__(self): + self.backend = os.environ.get("CUSTOM_AUTH_BACKEND", "oidc") + + def verify_token(self, token: str) -> dict[str, Any] | None: + """ + Verify authentication token. + + Example OIDC implementation: + from jose import jwt + try: + payload = jwt.decode( + token, + self.public_key, + algorithms=['RS256'], + audience=os.environ.get('OIDC_AUDIENCE') + ) + return { + 'user_id': payload['sub'], + 'email': payload.get('email'), + 'groups': payload.get('groups', []) + } + except jwt.JWTError: + return None + """ + raise NotImplementedError( + f"Custom auth ({self.backend}) not implemented. " + "Implement verify_token() for your auth backend." + ) + + def get_user_info(self, token: str) -> dict[str, Any] | None: + """Get user information from token.""" + raise NotImplementedError( + f"Custom auth ({self.backend}) not implemented. " + "Implement get_user_info() for your auth backend." + ) + + def create_api_key( + self, user_id: str, scopes: list[str], ttl_hours: int = 24 + ) -> str: + """ + Create an API key for a user. + + This is typically handled by the API service using database-backed + API keys rather than cloud provider tokens. + """ + raise NotImplementedError("Use API service for API key creation") diff --git a/terraform-gpu-devservers/providers/gcp.py b/terraform-gpu-devservers/providers/gcp.py new file mode 100644 index 00000000..c087c785 --- /dev/null +++ b/terraform-gpu-devservers/providers/gcp.py @@ -0,0 +1,187 @@ +""" +GCP Cloud Provider Implementation (Stub) + +This is a template for GCP support. Implement the methods +using Google Cloud SDK. +""" + +import logging +from typing import Dict, List, Optional, Any + +from .base import ( + CloudProvider, + VolumeInfo, + SnapshotInfo, + NodeInfo, +) + +logger = logging.getLogger(__name__) + + +class GCPProvider(CloudProvider): + """ + GCP implementation of CloudProvider interface. + + TODO: Implement using google-cloud-compute SDK + """ + + def __init__(self, project: str, zone: str): + self.project = project + self.zone = zone + self.region = zone.rsplit("-", 1)[0] # us-central1-a -> us-central1 + + # Initialize GCP clients (uncomment when implementing) + # from google.cloud import compute_v1 + # self.disks_client = compute_v1.DisksClient() + # self.snapshots_client = compute_v1.SnapshotsClient() + # self.instances_client = compute_v1.InstancesClient() + + def name(self) -> str: + return "gcp" + + # === Block Storage (GCE Persistent Disk) === + + def create_volume( + self, + size_gb: int, + availability_zone: str, + volume_type: str = "ssd", + tags: Optional[Dict[str, str]] = None, + snapshot_id: Optional[str] = None, + ) -> VolumeInfo: + """ + Create a GCE Persistent Disk. + + volume_type mapping: + - ssd -> pd-ssd + - hdd -> pd-standard + - balanced -> pd-balanced + """ + raise NotImplementedError( + "GCP volume creation not implemented. " + "Use google.cloud.compute_v1.DisksClient.insert()" + ) + + def delete_volume(self, volume_id: str) -> bool: + """Delete a GCE Persistent Disk.""" + raise NotImplementedError( + "GCP volume deletion not implemented. " + "Use google.cloud.compute_v1.DisksClient.delete()" + ) + + def attach_volume( + self, volume_id: str, instance_id: str, device_path: str + ) -> bool: + """Attach disk to GCE instance.""" + raise NotImplementedError( + "GCP volume attachment not implemented. " + "Use google.cloud.compute_v1.InstancesClient.attach_disk()" + ) + + def detach_volume(self, volume_id: str) -> bool: + """Detach disk from GCE instance.""" + raise NotImplementedError( + "GCP volume detachment not implemented. " + "Use google.cloud.compute_v1.InstancesClient.detach_disk()" + ) + + def get_volume(self, volume_id: str) -> Optional[VolumeInfo]: + """Get disk information.""" + raise NotImplementedError( + "GCP volume get not implemented. " + "Use google.cloud.compute_v1.DisksClient.get()" + ) + + def list_volumes( + self, filters: Optional[Dict[str, str]] = None + ) -> List[VolumeInfo]: + """List disks matching filters (labels in GCP).""" + raise NotImplementedError( + "GCP volume list not implemented. " + "Use google.cloud.compute_v1.DisksClient.list()" + ) + + # === Snapshots === + + def create_snapshot( + self, + volume_id: str, + description: str = "", + tags: Optional[Dict[str, str]] = None, + ) -> SnapshotInfo: + """Create disk snapshot.""" + raise NotImplementedError( + "GCP snapshot creation not implemented. " + "Use google.cloud.compute_v1.SnapshotsClient.insert()" + ) + + def delete_snapshot(self, snapshot_id: str) -> bool: + """Delete snapshot.""" + raise NotImplementedError( + "GCP snapshot deletion not implemented. " + "Use google.cloud.compute_v1.SnapshotsClient.delete()" + ) + + def get_snapshot(self, snapshot_id: str) -> Optional[SnapshotInfo]: + """Get snapshot information.""" + raise NotImplementedError( + "GCP snapshot get not implemented. " + "Use google.cloud.compute_v1.SnapshotsClient.get()" + ) + + def list_snapshots( + self, filters: Optional[Dict[str, str]] = None + ) -> List[SnapshotInfo]: + """List snapshots.""" + raise NotImplementedError( + "GCP snapshot list not implemented. " + "Use google.cloud.compute_v1.SnapshotsClient.list()" + ) + + def wait_for_snapshot( + self, snapshot_id: str, timeout_seconds: int = 600 + ) -> bool: + """Wait for snapshot to complete.""" + raise NotImplementedError( + "GCP snapshot wait not implemented. " + "Poll SnapshotsClient.get() until status is READY" + ) + + # === Compute === + + def get_nodes_by_gpu_type(self, gpu_type: str) -> List[NodeInfo]: + """Get GCE instances by GPU type.""" + raise NotImplementedError( + "GCP node listing not implemented. " + "Query via Kubernetes API instead." + ) + + def get_node_availability(self) -> Dict[str, Dict[str, int]]: + """Get GPU availability.""" + raise NotImplementedError( + "Handled by availability updater service." + ) + + # === Object Storage (GCS) === + + def upload_to_object_storage( + self, + bucket: str, + key: str, + content: bytes, + metadata: Optional[Dict[str, str]] = None, + ) -> str: + """Upload to Google Cloud Storage.""" + raise NotImplementedError( + "GCS upload not implemented. " + "Use google.cloud.storage.Client().bucket().blob().upload_from_string()" + ) + + def download_from_object_storage( + self, bucket: str, key: str + ) -> Optional[bytes]: + """Download from Google Cloud Storage.""" + raise NotImplementedError( + "GCS download not implemented. " + "Use google.cloud.storage.Client().bucket().blob().download_as_bytes()" + ) diff --git a/terraform-gpu-devservers/shared/auth/__init__.py b/terraform-gpu-devservers/shared/auth/__init__.py new file mode 100644 index 00000000..26f4f758 --- /dev/null +++ b/terraform-gpu-devservers/shared/auth/__init__.py @@ -0,0 +1,70 @@ +""" +OIDC Authentication Module + +Provides cloud-agnostic authentication using OpenID Connect (OIDC) tokens +from multiple identity providers (GitHub, Google, Okta, etc.). + +Components: +- oidc: JWT token verification with JWKS caching +- api_keys: API key management from OIDC identities +- audit: User action and token usage logging +""" + +from .oidc import ( + OIDCVerifier, + OIDCProvider, + OIDCIdentity, + OIDCVerificationError, + JWKSFetchError, + TokenExpiredError, + InvalidIssuerError, + InvalidAudienceError, +) + +from .api_keys import ( + APIKeyManager, + APIKeyInfo, + create_api_key_from_oidc, + validate_api_key, + revoke_api_key, + get_user_api_keys, + cleanup_expired_keys, +) + +from .audit import ( + AuditLogger, + AuditEvent, + AuditEventType, + log_user_action, + log_token_usage, + get_user_audit_log, + get_resource_audit_log, +) + +__all__ = [ + # OIDC + "OIDCVerifier", + "OIDCProvider", + "OIDCIdentity", + "OIDCVerificationError", + "JWKSFetchError", + "TokenExpiredError", + "InvalidIssuerError", + "InvalidAudienceError", + # API Keys + "APIKeyManager", + "APIKeyInfo", + "create_api_key_from_oidc", + "validate_api_key", + "revoke_api_key", + "get_user_api_keys", + "cleanup_expired_keys", + # Audit + "AuditLogger", + "AuditEvent", + "AuditEventType", + "log_user_action", + "log_token_usage", + "get_user_audit_log", + "get_resource_audit_log", +] diff --git a/terraform-gpu-devservers/shared/auth/api_keys.py b/terraform-gpu-devservers/shared/auth/api_keys.py new file mode 100644 index 00000000..ec1e9620 --- /dev/null +++ b/terraform-gpu-devservers/shared/auth/api_keys.py @@ -0,0 +1,564 @@ +""" +API Key Management Module + +Creates and validates API keys from OIDC identities. +Tracks key usage for auditing and supports key revocation. + +Usage: + # Create key from OIDC identity + key_info = await create_api_key_from_oidc(identity, conn, ttl_hours=2) + print(f"API Key: {key_info.key}") + + # Validate key + user_info = await validate_api_key(api_key, conn) +""" + +import hashlib +import logging +import secrets +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from typing import Any + +import asyncpg + +from .oidc import OIDCIdentity + +logger = logging.getLogger(__name__) + +# API key length (64 bytes = 86 base64 characters) +API_KEY_LENGTH = 64 + +# Default TTL for API keys +DEFAULT_TTL_HOURS = 2 + +# Key prefix length for identification +KEY_PREFIX_LENGTH = 8 + + +@dataclass +class APIKeyInfo: + """ + Information about a created API key. + + Attributes: + key: The API key (only available at creation time) + key_id: Database ID of the key + key_prefix: First 8 characters for identification + user_id: Owner user ID + username: Owner username + expires_at: Key expiration timestamp + created_at: Key creation timestamp + """ + key: str + key_id: int + key_prefix: str + user_id: int + username: str + expires_at: datetime + created_at: datetime + + def to_response(self) -> dict[str, Any]: + """Convert to API response format.""" + return { + "api_key": self.key, + "key_prefix": self.key_prefix, + "user_id": self.user_id, + "username": self.username, + "expires_at": self.expires_at.isoformat(), + "created_at": self.created_at.isoformat(), + } + + +@dataclass +class UserInfo: + """ + User information from API key validation. + + Attributes: + user_id: Database user ID + username: Username + email: User email + oidc_subject: OIDC subject identifier + oidc_issuer: OIDC issuer URL + """ + user_id: int + username: str + email: str | None + oidc_subject: str | None + oidc_issuer: str | None + + +def hash_api_key(api_key: str) -> str: + """ + Hash API key for secure storage. + + Uses SHA-256 for consistent, irreversible hashing. + """ + return hashlib.sha256(api_key.encode()).hexdigest() + + +def generate_api_key() -> tuple[str, str]: + """ + Generate a new API key and its prefix. + + Returns: + Tuple of (full_key, prefix) + """ + key = secrets.token_urlsafe(API_KEY_LENGTH) + prefix = key[:KEY_PREFIX_LENGTH] + return key, prefix + + +async def get_or_create_user( + conn: asyncpg.Connection, + identity: OIDCIdentity, +) -> int: + """ + Get existing user or create new one from OIDC identity. + + Links OIDC identity to user via oidc_subject and oidc_issuer columns. + + Args: + conn: Database connection + identity: Verified OIDC identity + + Returns: + User ID + """ + # Try to find user by OIDC identity + existing = await conn.fetchrow( + """ + SELECT user_id FROM api_users + WHERE oidc_subject = $1 AND oidc_issuer = $2 + """, + identity.subject, + identity.issuer + ) + + if existing: + logger.debug( + f"Found existing user {existing['user_id']} for {identity.display_name}" + ) + return existing['user_id'] + + # Try to find by email (for linking existing accounts) + if identity.email: + existing = await conn.fetchrow( + """ + SELECT user_id FROM api_users + WHERE email = $1 AND oidc_subject IS NULL + """, + identity.email + ) + + if existing: + # Link OIDC identity to existing user + await conn.execute( + """ + UPDATE api_users + SET oidc_subject = $1, oidc_issuer = $2 + WHERE user_id = $3 + """, + identity.subject, + identity.issuer, + existing['user_id'] + ) + logger.info( + f"Linked OIDC identity to existing user {existing['user_id']}" + ) + return existing['user_id'] + + # Create new user + username = identity.username or identity.email or identity.subject + # Ensure username uniqueness by appending subject hash if needed + base_username = username[:200] + + user_id = await conn.fetchval( + """ + INSERT INTO api_users (username, email, oidc_subject, oidc_issuer) + VALUES ($1, $2, $3, $4) + ON CONFLICT (username) DO UPDATE + SET username = EXCLUDED.username || '-' || substr(md5(EXCLUDED.oidc_subject), 1, 8) + RETURNING user_id + """, + base_username, + identity.email, + identity.subject, + identity.issuer + ) + + logger.info(f"Created new user {user_id} for {identity.display_name}") + return user_id + + +async def create_api_key_from_oidc( + identity: OIDCIdentity, + conn: asyncpg.Connection, + ttl_hours: int = DEFAULT_TTL_HOURS, + description: str | None = None, +) -> APIKeyInfo: + """ + Create a new API key from verified OIDC identity. + + Args: + identity: Verified OIDC identity + conn: Database connection + ttl_hours: Key time-to-live in hours (default: 2) + description: Optional key description + + Returns: + APIKeyInfo with the new key (key only available here) + """ + # Get or create user + user_id = await get_or_create_user(conn, identity) + + # Generate key + api_key, key_prefix = generate_api_key() + key_hash = hash_api_key(api_key) + + # Calculate expiration + now = datetime.now(UTC) + expires_at = now + timedelta(hours=ttl_hours) + + # Build description + if not description: + description = f"OIDC key from {identity.provider_name}" + + # Store key + key_id = await conn.fetchval( + """ + INSERT INTO api_keys ( + user_id, key_hash, key_prefix, expires_at, description + ) VALUES ($1, $2, $3, $4, $5) + RETURNING key_id + """, + user_id, + key_hash, + key_prefix, + expires_at, + description + ) + + # Get username for response + username = await conn.fetchval( + "SELECT username FROM api_users WHERE user_id = $1", + user_id + ) + + logger.info( + f"Created API key {key_prefix}... for user {username} " + f"(expires: {expires_at.isoformat()})" + ) + + return APIKeyInfo( + key=api_key, + key_id=key_id, + key_prefix=key_prefix, + user_id=user_id, + username=username, + expires_at=expires_at, + created_at=now, + ) + + +async def validate_api_key( + api_key: str, + conn: asyncpg.Connection, + update_last_used: bool = True, +) -> UserInfo: + """ + Validate an API key and return user information. + + Args: + api_key: The API key to validate + conn: Database connection + update_last_used: Whether to update last_used_at timestamp + + Returns: + UserInfo for the key owner + + Raises: + ValueError: If key is invalid, expired, or revoked + """ + # Basic format validation + if not api_key or len(api_key) < 16 or len(api_key) > 256: + raise ValueError("Invalid API key format") + + key_hash = hash_api_key(api_key) + + # Look up key and user + row = await conn.fetchrow( + """ + SELECT + u.user_id, u.username, u.email, u.is_active as user_active, + u.oidc_subject, u.oidc_issuer, + k.key_id, k.expires_at, k.is_active as key_active + FROM api_keys k + JOIN api_users u ON k.user_id = u.user_id + WHERE k.key_hash = $1 + """, + key_hash + ) + + if not row: + raise ValueError("Invalid API key") + + # Check user status + if not row['user_active']: + raise ValueError("User account is disabled") + + # Check key status + if not row['key_active']: + raise ValueError("API key has been revoked") + + # Check expiration + expires_at = row['expires_at'] + if expires_at: + # Handle timezone-aware comparison + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + else: + expires_at = expires_at.astimezone(UTC) + + if expires_at < datetime.now(UTC): + raise ValueError("API key has expired") + + # Update last used timestamp + if update_last_used: + await conn.execute( + """ + UPDATE api_keys + SET last_used_at = CURRENT_TIMESTAMP + WHERE key_id = $1 + """, + row['key_id'] + ) + + return UserInfo( + user_id=row['user_id'], + username=row['username'], + email=row['email'], + oidc_subject=row['oidc_subject'], + oidc_issuer=row['oidc_issuer'], + ) + + +async def revoke_api_key( + key_prefix: str, + user_id: int, + conn: asyncpg.Connection, +) -> bool: + """ + Revoke an API key by prefix. + + Args: + key_prefix: Key prefix to revoke + user_id: Owner user ID (for authorization) + conn: Database connection + + Returns: + True if key was revoked, False if not found + """ + result = await conn.execute( + """ + UPDATE api_keys + SET is_active = false + WHERE key_prefix = $1 AND user_id = $2 AND is_active = true + """, + key_prefix, + user_id + ) + + # asyncpg returns "UPDATE N" where N is affected rows + affected = int(result.split()[-1]) + + if affected > 0: + logger.info(f"Revoked API key {key_prefix}... for user {user_id}") + return True + + return False + + +async def revoke_all_user_keys( + user_id: int, + conn: asyncpg.Connection, +) -> int: + """ + Revoke all API keys for a user. + + Args: + user_id: User ID + conn: Database connection + + Returns: + Number of keys revoked + """ + result = await conn.execute( + """ + UPDATE api_keys + SET is_active = false + WHERE user_id = $1 AND is_active = true + """, + user_id + ) + + affected = int(result.split()[-1]) + + if affected > 0: + logger.info(f"Revoked {affected} API keys for user {user_id}") + + return affected + + +async def get_user_api_keys( + user_id: int, + conn: asyncpg.Connection, + include_revoked: bool = False, +) -> list[dict[str, Any]]: + """ + List API keys for a user. + + Args: + user_id: User ID + conn: Database connection + include_revoked: Include revoked keys + + Returns: + List of key info dicts (without actual key values) + """ + where_clause = "WHERE user_id = $1" + if not include_revoked: + where_clause += " AND is_active = true" + + rows = await conn.fetch( + f""" + SELECT + key_id, key_prefix, created_at, expires_at, + last_used_at, is_active, description + FROM api_keys + {where_clause} + ORDER BY created_at DESC + """, + user_id + ) + + return [ + { + "key_id": row['key_id'], + "key_prefix": row['key_prefix'], + "created_at": row['created_at'].isoformat() if row['created_at'] else None, + "expires_at": row['expires_at'].isoformat() if row['expires_at'] else None, + "last_used_at": row['last_used_at'].isoformat() if row['last_used_at'] else None, + "is_active": row['is_active'], + "description": row['description'], + } + for row in rows + ] + + +async def cleanup_expired_keys( + conn: asyncpg.Connection, + older_than_hours: int = 24, +) -> int: + """ + Delete expired API keys older than specified hours. + + Args: + conn: Database connection + older_than_hours: Delete keys expired longer than this + + Returns: + Number of keys deleted + """ + cutoff = datetime.now(UTC) - timedelta(hours=older_than_hours) + + result = await conn.execute( + """ + DELETE FROM api_keys + WHERE expires_at < $1 + """, + cutoff + ) + + deleted = int(result.split()[-1]) + + if deleted > 0: + logger.info( + f"Cleaned up {deleted} expired API keys " + f"(older than {older_than_hours} hours)" + ) + + return deleted + + +class APIKeyManager: + """ + High-level API key management. + + Provides a convenient interface for key operations with connection pooling. + + Usage: + manager = APIKeyManager(pool) + key_info = await manager.create_from_oidc(identity) + user_info = await manager.validate(api_key) + """ + + def __init__(self, pool: asyncpg.Pool, default_ttl_hours: int = DEFAULT_TTL_HOURS): + """ + Initialize API key manager. + + Args: + pool: asyncpg connection pool + default_ttl_hours: Default TTL for new keys + """ + self._pool = pool + self._default_ttl = default_ttl_hours + + async def create_from_oidc( + self, + identity: OIDCIdentity, + ttl_hours: int | None = None, + description: str | None = None, + ) -> APIKeyInfo: + """Create API key from OIDC identity.""" + async with self._pool.acquire() as conn: + return await create_api_key_from_oidc( + identity, + conn, + ttl_hours=ttl_hours or self._default_ttl, + description=description, + ) + + async def validate( + self, + api_key: str, + update_last_used: bool = True, + ) -> UserInfo: + """Validate API key and return user info.""" + async with self._pool.acquire() as conn: + return await validate_api_key(api_key, conn, update_last_used) + + async def revoke(self, key_prefix: str, user_id: int) -> bool: + """Revoke an API key.""" + async with self._pool.acquire() as conn: + return await revoke_api_key(key_prefix, user_id, conn) + + async def revoke_all(self, user_id: int) -> int: + """Revoke all keys for a user.""" + async with self._pool.acquire() as conn: + return await revoke_all_user_keys(user_id, conn) + + async def list_keys( + self, + user_id: int, + include_revoked: bool = False, + ) -> list[dict[str, Any]]: + """List user's API keys.""" + async with self._pool.acquire() as conn: + return await get_user_api_keys(user_id, conn, include_revoked) + + async def cleanup(self, older_than_hours: int = 24) -> int: + """Clean up expired keys.""" + async with self._pool.acquire() as conn: + return await cleanup_expired_keys(conn, older_than_hours) diff --git a/terraform-gpu-devservers/shared/auth/audit.py b/terraform-gpu-devservers/shared/auth/audit.py new file mode 100644 index 00000000..9a790569 --- /dev/null +++ b/terraform-gpu-devservers/shared/auth/audit.py @@ -0,0 +1,710 @@ +""" +Audit Logging Module + +Logs all user actions and tracks Bedrock/Claude token usage for traceability. +Stores audit events in PostgreSQL for querying and compliance. + +Usage: + logger = AuditLogger(pool) + await logger.log_action( + user_id=123, + action="reservation.create", + resource_type="reservation", + resource_id="abc-123", + details={"gpu_type": "h100", "gpu_count": 4} + ) + + # Query user's actions + events = await logger.get_user_history(user_id=123, limit=50) +""" + +import json +import logging +from dataclasses import dataclass, field +from datetime import UTC, datetime, timedelta +from enum import Enum +from typing import Any + +import asyncpg + +logger = logging.getLogger(__name__) + + +class AuditEventType(str, Enum): + """Types of auditable events.""" + + # Authentication events + AUTH_LOGIN = "auth.login" + AUTH_LOGOUT = "auth.logout" + AUTH_KEY_CREATE = "auth.key_create" + AUTH_KEY_REVOKE = "auth.key_revoke" + AUTH_FAILED = "auth.failed" + + # Reservation events + RESERVATION_CREATE = "reservation.create" + RESERVATION_CANCEL = "reservation.cancel" + RESERVATION_EXTEND = "reservation.extend" + RESERVATION_EXPIRE = "reservation.expire" + RESERVATION_ADD_USER = "reservation.add_user" + + # Disk events + DISK_CREATE = "disk.create" + DISK_DELETE = "disk.delete" + DISK_ATTACH = "disk.attach" + DISK_DETACH = "disk.detach" + DISK_RENAME = "disk.rename" + + # LLM/AI events + LLM_REQUEST = "llm.request" + LLM_RESPONSE = "llm.response" + LLM_ERROR = "llm.error" + + # Admin events + ADMIN_USER_DISABLE = "admin.user_disable" + ADMIN_USER_ENABLE = "admin.user_enable" + ADMIN_CONFIG_CHANGE = "admin.config_change" + + # System events + SYSTEM_ERROR = "system.error" + SYSTEM_WARNING = "system.warning" + + +@dataclass +class AuditEvent: + """ + Represents an audit log event. + + Attributes: + event_id: Unique event identifier (assigned by database) + user_id: User who performed the action + username: Username for display + event_type: Type of event + resource_type: Type of resource affected (reservation, disk, etc.) + resource_id: ID of affected resource + action: Human-readable action description + details: Additional event details (JSON) + ip_address: Client IP address + user_agent: Client user agent + created_at: Event timestamp + """ + event_id: int | None + user_id: int | None + username: str | None + event_type: AuditEventType + resource_type: str | None + resource_id: str | None + action: str + details: dict[str, Any] = field(default_factory=dict) + ip_address: str | None = None + user_agent: str | None = None + created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "event_id": self.event_id, + "user_id": self.user_id, + "username": self.username, + "event_type": self.event_type.value, + "resource_type": self.resource_type, + "resource_id": self.resource_id, + "action": self.action, + "details": self.details, + "ip_address": self.ip_address, + "user_agent": self.user_agent, + "created_at": self.created_at.isoformat(), + } + + +@dataclass +class TokenUsage: + """ + Tracks LLM token usage for billing and monitoring. + + Attributes: + usage_id: Unique usage record ID + user_id: User who made the request + model: LLM model name (e.g., "claude-3-opus") + input_tokens: Number of input tokens + output_tokens: Number of output tokens + total_tokens: Total tokens used + cost_usd: Estimated cost in USD + request_id: Associated request ID + created_at: Usage timestamp + """ + usage_id: int | None + user_id: int + model: str + input_tokens: int + output_tokens: int + total_tokens: int + cost_usd: float | None + request_id: str | None + created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) + + +async def log_user_action( + conn: asyncpg.Connection, + user_id: int | None, + username: str | None, + event_type: AuditEventType, + action: str, + resource_type: str | None = None, + resource_id: str | None = None, + details: dict[str, Any] | None = None, + ip_address: str | None = None, + user_agent: str | None = None, +) -> int: + """ + Log a user action to the audit log. + + Args: + conn: Database connection + user_id: User performing the action (None for system events) + username: Username for display + event_type: Type of event + action: Human-readable action description + resource_type: Type of resource (reservation, disk, etc.) + resource_id: ID of affected resource + details: Additional details to log + ip_address: Client IP address + user_agent: Client user agent + + Returns: + Event ID of the created log entry + """ + details_json = json.dumps(details or {}) + + event_id = await conn.fetchval( + """ + INSERT INTO audit_log ( + user_id, username, event_type, action, + resource_type, resource_id, details, + ip_address, user_agent + ) VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9) + RETURNING event_id + """, + user_id, + username, + event_type.value, + action, + resource_type, + resource_id, + details_json, + ip_address, + user_agent, + ) + + logger.debug( + f"Audit log: {event_type.value} by {username or 'system'} - {action}" + ) + + return event_id + + +async def log_token_usage( + conn: asyncpg.Connection, + user_id: int, + model: str, + input_tokens: int, + output_tokens: int, + request_id: str | None = None, + cost_usd: float | None = None, +) -> int: + """ + Log LLM token usage. + + Args: + conn: Database connection + user_id: User who made the request + model: LLM model name + input_tokens: Number of input tokens + output_tokens: Number of output tokens + request_id: Request identifier for correlation + cost_usd: Estimated cost (optional) + + Returns: + Usage record ID + """ + total_tokens = input_tokens + output_tokens + + usage_id = await conn.fetchval( + """ + INSERT INTO token_usage ( + user_id, model, input_tokens, output_tokens, + total_tokens, cost_usd, request_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING usage_id + """, + user_id, + model, + input_tokens, + output_tokens, + total_tokens, + cost_usd, + request_id, + ) + + logger.debug( + f"Token usage: user={user_id} model={model} " + f"tokens={total_tokens} (in={input_tokens}, out={output_tokens})" + ) + + return usage_id + + +async def get_user_audit_log( + conn: asyncpg.Connection, + user_id: int, + limit: int = 100, + offset: int = 0, + event_types: list[AuditEventType] | None = None, + since: datetime | None = None, + until: datetime | None = None, +) -> list[AuditEvent]: + """ + Get audit log entries for a user. + + Args: + conn: Database connection + user_id: User ID to query + limit: Maximum entries to return + offset: Offset for pagination + event_types: Filter by event types + since: Only events after this time + until: Only events before this time + + Returns: + List of AuditEvent objects + """ + conditions = ["user_id = $1"] + params: list[Any] = [user_id] + param_idx = 2 + + if event_types: + placeholders = ", ".join(f"${i}" for i in range(param_idx, param_idx + len(event_types))) + conditions.append(f"event_type IN ({placeholders})") + params.extend(et.value for et in event_types) + param_idx += len(event_types) + + if since: + conditions.append(f"created_at >= ${param_idx}") + params.append(since) + param_idx += 1 + + if until: + conditions.append(f"created_at <= ${param_idx}") + params.append(until) + param_idx += 1 + + where_clause = " AND ".join(conditions) + + rows = await conn.fetch( + f""" + SELECT + event_id, user_id, username, event_type, action, + resource_type, resource_id, details, + ip_address, user_agent, created_at + FROM audit_log + WHERE {where_clause} + ORDER BY created_at DESC + LIMIT ${param_idx} OFFSET ${param_idx + 1} + """, + *params, + limit, + offset, + ) + + return [ + AuditEvent( + event_id=row['event_id'], + user_id=row['user_id'], + username=row['username'], + event_type=AuditEventType(row['event_type']), + resource_type=row['resource_type'], + resource_id=row['resource_id'], + action=row['action'], + details=row['details'] or {}, + ip_address=row['ip_address'], + user_agent=row['user_agent'], + created_at=row['created_at'], + ) + for row in rows + ] + + +async def get_resource_audit_log( + conn: asyncpg.Connection, + resource_type: str, + resource_id: str, + limit: int = 100, +) -> list[AuditEvent]: + """ + Get audit log entries for a specific resource. + + Args: + conn: Database connection + resource_type: Type of resource (reservation, disk, etc.) + resource_id: Resource identifier + limit: Maximum entries to return + + Returns: + List of AuditEvent objects + """ + rows = await conn.fetch( + """ + SELECT + event_id, user_id, username, event_type, action, + resource_type, resource_id, details, + ip_address, user_agent, created_at + FROM audit_log + WHERE resource_type = $1 AND resource_id = $2 + ORDER BY created_at DESC + LIMIT $3 + """, + resource_type, + resource_id, + limit, + ) + + return [ + AuditEvent( + event_id=row['event_id'], + user_id=row['user_id'], + username=row['username'], + event_type=AuditEventType(row['event_type']), + resource_type=row['resource_type'], + resource_id=row['resource_id'], + action=row['action'], + details=row['details'] or {}, + ip_address=row['ip_address'], + user_agent=row['user_agent'], + created_at=row['created_at'], + ) + for row in rows + ] + + +async def get_user_token_usage( + conn: asyncpg.Connection, + user_id: int, + since: datetime | None = None, + until: datetime | None = None, +) -> dict[str, Any]: + """ + Get token usage summary for a user. + + Args: + conn: Database connection + user_id: User ID + since: Start of period + until: End of period + + Returns: + Usage summary with totals by model + """ + conditions = ["user_id = $1"] + params: list[Any] = [user_id] + param_idx = 2 + + if since: + conditions.append(f"created_at >= ${param_idx}") + params.append(since) + param_idx += 1 + + if until: + conditions.append(f"created_at <= ${param_idx}") + params.append(until) + param_idx += 1 + + where_clause = " AND ".join(conditions) + + # Get totals by model + rows = await conn.fetch( + f""" + SELECT + model, + COUNT(*) as request_count, + SUM(input_tokens) as total_input_tokens, + SUM(output_tokens) as total_output_tokens, + SUM(total_tokens) as total_tokens, + SUM(COALESCE(cost_usd, 0)) as total_cost_usd + FROM token_usage + WHERE {where_clause} + GROUP BY model + ORDER BY total_tokens DESC + """, + *params, + ) + + by_model = { + row['model']: { + "request_count": row['request_count'], + "input_tokens": row['total_input_tokens'], + "output_tokens": row['total_output_tokens'], + "total_tokens": row['total_tokens'], + "cost_usd": float(row['total_cost_usd']) if row['total_cost_usd'] else 0.0, + } + for row in rows + } + + # Calculate grand totals + total_requests = sum(m['request_count'] for m in by_model.values()) + total_input = sum(m['input_tokens'] for m in by_model.values()) + total_output = sum(m['output_tokens'] for m in by_model.values()) + total_tokens = sum(m['total_tokens'] for m in by_model.values()) + total_cost = sum(m['cost_usd'] for m in by_model.values()) + + return { + "user_id": user_id, + "period": { + "since": since.isoformat() if since else None, + "until": until.isoformat() if until else None, + }, + "totals": { + "request_count": total_requests, + "input_tokens": total_input, + "output_tokens": total_output, + "total_tokens": total_tokens, + "cost_usd": total_cost, + }, + "by_model": by_model, + } + + +async def cleanup_old_audit_logs( + conn: asyncpg.Connection, + days_to_keep: int = 90, +) -> int: + """ + Delete audit log entries older than specified days. + + Args: + conn: Database connection + days_to_keep: Number of days to retain + + Returns: + Number of entries deleted + """ + cutoff = datetime.now(UTC) - timedelta(days=days_to_keep) + + result = await conn.execute( + """ + DELETE FROM audit_log + WHERE created_at < $1 + """, + cutoff, + ) + + deleted = int(result.split()[-1]) + + if deleted > 0: + logger.info( + f"Cleaned up {deleted} audit log entries " + f"older than {days_to_keep} days" + ) + + return deleted + + +class AuditLogger: + """ + High-level audit logging interface. + + Provides convenient methods for logging various event types. + + Usage: + audit = AuditLogger(pool) + await audit.log_login(user_id, username, ip="192.168.1.1") + await audit.log_reservation_create(user_id, username, reservation_id, details) + """ + + def __init__(self, pool: asyncpg.Pool): + """ + Initialize audit logger. + + Args: + pool: asyncpg connection pool + """ + self._pool = pool + + async def log_action( + self, + event_type: AuditEventType, + action: str, + user_id: int | None = None, + username: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + details: dict[str, Any] | None = None, + ip_address: str | None = None, + user_agent: str | None = None, + ) -> int: + """Log a generic action.""" + async with self._pool.acquire() as conn: + return await log_user_action( + conn, + user_id=user_id, + username=username, + event_type=event_type, + action=action, + resource_type=resource_type, + resource_id=resource_id, + details=details, + ip_address=ip_address, + user_agent=user_agent, + ) + + async def log_login( + self, + user_id: int, + username: str, + provider: str | None = None, + ip_address: str | None = None, + user_agent: str | None = None, + ) -> int: + """Log successful login.""" + return await self.log_action( + AuditEventType.AUTH_LOGIN, + f"User logged in via {provider or 'unknown'}", + user_id=user_id, + username=username, + details={"provider": provider} if provider else None, + ip_address=ip_address, + user_agent=user_agent, + ) + + async def log_auth_failed( + self, + username: str | None = None, + reason: str = "Unknown", + ip_address: str | None = None, + ) -> int: + """Log failed authentication attempt.""" + return await self.log_action( + AuditEventType.AUTH_FAILED, + f"Authentication failed: {reason}", + username=username, + details={"reason": reason}, + ip_address=ip_address, + ) + + async def log_key_create( + self, + user_id: int, + username: str, + key_prefix: str, + ttl_hours: int, + ) -> int: + """Log API key creation.""" + return await self.log_action( + AuditEventType.AUTH_KEY_CREATE, + f"Created API key {key_prefix}...", + user_id=user_id, + username=username, + details={"key_prefix": key_prefix, "ttl_hours": ttl_hours}, + ) + + async def log_reservation_create( + self, + user_id: int, + username: str, + reservation_id: str, + details: dict[str, Any] | None = None, + ) -> int: + """Log reservation creation.""" + return await self.log_action( + AuditEventType.RESERVATION_CREATE, + "Created GPU reservation", + user_id=user_id, + username=username, + resource_type="reservation", + resource_id=reservation_id, + details=details, + ) + + async def log_reservation_cancel( + self, + user_id: int, + username: str, + reservation_id: str, + reason: str | None = None, + ) -> int: + """Log reservation cancellation.""" + return await self.log_action( + AuditEventType.RESERVATION_CANCEL, + f"Cancelled reservation{f': {reason}' if reason else ''}", + user_id=user_id, + username=username, + resource_type="reservation", + resource_id=reservation_id, + details={"reason": reason} if reason else None, + ) + + async def log_token_usage( + self, + user_id: int, + model: str, + input_tokens: int, + output_tokens: int, + request_id: str | None = None, + cost_usd: float | None = None, + ) -> int: + """Log LLM token usage.""" + async with self._pool.acquire() as conn: + return await log_token_usage( + conn, + user_id=user_id, + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + request_id=request_id, + cost_usd=cost_usd, + ) + + async def get_user_history( + self, + user_id: int, + limit: int = 100, + offset: int = 0, + event_types: list[AuditEventType] | None = None, + ) -> list[AuditEvent]: + """Get user's audit history.""" + async with self._pool.acquire() as conn: + return await get_user_audit_log( + conn, + user_id=user_id, + limit=limit, + offset=offset, + event_types=event_types, + ) + + async def get_resource_history( + self, + resource_type: str, + resource_id: str, + limit: int = 100, + ) -> list[AuditEvent]: + """Get audit history for a resource.""" + async with self._pool.acquire() as conn: + return await get_resource_audit_log( + conn, + resource_type=resource_type, + resource_id=resource_id, + limit=limit, + ) + + async def get_user_token_summary( + self, + user_id: int, + days: int = 30, + ) -> dict[str, Any]: + """Get token usage summary for user.""" + since = datetime.now(UTC) - timedelta(days=days) + async with self._pool.acquire() as conn: + return await get_user_token_usage(conn, user_id, since=since) + + async def cleanup(self, days_to_keep: int = 90) -> int: + """Clean up old audit logs.""" + async with self._pool.acquire() as conn: + return await cleanup_old_audit_logs(conn, days_to_keep) diff --git a/terraform-gpu-devservers/shared/auth/oidc.py b/terraform-gpu-devservers/shared/auth/oidc.py new file mode 100644 index 00000000..3c857363 --- /dev/null +++ b/terraform-gpu-devservers/shared/auth/oidc.py @@ -0,0 +1,573 @@ +""" +OIDC Token Verification Module + +Verifies JWT tokens from multiple OIDC providers (GitHub, Google, Okta). +Implements JWKS caching for performance and supports multiple issuers. + +Usage: + verifier = OIDCVerifier() + verifier.add_provider(OIDCProvider( + name="github", + issuer="https://token.actions.githubusercontent.com", + audience="gpu-dev-api" + )) + + identity = await verifier.verify_token(token) + print(f"Authenticated: {identity.subject} from {identity.issuer}") +""" + +import asyncio +import hashlib +import logging +import time +from dataclasses import dataclass, field +from datetime import UTC, datetime +from enum import Enum +from typing import Any + +import httpx +import jwt +from jwt import PyJWKClient, PyJWKClientError + +logger = logging.getLogger(__name__) + +# JWKS cache TTL (5 minutes) +JWKS_CACHE_TTL_SECONDS = 300 + +# Token clock skew tolerance (30 seconds) +CLOCK_SKEW_SECONDS = 30 + + +class OIDCVerificationError(Exception): + """Base exception for OIDC verification errors.""" + pass + + +class JWKSFetchError(OIDCVerificationError): + """Failed to fetch JWKS from issuer.""" + pass + + +class TokenExpiredError(OIDCVerificationError): + """Token has expired.""" + pass + + +class InvalidIssuerError(OIDCVerificationError): + """Token issuer is not trusted.""" + pass + + +class InvalidAudienceError(OIDCVerificationError): + """Token audience does not match expected value.""" + pass + + +class InvalidSignatureError(OIDCVerificationError): + """Token signature verification failed.""" + pass + + +class InvalidClaimsError(OIDCVerificationError): + """Token claims are invalid or missing.""" + pass + + +@dataclass +class OIDCProvider: + """ + Configuration for an OIDC identity provider. + + Attributes: + name: Human-readable provider name (e.g., "github", "google") + issuer: Token issuer URL (must match `iss` claim exactly) + audience: Expected audience (`aud` claim), can be string or list + jwks_uri: Optional custom JWKS URI (auto-discovered if not set) + additional_claims: Extra claims to extract from tokens + """ + name: str + issuer: str + audience: str | list[str] + jwks_uri: str | None = None + additional_claims: list[str] = field(default_factory=list) + + def __post_init__(self): + # Normalize issuer URL (remove trailing slash) + self.issuer = self.issuer.rstrip("/") + + +@dataclass +class OIDCIdentity: + """ + Verified identity extracted from an OIDC token. + + Attributes: + subject: Unique user identifier (`sub` claim) + issuer: Token issuer (`iss` claim) + email: User email if available + username: Username or preferred_username claim + groups: Group memberships if available + provider_name: Name of the OIDC provider that issued this token + claims: All verified claims from the token + verified_at: When the token was verified + """ + subject: str + issuer: str + email: str | None + username: str | None + groups: list[str] + provider_name: str + claims: dict[str, Any] + verified_at: datetime = field(default_factory=lambda: datetime.now(UTC)) + + @property + def display_name(self) -> str: + """Human-readable identifier for logging.""" + return self.email or self.username or self.subject + + +@dataclass +class CachedJWKS: + """Cached JWKS with expiration.""" + client: PyJWKClient + fetched_at: float + + def is_expired(self) -> bool: + return time.time() - self.fetched_at > JWKS_CACHE_TTL_SECONDS + + +class OIDCVerifier: + """ + Verifies OIDC tokens from multiple identity providers. + + Thread-safe and supports concurrent token verification. + Caches JWKS for performance. + + Usage: + verifier = OIDCVerifier() + verifier.add_provider(OIDCProvider(...)) + identity = await verifier.verify_token(token) + """ + + def __init__(self): + self._providers: dict[str, OIDCProvider] = {} + self._jwks_cache: dict[str, CachedJWKS] = {} + self._lock = asyncio.Lock() + self._http_client: httpx.AsyncClient | None = None + + def add_provider(self, provider: OIDCProvider) -> None: + """ + Register an OIDC provider. + + Args: + provider: Provider configuration + """ + self._providers[provider.issuer] = provider + logger.info(f"Registered OIDC provider: {provider.name} ({provider.issuer})") + + def remove_provider(self, issuer: str) -> bool: + """ + Remove an OIDC provider. + + Args: + issuer: Provider issuer URL + + Returns: + True if provider was removed, False if not found + """ + if issuer in self._providers: + del self._providers[issuer] + if issuer in self._jwks_cache: + del self._jwks_cache[issuer] + return True + return False + + @property + def providers(self) -> list[OIDCProvider]: + """List all registered providers.""" + return list(self._providers.values()) + + async def _get_http_client(self) -> httpx.AsyncClient: + """Get or create HTTP client for JWKS fetching.""" + if self._http_client is None or self._http_client.is_closed: + self._http_client = httpx.AsyncClient( + timeout=10.0, + follow_redirects=True, + headers={"User-Agent": "gpu-dev-oidc-verifier/1.0"} + ) + return self._http_client + + async def _discover_jwks_uri(self, issuer: str) -> str: + """ + Discover JWKS URI from OIDC discovery document. + + Args: + issuer: OIDC issuer URL + + Returns: + JWKS URI + + Raises: + JWKSFetchError: If discovery fails + """ + discovery_url = f"{issuer}/.well-known/openid-configuration" + + try: + client = await self._get_http_client() + response = await client.get(discovery_url) + response.raise_for_status() + + config = response.json() + jwks_uri = config.get("jwks_uri") + + if not jwks_uri: + raise JWKSFetchError( + f"No jwks_uri in discovery document for {issuer}" + ) + + logger.debug(f"Discovered JWKS URI for {issuer}: {jwks_uri}") + return jwks_uri + + except httpx.HTTPError as e: + raise JWKSFetchError( + f"Failed to fetch OIDC discovery document from {issuer}: {e}" + ) from e + + async def _get_jwks_client(self, provider: OIDCProvider) -> PyJWKClient: + """ + Get JWKS client for provider, using cache if available. + + Args: + provider: OIDC provider configuration + + Returns: + PyJWKClient for the provider + """ + async with self._lock: + cached = self._jwks_cache.get(provider.issuer) + + if cached and not cached.is_expired(): + return cached.client + + # Discover or use configured JWKS URI + jwks_uri = provider.jwks_uri + if not jwks_uri: + jwks_uri = await self._discover_jwks_uri(provider.issuer) + + # Create new JWKS client + try: + client = PyJWKClient( + jwks_uri, + cache_jwk_set=True, + lifespan=JWKS_CACHE_TTL_SECONDS + ) + + self._jwks_cache[provider.issuer] = CachedJWKS( + client=client, + fetched_at=time.time() + ) + + logger.debug(f"Created JWKS client for {provider.issuer}") + return client + + except Exception as e: + raise JWKSFetchError( + f"Failed to create JWKS client for {provider.issuer}: {e}" + ) from e + + def _find_provider_for_token(self, token: str) -> OIDCProvider: + """ + Find the appropriate provider for a token by decoding without verification. + + Args: + token: JWT token + + Returns: + Matching OIDCProvider + + Raises: + InvalidIssuerError: If no matching provider found + """ + try: + # Decode without verification to get issuer + unverified = jwt.decode( + token, + options={"verify_signature": False}, + algorithms=["RS256", "RS384", "RS512", "ES256", "ES384", "ES512"] + ) + + issuer = unverified.get("iss", "").rstrip("/") + + if issuer in self._providers: + return self._providers[issuer] + + raise InvalidIssuerError( + f"No trusted provider for issuer: {issuer}. " + f"Trusted issuers: {list(self._providers.keys())}" + ) + + except jwt.DecodeError as e: + raise OIDCVerificationError(f"Invalid token format: {e}") from e + + def _validate_audience( + self, + token_audience: str | list[str], + expected: str | list[str] + ) -> bool: + """ + Validate token audience against expected values. + + Args: + token_audience: Audience from token (aud claim) + expected: Expected audience value(s) + + Returns: + True if audience is valid + """ + # Normalize to lists + if isinstance(token_audience, str): + token_audiences = [token_audience] + else: + token_audiences = token_audience + + if isinstance(expected, str): + expected_audiences = [expected] + else: + expected_audiences = expected + + # Check if any token audience matches any expected audience + return bool(set(token_audiences) & set(expected_audiences)) + + async def verify_token(self, token: str) -> OIDCIdentity: + """ + Verify an OIDC token and extract identity. + + Args: + token: JWT token string + + Returns: + Verified OIDCIdentity + + Raises: + OIDCVerificationError: If verification fails + InvalidIssuerError: If issuer is not trusted + TokenExpiredError: If token has expired + InvalidAudienceError: If audience doesn't match + InvalidSignatureError: If signature is invalid + """ + # Find provider for this token + provider = self._find_provider_for_token(token) + + # Get JWKS client + jwks_client = await self._get_jwks_client(provider) + + try: + # Get signing key from JWKS + signing_key = jwks_client.get_signing_key_from_jwt(token) + + # Verify and decode token + claims = jwt.decode( + token, + signing_key.key, + algorithms=["RS256", "RS384", "RS512", "ES256", "ES384", "ES512"], + issuer=provider.issuer, + options={ + "verify_signature": True, + "verify_exp": True, + "verify_nbf": True, + "verify_iat": True, + "verify_iss": True, + "require": ["sub", "iss", "exp", "iat"], + }, + leeway=CLOCK_SKEW_SECONDS + ) + + # Validate audience manually for more control + token_audience = claims.get("aud") + if token_audience and not self._validate_audience( + token_audience, provider.audience + ): + raise InvalidAudienceError( + f"Token audience {token_audience} does not match " + f"expected {provider.audience}" + ) + + # Extract identity claims + identity = self._extract_identity(claims, provider) + + logger.info( + f"Verified OIDC token for {identity.display_name} " + f"from {provider.name}" + ) + + return identity + + except jwt.ExpiredSignatureError as e: + raise TokenExpiredError("Token has expired") from e + except jwt.InvalidIssuerError as e: + raise InvalidIssuerError(f"Invalid issuer: {e}") from e + except jwt.InvalidSignatureError as e: + raise InvalidSignatureError(f"Invalid signature: {e}") from e + except jwt.InvalidAudienceError as e: + raise InvalidAudienceError(f"Invalid audience: {e}") from e + except PyJWKClientError as e: + raise JWKSFetchError(f"Failed to get signing key: {e}") from e + except jwt.PyJWTError as e: + raise OIDCVerificationError(f"Token verification failed: {e}") from e + + def _extract_identity( + self, + claims: dict[str, Any], + provider: OIDCProvider + ) -> OIDCIdentity: + """ + Extract OIDCIdentity from verified claims. + + Handles provider-specific claim names. + """ + subject = claims["sub"] + issuer = claims["iss"] + + # Extract email (various claim names) + email = ( + claims.get("email") or + claims.get("preferred_username") or + claims.get("upn") # Azure AD + ) + + # Extract username (various claim names) + username = ( + claims.get("preferred_username") or + claims.get("name") or + claims.get("nickname") or # GitHub + claims.get("login") or # GitHub + claims.get("sub") # Fallback to subject + ) + + # Extract groups (various claim names) + groups = [] + for claim_name in ["groups", "roles", "cognito:groups", "custom:groups"]: + if claim_name in claims: + group_value = claims[claim_name] + if isinstance(group_value, list): + groups.extend(group_value) + elif isinstance(group_value, str): + groups.append(group_value) + + # Extract additional provider-specific claims + extracted_claims = dict(claims) + for claim_name in provider.additional_claims: + if claim_name in claims: + extracted_claims[claim_name] = claims[claim_name] + + return OIDCIdentity( + subject=subject, + issuer=issuer, + email=email if email and "@" in str(email) else None, + username=username, + groups=groups, + provider_name=provider.name, + claims=extracted_claims + ) + + async def close(self) -> None: + """Close HTTP client and cleanup resources.""" + if self._http_client: + await self._http_client.aclose() + self._http_client = None + self._jwks_cache.clear() + + +# Pre-configured providers for common OIDC issuers +def create_github_provider( + audience: str = "gpu-dev-api", + organization: str | None = None +) -> OIDCProvider: + """ + Create GitHub Actions OIDC provider configuration. + + Args: + audience: Expected audience (matches `aud` claim) + organization: Optional GitHub org to restrict access + + Returns: + Configured OIDCProvider for GitHub Actions + """ + return OIDCProvider( + name="github", + issuer="https://token.actions.githubusercontent.com", + audience=audience, + additional_claims=[ + "repository", + "repository_owner", + "actor", + "workflow", + "ref" + ] + ) + + +def create_google_provider( + client_id: str, + hd: str | None = None +) -> OIDCProvider: + """ + Create Google OIDC provider configuration. + + Args: + client_id: Google OAuth client ID (becomes audience) + hd: Optional hosted domain restriction + + Returns: + Configured OIDCProvider for Google + """ + return OIDCProvider( + name="google", + issuer="https://accounts.google.com", + audience=client_id, + additional_claims=["hd", "picture"] if hd else ["picture"] + ) + + +def create_okta_provider( + issuer: str, + audience: str +) -> OIDCProvider: + """ + Create Okta OIDC provider configuration. + + Args: + issuer: Okta issuer URL (e.g., https://your-domain.okta.com) + audience: Expected audience + + Returns: + Configured OIDCProvider for Okta + """ + return OIDCProvider( + name="okta", + issuer=issuer, + audience=audience, + additional_claims=["groups", "preferred_username"] + ) + + +def create_azure_ad_provider( + tenant_id: str, + client_id: str +) -> OIDCProvider: + """ + Create Azure AD OIDC provider configuration. + + Args: + tenant_id: Azure AD tenant ID + client_id: Application (client) ID + + Returns: + Configured OIDCProvider for Azure AD + """ + return OIDCProvider( + name="azure_ad", + issuer=f"https://login.microsoftonline.com/{tenant_id}/v2.0", + audience=client_id, + additional_claims=["groups", "roles", "upn", "tid"] + ) diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..a111716e --- /dev/null +++ b/tests/README.md @@ -0,0 +1,197 @@ +# ODC Test Suite + +Comprehensive tests for the ODC (GPU Dev Servers) system. + +## Test Structure + +``` +tests/ +├── conftest.py # Shared fixtures and AWS mocking +├── requirements-test.txt # Test dependencies +├── unit/ # Unit tests (mocked, fast) +│ ├── cli/ # CLI module tests +│ │ ├── test_config.py +│ │ ├── test_auth.py +│ │ ├── test_disks.py +│ │ └── test_reservations.py +│ └── lambda/ # Lambda function tests +│ └── test_reservation_processor.py +├── e2e/ # End-to-end tests (real AWS) +│ ├── test_reservation_flow.py +│ └── test_cli_commands.py +└── fixtures/ # Test data factories +``` + +## Running Tests + +### Prerequisites + +```bash +# Install test dependencies +pip install -r tests/requirements-test.txt + +# Or install with optional test group +pip install -e ".[test]" +``` + +### Unit Tests (Fast, Mocked) + +Unit tests use moto to mock AWS services. No AWS credentials required. + +```bash +# Run all unit tests +pytest tests/unit/ -v + +# Run with coverage +pytest tests/unit/ --cov=cli-tools/gpu-dev-cli --cov-report=html + +# Run specific test file +pytest tests/unit/cli/test_config.py -v + +# Run specific test class +pytest tests/unit/cli/test_config.py::TestConfigInit -v + +# Run specific test +pytest tests/unit/cli/test_config.py::TestConfigInit::test_config_creates_default_file_if_missing -v +``` + +### E2E Tests (Real AWS Dev Cluster) + +E2E tests run against the actual dev cluster in us-west-1. + +**Requirements:** +- AWS credentials with gpu-dev access +- GitHub username configured: `gpu-dev config set github_user ` +- Test environment enabled: `gpu-dev config environment test` + +```bash +# Run E2E tests +RUN_E2E_TESTS=1 pytest tests/e2e/ -v + +# Run with specific GitHub user +RUN_E2E_TESTS=1 E2E_GITHUB_USER=myuser pytest tests/e2e/ -v + +# Skip slow tests +RUN_E2E_TESTS=1 pytest tests/e2e/ -v -m "not slow" + +# Run only fast E2E tests (no actual reservations) +RUN_E2E_TESTS=1 pytest tests/e2e/test_cli_commands.py -v +``` + +### All Tests + +```bash +# Run everything (unit only by default) +pytest + +# Run with markers +pytest -m unit # Only unit tests +pytest -m e2e # Only E2E tests (requires RUN_E2E_TESTS=1) +pytest -m "not slow" # Skip slow tests +``` + +## Test Markers + +- `@pytest.mark.unit` - Unit tests (fast, mocked) +- `@pytest.mark.e2e` - End-to-end tests (require real AWS) +- `@pytest.mark.slow` - Slow tests (can be skipped) + +## Writing Tests + +### Unit Tests + +Use moto for AWS mocking: + +```python +from moto import mock_aws +import boto3 + +@mock_aws +def test_something(aws_credentials): + # Create mock resources + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + # ... test code +``` + +Use fixtures from conftest.py: + +```python +def test_with_fixtures(dynamodb_mock, reservation_factory): + # dynamodb_mock provides mock DynamoDB tables + # reservation_factory creates test reservation data + reservation = reservation_factory.create( + user_id="test-user", + gpu_count=2, + gpu_type="t4", + ) +``` + +### E2E Tests + +Use cleanup fixtures to avoid resource leaks: + +```python +@pytest.mark.e2e +def test_reservation(cleanup_reservations): + # Create reservation + result = subprocess.run(["gpu-dev", "reserve", ...]) + + # Track for cleanup + cleanup_reservations.append(reservation_id) +``` + +## Coverage + +Generate coverage report: + +```bash +# HTML report +pytest tests/unit/ --cov=cli-tools/gpu-dev-cli --cov-report=html +open htmlcov/index.html + +# Terminal report +pytest tests/unit/ --cov=cli-tools/gpu-dev-cli --cov-report=term-missing +``` + +## CI Integration + +For GitHub Actions: + +```yaml +- name: Run Unit Tests + run: pytest tests/unit/ -v --cov + +- name: Run E2E Tests + if: github.event_name == 'schedule' # Nightly only + env: + RUN_E2E_TESTS: "1" + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + run: pytest tests/e2e/ -v -m "not slow" +``` + +## Troubleshooting + +### "No module named 'gpu_dev_cli'" + +Install the package in development mode: + +```bash +pip install -e . +``` + +### "AWS credentials not found" in unit tests + +Unit tests shouldn't need real credentials - check that you're using `@mock_aws` decorator. + +### E2E tests timing out + +- Check your AWS credentials are valid +- Verify test environment: `gpu-dev config environment test` +- Ensure dev cluster is running + +### Tests pass locally but fail in CI + +- Check environment variables are set +- Verify CI has access to AWS (for E2E only) +- Check for timezone-sensitive assertions diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..0f261770 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""ODC Test Suite""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..7e0255e9 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,382 @@ +""" +Shared pytest fixtures for ODC test suite + +Provides: +- AWS mocking (DynamoDB, SQS, EC2, S3) +- Test data factories +- Configuration fixtures +""" + +import json +import os +import sys +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from typing import Dict, Any, Optional +from unittest.mock import MagicMock, patch + +import boto3 +import pytest +from moto import mock_aws + +# Add CLI and Lambda source to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "cli-tools", "gpu-dev-cli")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "terraform-gpu-devservers", "lambda", "reservation_processor")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "terraform-gpu-devservers", "lambda", "shared")) + + +# Test configuration +TEST_AWS_REGION = "us-west-1" +TEST_PREFIX = "pytorch-gpu-dev-test" + + +@pytest.fixture(scope="session") +def aws_credentials(): + """Mock AWS credentials for moto""" + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_DEFAULT_REGION"] = TEST_AWS_REGION + + +@pytest.fixture +def mock_aws_env(aws_credentials): + """Set up mock AWS environment variables""" + env_vars = { + "AWS_DEFAULT_REGION": TEST_AWS_REGION, + "AWS_REGION": TEST_AWS_REGION, + "RESERVATIONS_TABLE": f"{TEST_PREFIX}-reservations", + "EKS_CLUSTER_NAME": f"{TEST_PREFIX}-cluster", + "REGION": TEST_AWS_REGION, + "MAX_RESERVATION_HOURS": "48", + "DEFAULT_TIMEOUT_HOURS": "8", + "QUEUE_URL": f"https://sqs.{TEST_AWS_REGION}.amazonaws.com/123456789012/{TEST_PREFIX}-reservation-queue", + "PRIMARY_AVAILABILITY_ZONE": f"{TEST_AWS_REGION}a", + "GPU_DEV_CONTAINER_IMAGE": "pytorch/pytorch:2.8.0-cuda12.9-cudnn9-devel", + "LAMBDA_VERSION": "0.3.5", + "MIN_CLI_VERSION": "0.3.0", + } + with patch.dict(os.environ, env_vars): + yield env_vars + + +@pytest.fixture +def dynamodb_mock(aws_credentials): + """Create mock DynamoDB tables""" + with mock_aws(): + dynamodb = boto3.resource("dynamodb", region_name=TEST_AWS_REGION) + + # Create reservations table + reservations_table = dynamodb.create_table( + TableName=f"{TEST_PREFIX}-reservations", + KeySchema=[ + {"AttributeName": "reservation_id", "KeyType": "HASH"}, + ], + AttributeDefinitions=[ + {"AttributeName": "reservation_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "status", "AttributeType": "S"}, + {"AttributeName": "created_at", "AttributeType": "S"}, + ], + GlobalSecondaryIndexes=[ + { + "IndexName": "UserIndex", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "reservation_id", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }, + { + "IndexName": "StatusIndex", + "KeySchema": [ + {"AttributeName": "status", "KeyType": "HASH"}, + {"AttributeName": "created_at", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }, + ], + BillingMode="PAY_PER_REQUEST", + ) + reservations_table.wait_until_exists() + + # Create disks table + disks_table = dynamodb.create_table( + TableName=f"{TEST_PREFIX}-disks", + KeySchema=[ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "disk_name", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "disk_name", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + disks_table.wait_until_exists() + + # Create availability table + availability_table = dynamodb.create_table( + TableName=f"{TEST_PREFIX}-gpu-availability", + KeySchema=[ + {"AttributeName": "gpu_type", "KeyType": "HASH"}, + ], + AttributeDefinitions=[ + {"AttributeName": "gpu_type", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + availability_table.wait_until_exists() + + yield dynamodb + + +@pytest.fixture +def sqs_mock(aws_credentials): + """Create mock SQS queue""" + with mock_aws(): + sqs = boto3.client("sqs", region_name=TEST_AWS_REGION) + queue = sqs.create_queue(QueueName=f"{TEST_PREFIX}-reservation-queue") + yield sqs, queue["QueueUrl"] + + +@pytest.fixture +def ec2_mock(aws_credentials): + """Create mock EC2 with test instances""" + with mock_aws(): + ec2 = boto3.client("ec2", region_name=TEST_AWS_REGION) + yield ec2 + + +@pytest.fixture +def s3_mock(aws_credentials): + """Create mock S3 bucket""" + with mock_aws(): + s3 = boto3.client("s3", region_name=TEST_AWS_REGION) + s3.create_bucket( + Bucket=f"{TEST_PREFIX}-snapshots", + CreateBucketConfiguration={"LocationConstraint": TEST_AWS_REGION}, + ) + yield s3 + + +# Data factories + +class ReservationFactory: + """Factory for creating test reservation data""" + + @staticmethod + def create( + reservation_id: Optional[str] = None, + user_id: str = "test-user", + status: str = "active", + gpu_count: int = 1, + gpu_type: str = "t4", + duration_hours: int = 4, + pod_name: Optional[str] = None, + **kwargs + ) -> Dict[str, Any]: + """Create a reservation dict with defaults""" + rid = reservation_id or f"res-{datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}" + now = datetime.now(timezone.utc) + + reservation = { + "reservation_id": rid, + "user_id": user_id, + "status": status, + "gpu_count": Decimal(gpu_count), + "gpu_type": gpu_type, + "duration_hours": Decimal(duration_hours), + "created_at": now.isoformat(), + "expires_at": (now + timedelta(hours=duration_hours)).isoformat(), + "pod_name": pod_name or f"gpu-dev-{rid[:8]}", + "namespace": "gpu-dev", + "node_port": Decimal(30000 + hash(rid) % 2767), + } + reservation.update(kwargs) + return reservation + + +class DiskFactory: + """Factory for creating test disk data""" + + @staticmethod + def create( + user_id: str = "test-user", + disk_name: str = "default", + size_gb: int = 100, + **kwargs + ) -> Dict[str, Any]: + """Create a disk dict with defaults""" + now = datetime.now(timezone.utc) + + disk = { + "user_id": user_id, + "disk_name": disk_name, + "size_gb": Decimal(size_gb), + "created_at": now.isoformat(), + "last_used": now.isoformat(), + "snapshot_count": Decimal(0), + "in_use": False, + "is_deleted": False, + } + disk.update(kwargs) + return disk + + +class GPUAvailabilityFactory: + """Factory for creating GPU availability data""" + + @staticmethod + def create( + gpu_type: str = "t4", + available_gpus: int = 4, + total_gpus: int = 8, + queue_length: int = 0, + **kwargs + ) -> Dict[str, Any]: + """Create GPU availability dict""" + return { + "gpu_type": gpu_type, + "available_gpus": Decimal(available_gpus), + "total_gpus": Decimal(total_gpus), + "queue_length": Decimal(queue_length), + "estimated_wait_minutes": Decimal(0), + "last_updated": datetime.now(timezone.utc).isoformat(), + **kwargs + } + + +@pytest.fixture +def reservation_factory(): + """Fixture for creating test reservations""" + return ReservationFactory() + + +@pytest.fixture +def disk_factory(): + """Fixture for creating test disks""" + return DiskFactory() + + +@pytest.fixture +def availability_factory(): + """Fixture for creating GPU availability data""" + return GPUAvailabilityFactory() + + +# Mock Kubernetes client + +@pytest.fixture +def mock_k8s_client(): + """Mock Kubernetes client for unit tests""" + mock_client = MagicMock() + + # Mock CoreV1Api + mock_v1 = MagicMock() + mock_client.CoreV1Api.return_value = mock_v1 + + # Mock pod operations + mock_v1.list_namespaced_pod.return_value = MagicMock(items=[]) + mock_v1.create_namespaced_pod.return_value = MagicMock() + mock_v1.delete_namespaced_pod.return_value = MagicMock() + mock_v1.read_namespaced_pod.return_value = MagicMock( + status=MagicMock(phase="Running"), + spec=MagicMock(node_name="test-node"), + ) + + # Mock node operations + mock_node = MagicMock() + mock_node.metadata.name = "test-gpu-node" + mock_node.metadata.labels = {"GpuType": "t4"} + mock_node.status.allocatable = {"nvidia.com/gpu": "4", "cpu": "48", "memory": "192Gi"} + mock_node.status.addresses = [MagicMock(type="ExternalIP", address="1.2.3.4")] + mock_v1.list_node.return_value = MagicMock(items=[mock_node]) + + # Mock service operations + mock_v1.create_namespaced_service.return_value = MagicMock() + mock_v1.delete_namespaced_service.return_value = MagicMock() + + return mock_client + + +# CLI Config fixture + +@pytest.fixture +def mock_cli_config(tmp_path, dynamodb_mock, sqs_mock): + """Create mock CLI Config object""" + config_dir = tmp_path / ".config" / "gpu-dev" + config_dir.mkdir(parents=True) + config_file = config_dir / "config.json" + config_file.write_text(json.dumps({ + "github_user": "testuser", + "environment": "test", + "region": TEST_AWS_REGION, + "workspace": "default", + })) + + # Patch Config class to use temp path + with patch("gpu_dev_cli.config.Config.CONFIG_FILE", config_file): + with patch("gpu_dev_cli.config.Config.LEGACY_CONFIG_FILE", tmp_path / ".gpu-dev-config"): + with patch("gpu_dev_cli.config.Config.LEGACY_ENVIRONMENT_FILE", tmp_path / ".gpu-dev-environment.json"): + from gpu_dev_cli.config import Config + + # Create config with mocked AWS + config = Config() + config.prefix = TEST_PREFIX + config.queue_name = f"{TEST_PREFIX}-reservation-queue" + config.reservations_table = f"{TEST_PREFIX}-reservations" + config.disks_table = f"{TEST_PREFIX}-disks" + config.availability_table = f"{TEST_PREFIX}-gpu-availability" + config.aws_region = TEST_AWS_REGION + + yield config + + +# E2E specific fixtures + +@pytest.fixture(scope="session") +def e2e_config(): + """ + Configuration for E2E tests against real dev cluster. + Requires AWS credentials and kubectl configured. + """ + # Skip if not running E2E tests + if not os.environ.get("RUN_E2E_TESTS"): + pytest.skip("E2E tests require RUN_E2E_TESTS=1 environment variable") + + return { + "region": os.environ.get("E2E_AWS_REGION", "us-west-1"), + "cluster_name": os.environ.get("E2E_CLUSTER_NAME", "pytorch-gpu-dev-cluster"), + "namespace": "gpu-dev", + "github_user": os.environ.get("E2E_GITHUB_USER", "testuser"), + } + + +@pytest.fixture +def e2e_cleanup(e2e_config): + """ + Fixture that tracks resources created during E2E tests for cleanup. + Yields a tracker dict, cleans up after test. + """ + created_resources = { + "reservations": [], + "disks": [], + } + + yield created_resources + + # Cleanup after test + if created_resources["reservations"]: + from gpu_dev_cli.config import Config + from gpu_dev_cli.reservations import ReservationManager + + config = Config() + manager = ReservationManager(config) + + for res_id in created_resources["reservations"]: + try: + manager.cancel_reservation(res_id, force=True) + except Exception as e: + print(f"Warning: Failed to cleanup reservation {res_id}: {e}") diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 00000000..7a346bf6 --- /dev/null +++ b/tests/e2e/__init__.py @@ -0,0 +1 @@ +"""End-to-end tests for ODC""" diff --git a/tests/e2e/test_cli_commands.py b/tests/e2e/test_cli_commands.py new file mode 100644 index 00000000..e035c367 --- /dev/null +++ b/tests/e2e/test_cli_commands.py @@ -0,0 +1,285 @@ +""" +End-to-end tests for CLI command behavior + +Tests CLI commands without actually creating reservations. +These tests verify the CLI interface works correctly. +""" + +import os +import subprocess + +import pytest + + +pytestmark = pytest.mark.skipif( + not os.environ.get("RUN_E2E_TESTS"), + reason="E2E tests require RUN_E2E_TESTS=1" +) + + +class TestCLIHelp: + """Tests for CLI help and version commands""" + + @pytest.mark.e2e + def test_cli_help(self): + """Should show help message""" + result = subprocess.run( + ["gpu-dev", "--help"], + capture_output=True, + text=True, + timeout=30, + ) + + assert result.returncode == 0 + assert "Usage" in result.stdout or "usage" in result.stdout + assert "reserve" in result.stdout + assert "list" in result.stdout + + @pytest.mark.e2e + def test_reserve_help(self): + """Should show reserve command help""" + result = subprocess.run( + ["gpu-dev", "reserve", "--help"], + capture_output=True, + text=True, + timeout=30, + ) + + assert result.returncode == 0 + assert "--gpus" in result.stdout + assert "--gpu-type" in result.stdout + assert "--hours" in result.stdout + + @pytest.mark.e2e + def test_list_help(self): + """Should show list command help""" + result = subprocess.run( + ["gpu-dev", "list", "--help"], + capture_output=True, + text=True, + timeout=30, + ) + + assert result.returncode == 0 + assert "--status" in result.stdout or "--user" in result.stdout + + @pytest.mark.e2e + def test_disk_help(self): + """Should show disk command help""" + result = subprocess.run( + ["gpu-dev", "disk", "--help"], + capture_output=True, + text=True, + timeout=30, + ) + + assert result.returncode == 0 + assert "list" in result.stdout + assert "create" in result.stdout + + +class TestConfigCommands: + """Tests for configuration commands""" + + @pytest.mark.e2e + def test_config_show(self): + """Should show current configuration""" + result = subprocess.run( + ["gpu-dev", "config"], + capture_output=True, + text=True, + timeout=30, + ) + + # Should show config or prompt for setup + assert result.returncode == 0 or "github" in result.stderr.lower() + + @pytest.mark.e2e + def test_config_environment_list(self): + """Should list available environments""" + result = subprocess.run( + ["gpu-dev", "config", "environment", "--help"], + capture_output=True, + text=True, + timeout=30, + ) + + assert result.returncode == 0 + + @pytest.mark.e2e + def test_config_environment_switch(self): + """Should switch between test and prod environments""" + # Switch to test + test_result = subprocess.run( + ["gpu-dev", "config", "environment", "test"], + capture_output=True, + text=True, + timeout=30, + ) + assert test_result.returncode == 0 + + # Switch back to prod + prod_result = subprocess.run( + ["gpu-dev", "config", "environment", "prod"], + capture_output=True, + text=True, + timeout=30, + ) + assert prod_result.returncode == 0 + + +class TestListCommand: + """Tests for list command functionality""" + + @pytest.mark.e2e + def test_list_all_reservations(self): + """Should list reservations with various filters""" + result = subprocess.run( + ["gpu-dev", "list", "--status", "all"], + capture_output=True, + text=True, + timeout=60, + ) + + assert result.returncode == 0 + + @pytest.mark.e2e + def test_list_active_only(self): + """Should filter to active reservations""" + result = subprocess.run( + ["gpu-dev", "list", "--status", "active"], + capture_output=True, + text=True, + timeout=60, + ) + + assert result.returncode == 0 + + @pytest.mark.e2e + def test_list_with_details(self): + """Should show detailed reservation info""" + result = subprocess.run( + ["gpu-dev", "list", "--details"], + capture_output=True, + text=True, + timeout=60, + ) + + assert result.returncode == 0 + + +class TestAvailCommand: + """Tests for availability command""" + + @pytest.mark.e2e + def test_avail_shows_gpu_types(self): + """Should show availability for all GPU types""" + result = subprocess.run( + ["gpu-dev", "avail"], + capture_output=True, + text=True, + timeout=60, + ) + + assert result.returncode == 0 + # Should contain GPU type names + output_lower = result.stdout.lower() + # At least one type should be mentioned + gpu_types = ["t4", "l4", "a100", "h100", "b200", "cpu"] + assert any(t in output_lower for t in gpu_types) + + +class TestStatusCommand: + """Tests for status command""" + + @pytest.mark.e2e + def test_status_command(self): + """Should show cluster status""" + result = subprocess.run( + ["gpu-dev", "status"], + capture_output=True, + text=True, + timeout=60, + ) + + assert result.returncode == 0 + + +class TestDiskCommands: + """Tests for disk subcommands""" + + @pytest.mark.e2e + def test_disk_list_command(self): + """Should list disks""" + result = subprocess.run( + ["gpu-dev", "disk", "list"], + capture_output=True, + text=True, + timeout=60, + ) + + assert result.returncode == 0 + + @pytest.mark.e2e + def test_disk_create_validates_name(self): + """Should validate disk name format""" + result = subprocess.run( + ["gpu-dev", "disk", "create", "invalid name with spaces"], + capture_output=True, + text=True, + timeout=30, + ) + + # Should fail due to invalid name + assert result.returncode != 0 or "invalid" in result.stderr.lower() + + +class TestShowCommand: + """Tests for show command""" + + @pytest.mark.e2e + def test_show_nonexistent_reservation(self): + """Should handle nonexistent reservation gracefully""" + result = subprocess.run( + ["gpu-dev", "show", "nonexistent-reservation-id"], + capture_output=True, + text=True, + timeout=30, + ) + + # Should fail gracefully with message + assert result.returncode != 0 or "not found" in result.stderr.lower() or "error" in result.stderr.lower() + + +class TestConnectCommand: + """Tests for connect command""" + + @pytest.mark.e2e + def test_connect_no_active_reservations(self): + """Should handle case with no active reservations""" + result = subprocess.run( + ["gpu-dev", "connect"], + capture_output=True, + text=True, + timeout=30, + ) + + # Should either connect or say no reservations + # We don't assert returncode since it depends on whether user has reservations + + +class TestCancelCommand: + """Tests for cancel command""" + + @pytest.mark.e2e + def test_cancel_nonexistent(self): + """Should handle canceling nonexistent reservation""" + result = subprocess.run( + ["gpu-dev", "cancel", "nonexistent-id"], + capture_output=True, + text=True, + timeout=30, + ) + + # Should fail gracefully + assert result.returncode != 0 or "not found" in result.stderr.lower() diff --git a/tests/e2e/test_full_flows.py b/tests/e2e/test_full_flows.py new file mode 100644 index 00000000..bc5c6667 --- /dev/null +++ b/tests/e2e/test_full_flows.py @@ -0,0 +1,284 @@ +""" +End-to-end test stubs for ODC + +These tests run against the real AWS us-west-1 test cluster. +They are skipped by default - run with RUN_E2E_TESTS=1 to enable. + +Test flows: +- Complete reservation lifecycle +- Disk management +- Multinode reservations +- Jupyter integration +""" + +import os +import subprocess +import time +from datetime import datetime, timezone + +import pytest + + +# Skip all tests if E2E not enabled +pytestmark = pytest.mark.skipif( + not os.environ.get("RUN_E2E_TESTS"), + reason="E2E tests require RUN_E2E_TESTS=1" +) + + +@pytest.fixture +def gpu_dev_cli(): + """Wrapper for gpu-dev CLI commands""" + def run(*args, timeout=60): + cmd = ["gpu-dev"] + list(args) + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + ) + return result + return run + + +@pytest.fixture +def cleanup_reservations(): + """Track and cleanup reservations after test""" + created = [] + yield created + + # Cleanup + for res_id in created: + try: + subprocess.run( + ["gpu-dev", "cancel", res_id, "--force"], + capture_output=True, + timeout=30, + ) + except Exception as e: + print(f"Warning: Failed to cleanup {res_id}: {e}") + + +@pytest.mark.e2e +@pytest.mark.slow +class TestReservationLifecycle: + """E2E tests for complete reservation lifecycle""" + + def test_reserve_wait_connect_cancel(self, gpu_dev_cli, cleanup_reservations): + """Should complete full reservation lifecycle""" + # Reserve + result = gpu_dev_cli( + "reserve", + "--gpu-type", "t4", + "--gpus", "1", + "--hours", "0.25", # 15 minutes + "--disk", "none", + "-y", # Skip confirmation + timeout=120, + ) + + assert result.returncode == 0, f"Reserve failed: {result.stderr}" + + # Extract reservation ID from output + # Expected format: "Reservation created: abc12345..." + res_id = None + for line in result.stdout.split("\n"): + if "reservation" in line.lower() and ":" in line: + parts = line.split(":") + if len(parts) >= 2: + res_id = parts[-1].strip()[:8] + break + + assert res_id is not None, "Could not extract reservation ID" + cleanup_reservations.append(res_id) + + # Wait for active status + max_wait = 180 # 3 minutes + start = time.time() + status = None + + while time.time() - start < max_wait: + list_result = gpu_dev_cli("list", "--json") + if list_result.returncode == 0: + import json + reservations = json.loads(list_result.stdout) + for res in reservations: + if res["reservation_id"].startswith(res_id): + status = res["status"] + if status == "active": + break + if status == "active": + break + time.sleep(5) + + assert status == "active", f"Reservation did not become active: {status}" + + # Get connection info + show_result = gpu_dev_cli("show", res_id) + assert show_result.returncode == 0 + assert "ssh" in show_result.stdout.lower() + + # Cancel + cancel_result = gpu_dev_cli("cancel", res_id, "--force") + assert cancel_result.returncode == 0 + + def test_reserve_with_disk(self, gpu_dev_cli, cleanup_reservations): + """Should create reservation with persistent disk""" + # First create a disk + disk_name = f"e2e-test-{int(time.time())}" + + create_result = gpu_dev_cli("disk", "create", disk_name) + assert create_result.returncode == 0 + + try: + # Reserve with disk + result = gpu_dev_cli( + "reserve", + "--gpu-type", "t4", + "--gpus", "1", + "--hours", "0.25", + "--disk", disk_name, + "-y", + timeout=120, + ) + + assert result.returncode == 0 + # Extract and track reservation ID for cleanup + # ... (similar to above) + + finally: + # Cleanup disk + gpu_dev_cli("disk", "delete", disk_name, "--force") + + +@pytest.mark.e2e +class TestDiskManagement: + """E2E tests for disk management""" + + def test_disk_create_list_delete(self, gpu_dev_cli): + """Should create, list, and delete a disk""" + disk_name = f"e2e-disk-{int(time.time())}" + + # Create + create_result = gpu_dev_cli("disk", "create", disk_name) + assert create_result.returncode == 0 + + try: + # List + list_result = gpu_dev_cli("disk", "list") + assert list_result.returncode == 0 + assert disk_name in list_result.stdout + + finally: + # Delete + delete_result = gpu_dev_cli("disk", "delete", disk_name, "--force") + assert delete_result.returncode == 0 + + def test_disk_rename(self, gpu_dev_cli): + """Should rename a disk""" + old_name = f"e2e-old-{int(time.time())}" + new_name = f"e2e-new-{int(time.time())}" + + # Create + gpu_dev_cli("disk", "create", old_name) + + try: + # Rename + rename_result = gpu_dev_cli("disk", "rename", old_name, new_name) + assert rename_result.returncode == 0 + + # Verify + list_result = gpu_dev_cli("disk", "list") + assert new_name in list_result.stdout + assert old_name not in list_result.stdout + + finally: + # Cleanup (use new name) + gpu_dev_cli("disk", "delete", new_name, "--force") + + +@pytest.mark.e2e +class TestAvailability: + """E2E tests for availability checking""" + + def test_avail_command(self, gpu_dev_cli): + """Should show GPU availability""" + result = gpu_dev_cli("avail") + + assert result.returncode == 0 + # Should show at least one GPU type + assert any(gpu in result.stdout.lower() for gpu in ["t4", "l4", "a100", "h100"]) + + +@pytest.mark.e2e +class TestCLICommands: + """E2E tests for basic CLI commands""" + + def test_config_show(self, gpu_dev_cli): + """Should show configuration""" + result = gpu_dev_cli("config", "show") + + assert result.returncode == 0 + assert "github" in result.stdout.lower() or "region" in result.stdout.lower() + + def test_list_command(self, gpu_dev_cli): + """Should list reservations""" + result = gpu_dev_cli("list") + + # Should succeed even with no reservations + assert result.returncode == 0 + + def test_help_command(self, gpu_dev_cli): + """Should show help""" + result = gpu_dev_cli("--help") + + assert result.returncode == 0 + assert "reserve" in result.stdout + assert "list" in result.stdout + + +@pytest.mark.e2e +@pytest.mark.slow +class TestMultinodeReservation: + """E2E tests for multinode reservations""" + + def test_multinode_reserve(self, gpu_dev_cli, cleanup_reservations): + """Should create multinode reservation (requires 16 GPUs)""" + pytest.skip("Multinode requires 16 GPUs - run manually when capacity available") + + result = gpu_dev_cli( + "reserve", + "--gpu-type", "h100", + "--gpus", "16", + "--distributed", + "--hours", "0.5", + "--disk", "none", + "-y", + timeout=300, + ) + + assert result.returncode == 0 + + +@pytest.mark.e2e +@pytest.mark.slow +class TestJupyterIntegration: + """E2E tests for Jupyter integration""" + + def test_reserve_with_jupyter(self, gpu_dev_cli, cleanup_reservations): + """Should create reservation with Jupyter enabled""" + result = gpu_dev_cli( + "reserve", + "--gpu-type", "t4", + "--gpus", "1", + "--hours", "0.25", + "--jupyter", + "--disk", "none", + "-y", + timeout=120, + ) + + assert result.returncode == 0 + + # Wait for active and check Jupyter URL in show output + # ... (implementation similar to lifecycle test) diff --git a/tests/e2e/test_reservation_flow.py b/tests/e2e/test_reservation_flow.py new file mode 100644 index 00000000..0dab7b34 --- /dev/null +++ b/tests/e2e/test_reservation_flow.py @@ -0,0 +1,484 @@ +""" +End-to-end tests for GPU reservation flow + +These tests run against a real AWS dev cluster (us-west-1). +Requires: +- RUN_E2E_TESTS=1 environment variable +- Valid AWS credentials with gpu-dev access +- E2E_GITHUB_USER set to a valid GitHub username +""" + +import os +import subprocess +import time +from datetime import datetime, timezone + +import pytest + + +# Skip all E2E tests if not explicitly enabled +pytestmark = pytest.mark.skipif( + not os.environ.get("RUN_E2E_TESTS"), + reason="E2E tests require RUN_E2E_TESTS=1" +) + + +@pytest.fixture(scope="module") +def cli_config(): + """Set up CLI to use test environment""" + # Switch to test environment + result = subprocess.run( + ["gpu-dev", "config", "environment", "test"], + capture_output=True, + text=True, + ) + assert result.returncode == 0, f"Failed to set test environment: {result.stderr}" + + yield { + "region": "us-west-1", + "environment": "test", + } + + +@pytest.fixture +def cleanup_reservations(): + """Track and cleanup reservations after tests""" + created_reservations = [] + + yield created_reservations + + # Cleanup + for res_id in created_reservations: + try: + subprocess.run( + ["gpu-dev", "cancel", res_id, "--force"], + capture_output=True, + timeout=60, + ) + except Exception as e: + print(f"Warning: Failed to cleanup reservation {res_id}: {e}") + + +class TestBasicReservation: + """Tests for basic single-GPU reservations""" + + @pytest.mark.e2e + @pytest.mark.timeout(300) + def test_reserve_single_gpu_t4(self, cli_config, cleanup_reservations): + """Should reserve 1 T4 GPU successfully""" + result = subprocess.run( + [ + "gpu-dev", "reserve", + "--gpus", "1", + "--gpu-type", "t4", + "--hours", "0.25", # 15 minutes + "--name", "e2e-test-t4", + "--no-wait", # Don't wait for pod to be ready + ], + capture_output=True, + text=True, + timeout=120, + ) + + assert result.returncode == 0, f"Reserve failed: {result.stderr}" + + # Extract reservation ID from output + output = result.stdout + assert "reservation" in output.lower() or "queued" in output.lower() + + # List reservations to get ID + list_result = subprocess.run( + ["gpu-dev", "list", "--status", "all"], + capture_output=True, + text=True, + ) + assert list_result.returncode == 0 + + # Find the test reservation + lines = list_result.stdout.split("\n") + for line in lines: + if "e2e-test-t4" in line or "queued" in line.lower() or "pending" in line.lower(): + # Found our reservation, extract ID if visible + parts = line.split() + if parts: + cleanup_reservations.append(parts[0]) + break + + @pytest.mark.e2e + @pytest.mark.timeout(300) + def test_reserve_multiple_gpus(self, cli_config, cleanup_reservations): + """Should reserve multiple GPUs on same node""" + result = subprocess.run( + [ + "gpu-dev", "reserve", + "--gpus", "2", + "--gpu-type", "t4", + "--hours", "0.25", + "--name", "e2e-test-multi-gpu", + "--no-wait", + ], + capture_output=True, + text=True, + timeout=120, + ) + + # Should either succeed or queue (depending on availability) + assert result.returncode == 0 or "queued" in result.stdout.lower() + + +class TestJupyterIntegration: + """Tests for Jupyter Lab integration""" + + @pytest.mark.e2e + @pytest.mark.timeout(600) + @pytest.mark.slow + def test_reserve_with_jupyter(self, cli_config, cleanup_reservations): + """Should enable Jupyter when requested""" + result = subprocess.run( + [ + "gpu-dev", "reserve", + "--gpus", "1", + "--gpu-type", "t4", + "--hours", "0.25", + "--jupyter", + "--name", "e2e-test-jupyter", + "--no-wait", + ], + capture_output=True, + text=True, + timeout=120, + ) + + assert result.returncode == 0 or "queued" in result.stdout.lower() + + +class TestDiskManagement: + """Tests for persistent disk functionality""" + + @pytest.mark.e2e + @pytest.mark.timeout(120) + def test_list_disks(self, cli_config): + """Should list user's disks""" + result = subprocess.run( + ["gpu-dev", "disk", "list"], + capture_output=True, + text=True, + timeout=60, + ) + + assert result.returncode == 0 + # Output should contain table headers or "no disks" message + assert "disk" in result.stdout.lower() or "no disks" in result.stdout.lower() + + @pytest.mark.e2e + @pytest.mark.timeout(120) + def test_create_and_delete_disk(self, cli_config): + """Should create and delete a disk""" + disk_name = f"e2e-test-{int(time.time())}" + + # Create disk + create_result = subprocess.run( + ["gpu-dev", "disk", "create", disk_name], + capture_output=True, + text=True, + timeout=60, + ) + + # Should succeed or say already exists + if create_result.returncode == 0: + # Wait for disk to appear + time.sleep(5) + + # Verify disk exists + list_result = subprocess.run( + ["gpu-dev", "disk", "list"], + capture_output=True, + text=True, + ) + assert disk_name in list_result.stdout + + # Delete disk + delete_result = subprocess.run( + ["gpu-dev", "disk", "delete", disk_name, "--yes"], + capture_output=True, + text=True, + timeout=60, + ) + assert delete_result.returncode == 0 + + +class TestAvailabilityChecks: + """Tests for GPU availability information""" + + @pytest.mark.e2e + @pytest.mark.timeout(60) + def test_check_availability(self, cli_config): + """Should show GPU availability by type""" + result = subprocess.run( + ["gpu-dev", "avail"], + capture_output=True, + text=True, + timeout=30, + ) + + assert result.returncode == 0 + # Should show GPU types + output_lower = result.stdout.lower() + assert "t4" in output_lower or "gpu" in output_lower + + @pytest.mark.e2e + @pytest.mark.timeout(60) + def test_cluster_status(self, cli_config): + """Should show cluster status""" + result = subprocess.run( + ["gpu-dev", "status"], + capture_output=True, + text=True, + timeout=30, + ) + + assert result.returncode == 0 + + +class TestCancellation: + """Tests for reservation cancellation""" + + @pytest.mark.e2e + @pytest.mark.timeout(300) + def test_cancel_reservation(self, cli_config): + """Should cancel a reservation""" + # First create a reservation + reserve_result = subprocess.run( + [ + "gpu-dev", "reserve", + "--gpus", "1", + "--gpu-type", "t4", + "--hours", "0.25", + "--name", "e2e-test-cancel", + "--no-wait", + ], + capture_output=True, + text=True, + timeout=120, + ) + + if reserve_result.returncode != 0: + pytest.skip("Could not create reservation to cancel") + + # Wait briefly + time.sleep(5) + + # List and get reservation ID + list_result = subprocess.run( + ["gpu-dev", "list"], + capture_output=True, + text=True, + ) + + # Cancel all test reservations + cancel_result = subprocess.run( + ["gpu-dev", "cancel", "--all", "--force"], + capture_output=True, + text=True, + timeout=120, + ) + + # Should succeed or say no reservations + assert cancel_result.returncode == 0 or "no" in cancel_result.stdout.lower() + + +class TestExtendReservation: + """Tests for reservation extension""" + + @pytest.mark.e2e + @pytest.mark.timeout(600) + @pytest.mark.slow + def test_extend_active_reservation(self, cli_config, cleanup_reservations): + """Should extend an active reservation""" + # Create a short reservation + reserve_result = subprocess.run( + [ + "gpu-dev", "reserve", + "--gpus", "1", + "--gpu-type", "t4", + "--hours", "0.5", + "--name", "e2e-test-extend", + ], + capture_output=True, + text=True, + timeout=300, + ) + + if reserve_result.returncode != 0: + pytest.skip("Could not create reservation to extend") + + # Wait for reservation to become active + time.sleep(10) + + # Get the reservation ID + list_result = subprocess.run( + ["gpu-dev", "list", "--details"], + capture_output=True, + text=True, + ) + + # Try to extend (may fail if not active yet) + extend_result = subprocess.run( + ["gpu-dev", "edit", "--extend", "1"], # Extend by 1 hour + capture_output=True, + text=True, + timeout=60, + ) + + # Should succeed or say reservation not found/not active + assert extend_result.returncode == 0 or "not active" in extend_result.stderr.lower() + + +class TestSSHAccess: + """Tests for SSH connectivity""" + + @pytest.mark.e2e + @pytest.mark.timeout(600) + @pytest.mark.slow + def test_ssh_connection_info(self, cli_config, cleanup_reservations): + """Should provide valid SSH connection info""" + # Create reservation and wait for it to be active + reserve_result = subprocess.run( + [ + "gpu-dev", "reserve", + "--gpus", "1", + "--gpu-type", "t4", + "--hours", "0.5", + "--name", "e2e-test-ssh", + ], + capture_output=True, + text=True, + timeout=300, + ) + + if reserve_result.returncode != 0: + pytest.skip("Could not create reservation for SSH test") + + # Output should contain SSH command + output = reserve_result.stdout + assert "ssh" in output.lower() or "connecting" in output.lower() + + @pytest.mark.e2e + @pytest.mark.timeout(30) + def test_ssh_config_generation(self, cli_config): + """Should generate SSH config file""" + # Check if config files exist + import os + from pathlib import Path + + devgpu_dir = Path.home() / ".devgpu" + if devgpu_dir.exists(): + configs = list(devgpu_dir.glob("*-sshconfig")) + # Just verify the directory structure is correct + assert devgpu_dir.is_dir() + + +class TestMultinodeReservation: + """Tests for multinode (distributed) reservations""" + + @pytest.mark.e2e + @pytest.mark.timeout(900) + @pytest.mark.slow + def test_multinode_reservation_creates_multiple_pods(self, cli_config, cleanup_reservations): + """Should create multiple pods for distributed reservation""" + # This test requires H100/B200 nodes which may not be available in test + result = subprocess.run( + [ + "gpu-dev", "reserve", + "--gpus", "8", + "--gpu-type", "h100", + "--distributed", + "--hours", "0.5", + "--name", "e2e-test-multinode", + "--no-wait", + ], + capture_output=True, + text=True, + timeout=120, + ) + + # May succeed, queue, or fail due to no H100 availability + # We just verify the command is accepted + assert result.returncode == 0 or "queued" in result.stdout.lower() or "not available" in result.stderr.lower() + + +class TestCustomDockerImage: + """Tests for custom Docker image support""" + + @pytest.mark.e2e + @pytest.mark.timeout(600) + @pytest.mark.slow + def test_reserve_with_custom_image(self, cli_config, cleanup_reservations): + """Should accept custom Docker image""" + result = subprocess.run( + [ + "gpu-dev", "reserve", + "--gpus", "1", + "--gpu-type", "t4", + "--dockerimage", "pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel", + "--hours", "0.25", + "--name", "e2e-test-custom-image", + "--no-wait", + ], + capture_output=True, + text=True, + timeout=120, + ) + + # Should accept the custom image + assert result.returncode == 0 or "queued" in result.stdout.lower() + + +class TestErrorHandling: + """Tests for error handling and validation""" + + @pytest.mark.e2e + @pytest.mark.timeout(60) + def test_invalid_gpu_type_rejected(self, cli_config): + """Should reject invalid GPU types""" + result = subprocess.run( + [ + "gpu-dev", "reserve", + "--gpus", "1", + "--gpu-type", "invalid_gpu_type", + "--hours", "1", + ], + capture_output=True, + text=True, + timeout=30, + ) + + # Should fail with error message + assert result.returncode != 0 or "invalid" in result.stderr.lower() or "error" in result.stderr.lower() + + @pytest.mark.e2e + @pytest.mark.timeout(60) + def test_excessive_gpus_rejected(self, cli_config): + """Should reject request for more GPUs than available on node""" + result = subprocess.run( + [ + "gpu-dev", "reserve", + "--gpus", "100", # Way more than any node has + "--gpu-type", "t4", + "--hours", "1", + ], + capture_output=True, + text=True, + timeout=30, + ) + + # Should fail or warn + assert result.returncode != 0 or "error" in result.stderr.lower() or "max" in result.stderr.lower() + + @pytest.mark.e2e + @pytest.mark.timeout(60) + def test_missing_github_user_rejected(self, cli_config, tmp_path): + """Should require GitHub username to be configured""" + # This test would require a fresh config, which is complex to set up + # Skip for now as it requires special setup + pytest.skip("Requires isolated config environment") diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..33bc9d8f --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1 @@ +"""Test fixtures for ODC""" diff --git a/tests/requirements-test.txt b/tests/requirements-test.txt new file mode 100644 index 00000000..116d3a38 --- /dev/null +++ b/tests/requirements-test.txt @@ -0,0 +1,21 @@ +# Test dependencies +pytest>=7.4.0 +pytest-asyncio>=0.21.0 +pytest-cov>=4.1.0 +pytest-mock>=3.11.0 + +# Mocking +unittest-mock>=1.4.0 + +# Type checking +mypy>=1.5.0 + +# Dependencies from services (for imports) +fastapi>=0.103.0 +asyncpg>=0.28.0 +aioboto3>=11.0.0 +pydantic>=2.3.0 +psycopg2-binary>=2.9.0 +kubernetes>=27.0.0 +boto3>=1.28.0 +botocore>=1.31.0 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..143330ed --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for ODC""" diff --git a/tests/unit/cli/__init__.py b/tests/unit/cli/__init__.py new file mode 100644 index 00000000..d78c7be6 --- /dev/null +++ b/tests/unit/cli/__init__.py @@ -0,0 +1 @@ +"""Unit tests for GPU Dev CLI""" diff --git a/tests/unit/cli/test_auth.py b/tests/unit/cli/test_auth.py new file mode 100644 index 00000000..bdbd0b6d --- /dev/null +++ b/tests/unit/cli/test_auth.py @@ -0,0 +1,208 @@ +""" +Unit tests for gpu_dev_cli.auth module + +Tests: +- AWS authentication +- GitHub SSH key validation +""" + +import subprocess +from unittest.mock import MagicMock, patch + +import pytest + + +class TestAuthenticateUser: + """Tests for authenticate_user function""" + + def test_authenticate_returns_user_info_on_success(self): + """Should return user info when AWS auth succeeds""" + mock_config = MagicMock() + mock_config.get_user_identity.return_value = { + "user_id": "AIDAEXAMPLE", + "account": "123456789012", + "arn": "arn:aws:iam::123456789012:user/testuser", + } + mock_config.get_queue_url.return_value = "https://sqs.us-east-2.amazonaws.com/123456789012/queue" + mock_config.get_github_username.return_value = "githubuser" + + from gpu_dev_cli.auth import authenticate_user + result = authenticate_user(mock_config) + + assert result["user_id"] == "testuser" + assert result["github_user"] == "githubuser" + assert "arn" in result + + def test_authenticate_raises_when_github_not_configured(self): + """Should raise RuntimeError when github_user not set""" + mock_config = MagicMock() + mock_config.get_user_identity.return_value = { + "user_id": "AIDAEXAMPLE", + "account": "123456789012", + "arn": "arn:aws:iam::123456789012:user/testuser", + } + mock_config.get_queue_url.return_value = "https://sqs.us-east-2.amazonaws.com/123456789012/queue" + mock_config.get_github_username.return_value = None + + from gpu_dev_cli.auth import authenticate_user + + with pytest.raises(RuntimeError, match="GitHub username not configured"): + authenticate_user(mock_config) + + def test_authenticate_raises_on_aws_error(self): + """Should raise RuntimeError on AWS authentication failure""" + mock_config = MagicMock() + mock_config.get_user_identity.side_effect = Exception("Invalid credentials") + + from gpu_dev_cli.auth import authenticate_user + + with pytest.raises(RuntimeError, match="AWS authentication failed"): + authenticate_user(mock_config) + + +class TestValidateSshKeyMatchesGithubUser: + """Tests for validate_ssh_key_matches_github_user function""" + + def test_returns_valid_when_ssh_user_matches_config(self): + """Should return valid=True when SSH user matches configured user""" + mock_config = MagicMock() + mock_config.get_github_username.return_value = "myuser" + + # Mock subprocess to return successful GitHub SSH response + with patch("subprocess.run") as mock_run: + # Create a mock tempfile + with patch("tempfile.NamedTemporaryFile") as mock_tempfile: + mock_file = MagicMock() + mock_file.name = "/tmp/test" + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + + # Simulate writing to temp file + class MockTempFile: + name = "/tmp/test" + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def seek(self, pos): + pass + + def read(self): + return "Hi myuser! You've successfully authenticated, but GitHub does not provide shell access." + + mock_tempfile.return_value = MockTempFile() + mock_run.return_value = MagicMock(returncode=1) + + with patch("os.unlink"): + from gpu_dev_cli.auth import validate_ssh_key_matches_github_user + result = validate_ssh_key_matches_github_user(mock_config) + + assert result["valid"] is True + assert result["configured_user"] == "myuser" + assert result["ssh_user"] == "myuser" + + def test_returns_invalid_when_ssh_user_differs(self): + """Should return valid=False when SSH user doesn't match""" + mock_config = MagicMock() + mock_config.get_github_username.return_value = "configureduser" + + with patch("subprocess.run") as mock_run: + with patch("tempfile.NamedTemporaryFile") as mock_tempfile: + class MockTempFile: + name = "/tmp/test" + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def seek(self, pos): + pass + + def read(self): + return "Hi differentuser! You've successfully authenticated, but GitHub does not provide shell access." + + mock_tempfile.return_value = MockTempFile() + mock_run.return_value = MagicMock(returncode=1) + + with patch("os.unlink"): + from gpu_dev_cli.auth import validate_ssh_key_matches_github_user + result = validate_ssh_key_matches_github_user(mock_config) + + assert result["valid"] is False + assert result["configured_user"] == "configureduser" + assert result["ssh_user"] == "differentuser" + assert "different" in result["error"].lower() + + def test_returns_error_when_github_not_configured(self): + """Should return error when GitHub username not configured""" + mock_config = MagicMock() + mock_config.get_github_username.return_value = None + + from gpu_dev_cli.auth import validate_ssh_key_matches_github_user + result = validate_ssh_key_matches_github_user(mock_config) + + assert result["valid"] is False + assert result["error"] is not None + assert "not configured" in result["error"] + + def test_returns_error_on_ssh_timeout(self): + """Should return error when SSH connection times out""" + mock_config = MagicMock() + mock_config.get_github_username.return_value = "myuser" + + with patch("subprocess.run") as mock_run: + with patch("tempfile.NamedTemporaryFile") as mock_tempfile: + class MockTempFile: + name = "/tmp/test" + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + mock_tempfile.return_value = MockTempFile() + mock_run.side_effect = subprocess.TimeoutExpired("ssh", 30) + + from gpu_dev_cli.auth import validate_ssh_key_matches_github_user + result = validate_ssh_key_matches_github_user(mock_config) + + assert result["valid"] is False + assert "timed out" in result["error"].lower() + + def test_case_insensitive_username_comparison(self): + """Should compare usernames case-insensitively""" + mock_config = MagicMock() + mock_config.get_github_username.return_value = "MyUser" + + with patch("subprocess.run") as mock_run: + with patch("tempfile.NamedTemporaryFile") as mock_tempfile: + class MockTempFile: + name = "/tmp/test" + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def seek(self, pos): + pass + + def read(self): + return "Hi myuser! You've successfully authenticated, but GitHub does not provide shell access." + + mock_tempfile.return_value = MockTempFile() + mock_run.return_value = MagicMock(returncode=1) + + with patch("os.unlink"): + from gpu_dev_cli.auth import validate_ssh_key_matches_github_user + result = validate_ssh_key_matches_github_user(mock_config) + + # Should be valid despite case difference + assert result["valid"] is True diff --git a/tests/unit/cli/test_availability.py b/tests/unit/cli/test_availability.py new file mode 100644 index 00000000..9ab427b6 --- /dev/null +++ b/tests/unit/cli/test_availability.py @@ -0,0 +1,358 @@ +""" +Unit tests for gpu_dev_cli availability command + +Tests: +- GPU availability fetching via API +- Display formatting +- Cluster status +""" + +import json +from datetime import datetime, timezone, timedelta +from unittest.mock import MagicMock, patch + +import pytest + + +class TestGPUAvailabilityFetching: + """Tests for fetching GPU availability via API""" + + def test_get_availability_calls_api(self): + """Should call API to get GPU availability""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_gpu_availability.return_value = { + "availability": { + "t4": { + "gpu_type": "t4", + "total": 8, + "available": 4, + "in_use": 4, + "queued": 0, + "max_per_node": 4, + }, + "h100": { + "gpu_type": "h100", + "total": 16, + "available": 0, + "in_use": 16, + "queued": 2, + "max_per_node": 8, + }, + }, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_gpu_availability_by_type() + + assert result is not None + mock_api_client.get_gpu_availability.assert_called_once() + + def test_availability_transforms_api_response(self): + """Should transform API response to expected format""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_gpu_availability.return_value = { + "availability": { + "h100": { + "gpu_type": "h100", + "total": 16, + "available": 8, + "in_use": 8, + "queued": 4, + "max_per_node": 8, + }, + }, + "timestamp": "2026-01-20T18:30:00Z", + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_gpu_availability_by_type() + + assert "h100" in result + h100 = result["h100"] + assert h100["available"] == 8 + assert h100["total"] == 16 + assert h100["queue_length"] == 4 + assert h100["max_reservable"] == 8 + assert h100["gpus_per_instance"] == 8 + + def test_availability_calculates_full_nodes(self): + """Should calculate number of full nodes available""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_gpu_availability.return_value = { + "availability": { + "h100": { + "gpu_type": "h100", + "total": 16, + "available": 16, + "in_use": 0, + "queued": 0, + "max_per_node": 8, + }, + }, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_gpu_availability_by_type() + + # 16 available / 8 max_per_node = 2 full nodes + assert result["h100"]["full_nodes_available"] == 2 + + def test_availability_calculates_estimated_wait(self): + """Should calculate estimated wait time from queue""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_gpu_availability.return_value = { + "availability": { + "h100": { + "gpu_type": "h100", + "total": 16, + "available": 0, + "in_use": 16, + "queued": 4, + "max_per_node": 8, + }, + }, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_gpu_availability_by_type() + + # 4 queued * 15 minutes = 60 minutes + assert result["h100"]["estimated_wait_minutes"] == 60 + + def test_availability_handles_api_error(self): + """Should return None on API error""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_gpu_availability.side_effect = Exception("API error") + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_gpu_availability_by_type() + + assert result is None + + +class TestStaticGPUConfig: + """Tests for static GPU configuration fallback""" + + def test_static_config_returns_known_types(self): + """Should return static config for known GPU types""" + mock_config = MagicMock() + mock_api_client = MagicMock() + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + + # Test A100 static config + config = manager._get_static_gpu_config("a100", queue_length=2, estimated_wait=30) + assert config["total"] == 16 + assert config["queue_length"] == 2 + assert config["estimated_wait_minutes"] == 30 + + # Test H100 static config + config = manager._get_static_gpu_config("h100", queue_length=0, estimated_wait=0) + assert config["total"] == 16 + + # Test T4 static config + config = manager._get_static_gpu_config("t4", queue_length=0, estimated_wait=0) + assert config["total"] == 8 + + def test_static_config_unknown_type(self): + """Should return zero values for unknown GPU types""" + mock_config = MagicMock() + mock_api_client = MagicMock() + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + config = manager._get_static_gpu_config("unknown-gpu", queue_length=0, estimated_wait=0) + + assert config["total"] == 0 + assert config["available"] == 0 + + +class TestClusterStatusAPI: + """Tests for cluster status via API""" + + def test_cluster_status_calls_api(self): + """Should call cluster status API endpoint""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_cluster_status.return_value = { + "total_gpus": 64, + "available_gpus": 32, + "in_use_gpus": 32, + "queued_gpus": 8, + "active_reservations": 10, + "queued_reservations": 2, + "pending_reservations": 1, + "preparing_reservations": 0, + "by_gpu_type": {}, + "timestamp": "2026-01-20T18:30:00Z", + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_cluster_status() + + mock_api_client.get_cluster_status.assert_called_once() + assert result["total_gpus"] == 64 + assert result["available_gpus"] == 32 + assert result["active_reservations"] == 10 + + def test_cluster_status_transforms_response(self): + """Should transform API response to CLI format""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_cluster_status.return_value = { + "total_gpus": 64, + "available_gpus": 24, + "in_use_gpus": 40, + "queued_gpus": 8, + "active_reservations": 8, + "queued_reservations": 3, + "pending_reservations": 1, + "preparing_reservations": 2, + "by_gpu_type": { + "h100": {"available": 8, "total": 16}, + "t4": {"available": 4, "total": 8}, + }, + "timestamp": "2026-01-20T18:30:00Z", + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_cluster_status() + + assert result["reserved_gpus"] == 40 + assert result["queue_length"] == 4 + assert result["queued_gpus"] == 8 + assert result["preparing_reservations"] == 2 + + +class TestAvailabilityMultipleGPUTypes: + """Tests for availability across multiple GPU types""" + + def test_availability_returns_all_gpu_types(self): + """Should return availability for all GPU types from API""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_gpu_availability.return_value = { + "availability": { + "t4": {"total": 8, "available": 4, "in_use": 4, "queued": 0, "max_per_node": 4}, + "l4": {"total": 8, "available": 8, "in_use": 0, "queued": 0, "max_per_node": 4}, + "a100": {"total": 16, "available": 8, "in_use": 8, "queued": 1, "max_per_node": 8}, + "h100": {"total": 16, "available": 0, "in_use": 16, "queued": 3, "max_per_node": 8}, + "h200": {"total": 16, "available": 16, "in_use": 0, "queued": 0, "max_per_node": 8}, + "b200": {"total": 8, "available": 8, "in_use": 0, "queued": 0, "max_per_node": 8}, + }, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_gpu_availability_by_type() + + assert len(result) == 6 + assert "t4" in result + assert "l4" in result + assert "a100" in result + assert "h100" in result + assert "h200" in result + assert "b200" in result + + def test_availability_empty_response(self): + """Should handle empty availability response""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_gpu_availability.return_value = { + "availability": {}, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_gpu_availability_by_type() + + assert result == {} + + +class TestAvailabilityCalculations: + """Tests for availability calculations""" + + def test_running_instances_calculation(self): + """Should calculate running instances from in_use GPUs""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_gpu_availability.return_value = { + "availability": { + "h100": { + "total": 16, + "available": 0, + "in_use": 16, + "queued": 0, + "max_per_node": 8, + }, + }, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_gpu_availability_by_type() + + # 16 in_use / 8 max_per_node = 2 running instances + assert result["h100"]["running_instances"] == 2 + + def test_zero_wait_when_available(self): + """Should show zero wait time when GPUs available""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_gpu_availability.return_value = { + "availability": { + "t4": { + "total": 8, + "available": 4, + "in_use": 4, + "queued": 0, + "max_per_node": 4, + }, + }, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_gpu_availability_by_type() + + assert result["t4"]["estimated_wait_minutes"] == 0 diff --git a/tests/unit/cli/test_cancel.py b/tests/unit/cli/test_cancel.py new file mode 100644 index 00000000..db162048 --- /dev/null +++ b/tests/unit/cli/test_cancel.py @@ -0,0 +1,176 @@ +""" +Unit tests for gpu_dev_cli cancel command + +Tests: +- Single reservation cancellation +- Cancel all reservations +- Validation and permissions +""" + +import json +from datetime import datetime, timezone +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest +from moto import mock_aws +import boto3 + + +class TestCancelSingle: + """Tests for cancelling a single reservation""" + + def test_cancel_builds_correct_message(self): + """Should build correct cancellation message""" + from gpu_dev_cli.reservations import build_cancel_request + + request = build_cancel_request( + reservation_id="res-123", + user_id="test-user", + ) + + assert request["type"] == "cancellation" + assert request["reservation_id"] == "res-123" + assert request["user_id"] == "test-user" + + def test_cancel_validates_ownership(self): + """Should only allow owner to cancel""" + from gpu_dev_cli.reservations import validate_cancel_permission + + reservation = { + "reservation_id": "res-123", + "user_id": "owner-user", + "status": "active", + } + + # Owner can cancel + validate_cancel_permission(reservation, "owner-user") + + # Non-owner cannot cancel + with pytest.raises(PermissionError): + validate_cancel_permission(reservation, "other-user") + + def test_cancel_validates_cancellable_status(self): + """Should only cancel reservations in cancellable state""" + from gpu_dev_cli.reservations import validate_cancel_permission + + # Already cancelled + cancelled_reservation = { + "reservation_id": "res-123", + "user_id": "test-user", + "status": "cancelled", + } + + with pytest.raises(ValueError, match="already"): + validate_cancel_permission(cancelled_reservation, "test-user") + + # Expired + expired_reservation = { + "reservation_id": "res-123", + "user_id": "test-user", + "status": "expired", + } + + with pytest.raises(ValueError, match="expired"): + validate_cancel_permission(expired_reservation, "test-user") + + +class TestCancelAll: + """Tests for cancelling all user reservations""" + + def test_cancel_all_finds_active_reservations(self): + """Should find all active/queued reservations for user""" + from gpu_dev_cli.reservations import get_cancellable_reservations + + reservations = [ + {"reservation_id": "res-1", "user_id": "test-user", "status": "active"}, + {"reservation_id": "res-2", "user_id": "test-user", "status": "queued"}, + {"reservation_id": "res-3", "user_id": "test-user", "status": "cancelled"}, + {"reservation_id": "res-4", "user_id": "test-user", "status": "expired"}, + {"reservation_id": "res-5", "user_id": "other-user", "status": "active"}, + ] + + cancellable = get_cancellable_reservations(reservations, "test-user") + + assert len(cancellable) == 2 + assert all(r["status"] in ["active", "queued"] for r in cancellable) + assert all(r["user_id"] == "test-user" for r in cancellable) + + def test_cancel_all_returns_empty_when_none_active(self): + """Should return empty list when no cancellable reservations""" + from gpu_dev_cli.reservations import get_cancellable_reservations + + reservations = [ + {"reservation_id": "res-1", "user_id": "test-user", "status": "cancelled"}, + {"reservation_id": "res-2", "user_id": "test-user", "status": "expired"}, + ] + + cancellable = get_cancellable_reservations(reservations, "test-user") + + assert len(cancellable) == 0 + + +class TestCancelConfirmation: + """Tests for cancel confirmation handling""" + + def test_force_flag_skips_confirmation(self): + """Should skip confirmation when --force is used""" + from gpu_dev_cli.reservations import should_prompt_confirmation + + assert should_prompt_confirmation(force=True) is False + assert should_prompt_confirmation(force=False) is True + + def test_cancel_message_includes_reservation_details(self): + """Should include reservation details in confirmation message""" + from gpu_dev_cli.reservations import format_cancel_confirmation + + reservation = { + "reservation_id": "abc12345-uuid", + "gpu_type": "h100", + "gpu_count": 4, + "status": "active", + } + + message = format_cancel_confirmation(reservation) + + assert "abc12345" in message + assert "h100" in message + assert "4" in message + + +class TestCancelResultHandling: + """Tests for handling cancel results""" + + def test_cancel_success_response(self): + """Should handle successful cancel response""" + from gpu_dev_cli.reservations import parse_cancel_response + + response = { + "statusCode": 200, + "body": json.dumps({ + "message": "Reservation cancelled", + "reservation_id": "res-123", + "status": "cancelling", + }), + } + + result = parse_cancel_response(response) + + assert result["success"] is True + assert result["status"] == "cancelling" + + def test_cancel_failure_response(self): + """Should handle failed cancel response""" + from gpu_dev_cli.reservations import parse_cancel_response + + response = { + "statusCode": 400, + "body": json.dumps({ + "error": "Reservation not found", + }), + } + + result = parse_cancel_response(response) + + assert result["success"] is False + assert "error" in result diff --git a/tests/unit/cli/test_config.py b/tests/unit/cli/test_config.py new file mode 100644 index 00000000..9d5a102f --- /dev/null +++ b/tests/unit/cli/test_config.py @@ -0,0 +1,250 @@ +""" +Unit tests for gpu_dev_cli.config module + +Tests: +- Config initialization and migration +- Environment switching +- Config file operations +- AWS session creation +""" + +import json +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + + +class TestConfigInit: + """Tests for Config class initialization""" + + def test_config_creates_default_file_if_missing(self, tmp_path): + """Config should create config.json with defaults if not exists""" + config_file = tmp_path / ".config" / "gpu-dev" / "config.json" + + with patch("gpu_dev_cli.config.Config.CONFIG_FILE", config_file): + with patch("gpu_dev_cli.config.Config.LEGACY_CONFIG_FILE", tmp_path / ".gpu-dev-config"): + with patch("gpu_dev_cli.config.Config.LEGACY_ENVIRONMENT_FILE", tmp_path / ".gpu-dev-environment.json"): + with patch("boto3.Session"): + from gpu_dev_cli.config import Config + config = Config() + + assert config_file.exists() + saved_config = json.loads(config_file.read_text()) + assert "region" in saved_config + assert "environment" in saved_config + + def test_config_loads_existing_file(self, tmp_path): + """Config should load existing config.json""" + config_file = tmp_path / ".config" / "gpu-dev" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text(json.dumps({ + "github_user": "myuser", + "region": "eu-west-1", + "environment": "prod", + "workspace": "prod", + })) + + with patch("gpu_dev_cli.config.Config.CONFIG_FILE", config_file): + with patch("gpu_dev_cli.config.Config.LEGACY_CONFIG_FILE", tmp_path / ".gpu-dev-config"): + with patch("gpu_dev_cli.config.Config.LEGACY_ENVIRONMENT_FILE", tmp_path / ".gpu-dev-environment.json"): + with patch("boto3.Session"): + from gpu_dev_cli.config import Config + config = Config() + + assert config.get_github_username() == "myuser" + assert config.aws_region == "eu-west-1" + + def test_config_migrates_legacy_files(self, tmp_path): + """Config should migrate from legacy config files""" + config_file = tmp_path / ".config" / "gpu-dev" / "config.json" + legacy_config = tmp_path / ".gpu-dev-config" + legacy_env = tmp_path / ".gpu-dev-environment.json" + + # Create legacy files + legacy_config.write_text(json.dumps({"github_user": "legacyuser"})) + legacy_env.write_text(json.dumps({"region": "ap-northeast-1"})) + + with patch("gpu_dev_cli.config.Config.CONFIG_FILE", config_file): + with patch("gpu_dev_cli.config.Config.LEGACY_CONFIG_FILE", legacy_config): + with patch("gpu_dev_cli.config.Config.LEGACY_ENVIRONMENT_FILE", legacy_env): + with patch("boto3.Session"): + from gpu_dev_cli.config import Config + config = Config() + + # Should have migrated values + assert config.get_github_username() == "legacyuser" + assert config.aws_region == "ap-northeast-1" + + +class TestConfigEnvironments: + """Tests for environment switching""" + + def test_set_environment_updates_region(self, tmp_path): + """set_environment should update region based on env""" + config_file = tmp_path / ".config" / "gpu-dev" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text(json.dumps({ + "region": "us-east-2", + "environment": "prod", + "workspace": "prod", + })) + + with patch("gpu_dev_cli.config.Config.CONFIG_FILE", config_file): + with patch("gpu_dev_cli.config.Config.LEGACY_CONFIG_FILE", tmp_path / ".gpu-dev-config"): + with patch("gpu_dev_cli.config.Config.LEGACY_ENVIRONMENT_FILE", tmp_path / ".gpu-dev-environment.json"): + with patch("boto3.Session"): + from gpu_dev_cli.config import Config + config = Config() + + result = config.set_environment("test") + + assert config.aws_region == "us-west-1" + assert result["region"] == "us-west-1" + + def test_set_invalid_environment_raises(self, tmp_path): + """set_environment should raise ValueError for invalid env""" + config_file = tmp_path / ".config" / "gpu-dev" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text(json.dumps({ + "region": "us-east-2", + "environment": "prod", + "workspace": "prod", + })) + + with patch("gpu_dev_cli.config.Config.CONFIG_FILE", config_file): + with patch("gpu_dev_cli.config.Config.LEGACY_CONFIG_FILE", tmp_path / ".gpu-dev-config"): + with patch("gpu_dev_cli.config.Config.LEGACY_ENVIRONMENT_FILE", tmp_path / ".gpu-dev-environment.json"): + with patch("boto3.Session"): + from gpu_dev_cli.config import Config + config = Config() + + with pytest.raises(ValueError, match="Invalid environment"): + config.set_environment("staging") + + +class TestConfigOperations: + """Tests for config save/load operations""" + + def test_save_config_persists_value(self, tmp_path): + """save_config should write to file""" + config_file = tmp_path / ".config" / "gpu-dev" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text(json.dumps({ + "region": "us-east-2", + "environment": "prod", + "workspace": "prod", + })) + + with patch("gpu_dev_cli.config.Config.CONFIG_FILE", config_file): + with patch("gpu_dev_cli.config.Config.LEGACY_CONFIG_FILE", tmp_path / ".gpu-dev-config"): + with patch("gpu_dev_cli.config.Config.LEGACY_ENVIRONMENT_FILE", tmp_path / ".gpu-dev-environment.json"): + with patch("boto3.Session"): + from gpu_dev_cli.config import Config + config = Config() + + config.save_config("github_user", "newuser") + + # Re-read file + saved = json.loads(config_file.read_text()) + assert saved["github_user"] == "newuser" + + def test_get_returns_saved_value(self, tmp_path): + """get should return saved config value""" + config_file = tmp_path / ".config" / "gpu-dev" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text(json.dumps({ + "github_user": "saveduser", + "region": "us-east-2", + "environment": "prod", + "workspace": "prod", + })) + + with patch("gpu_dev_cli.config.Config.CONFIG_FILE", config_file): + with patch("gpu_dev_cli.config.Config.LEGACY_CONFIG_FILE", tmp_path / ".gpu-dev-config"): + with patch("gpu_dev_cli.config.Config.LEGACY_ENVIRONMENT_FILE", tmp_path / ".gpu-dev-environment.json"): + with patch("boto3.Session"): + from gpu_dev_cli.config import Config + config = Config() + + assert config.get("github_user") == "saveduser" + assert config.get("nonexistent") is None + + +class TestAWSSession: + """Tests for AWS session creation""" + + def test_creates_session_with_gpu_dev_profile_if_available(self, tmp_path): + """Should try gpu-dev profile first""" + config_file = tmp_path / ".config" / "gpu-dev" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text(json.dumps({ + "region": "us-east-2", + "environment": "prod", + "workspace": "prod", + })) + + mock_session = MagicMock() + mock_session.get_credentials.return_value = MagicMock() + + with patch("gpu_dev_cli.config.Config.CONFIG_FILE", config_file): + with patch("gpu_dev_cli.config.Config.LEGACY_CONFIG_FILE", tmp_path / ".gpu-dev-config"): + with patch("gpu_dev_cli.config.Config.LEGACY_ENVIRONMENT_FILE", tmp_path / ".gpu-dev-environment.json"): + with patch("boto3.Session", return_value=mock_session) as mock_session_class: + from gpu_dev_cli.config import Config + config = Config() + + # First call should be with gpu-dev profile + mock_session_class.assert_any_call(profile_name="gpu-dev") + + def test_falls_back_to_default_session(self, tmp_path): + """Should fall back to default session if gpu-dev profile fails""" + config_file = tmp_path / ".config" / "gpu-dev" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text(json.dumps({ + "region": "us-east-2", + "environment": "prod", + "workspace": "prod", + })) + + def session_side_effect(*args, **kwargs): + if kwargs.get("profile_name") == "gpu-dev": + raise Exception("Profile not found") + return MagicMock() + + with patch("gpu_dev_cli.config.Config.CONFIG_FILE", config_file): + with patch("gpu_dev_cli.config.Config.LEGACY_CONFIG_FILE", tmp_path / ".gpu-dev-config"): + with patch("gpu_dev_cli.config.Config.LEGACY_ENVIRONMENT_FILE", tmp_path / ".gpu-dev-environment.json"): + with patch("boto3.Session", side_effect=session_side_effect): + from gpu_dev_cli.config import Config + config = Config() + + # Should have created a session (fell back to default) + assert config.session is not None + + +class TestResourceNaming: + """Tests for resource naming conventions""" + + def test_resource_names_use_prefix(self, tmp_path): + """Resource names should use pytorch-gpu-dev prefix""" + config_file = tmp_path / ".config" / "gpu-dev" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text(json.dumps({ + "region": "us-east-2", + "environment": "prod", + "workspace": "prod", + })) + + with patch("gpu_dev_cli.config.Config.CONFIG_FILE", config_file): + with patch("gpu_dev_cli.config.Config.LEGACY_CONFIG_FILE", tmp_path / ".gpu-dev-config"): + with patch("gpu_dev_cli.config.Config.LEGACY_ENVIRONMENT_FILE", tmp_path / ".gpu-dev-environment.json"): + with patch("boto3.Session"): + from gpu_dev_cli.config import Config + config = Config() + + assert config.prefix == "pytorch-gpu-dev" + # Note: queue_name removed - dev branch uses PGMQ (PostgreSQL queue) + # instead of SQS, managed by API service not CLI + assert "pytorch-gpu-dev" in config.cluster_name diff --git a/tests/unit/cli/test_connect_show.py b/tests/unit/cli/test_connect_show.py new file mode 100644 index 00000000..fb6a2e8d --- /dev/null +++ b/tests/unit/cli/test_connect_show.py @@ -0,0 +1,346 @@ +""" +Unit tests for gpu_dev_cli connect and show commands + +Tests: +- SSH config generation +- Connection info parsing +- Display formatting +""" + +import json +from datetime import datetime, timezone, timedelta +from unittest.mock import MagicMock, patch + +import pytest + + +class TestSSHConfigGeneration: + """Tests for SSH config file generation""" + + def test_generate_ssh_config(self): + """Should generate valid SSH config content""" + from gpu_dev_cli.reservations import _generate_ssh_config + + config = _generate_ssh_config("myhost.devservers.io", "gpu-dev-abc123") + + assert "Host gpu-dev-abc123" in config + assert "HostName myhost.devservers.io" in config + assert "User dev" in config + assert "ForwardAgent yes" in config + assert "StrictHostKeyChecking no" in config + + def test_generate_ssh_config_with_port(self): + """Should include port in SSH config""" + from gpu_dev_cli.reservations import _generate_ssh_config + + config = _generate_ssh_config("myhost.devservers.io", "gpu-dev-abc123", port=30123) + + assert "Port 30123" in config + + def test_get_ssh_config_path(self): + """Should return correct SSH config path""" + from gpu_dev_cli.reservations import get_ssh_config_path + + path = get_ssh_config_path("abc12345-uuid-uuid-uuid") + + assert ".gpu-dev" in path + assert "abc12345" in path + assert "sshconfig" in path + + def test_get_ssh_config_path_with_name(self): + """Should include name in SSH config path when provided""" + from gpu_dev_cli.reservations import get_ssh_config_path + + path = get_ssh_config_path("abc12345-uuid", name="my-experiment") + + assert "abc12345" in path + + +class TestConnectionInfo: + """Tests for connection info parsing and building""" + + def test_extract_ip_from_reservation_with_port(self): + """Should extract IP:Port from reservation data""" + from gpu_dev_cli.reservations import _extract_ip_from_reservation + + reservation = { + "node_ip": "10.0.1.100", + "node_port": 30123, + } + + result = _extract_ip_from_reservation(reservation) + assert result == "10.0.1.100:30123" + + def test_extract_ip_from_reservation_no_port(self): + """Should extract IP without port when not available""" + from gpu_dev_cli.reservations import _extract_ip_from_reservation + + reservation = { + "node_ip": "10.0.1.100", + } + + result = _extract_ip_from_reservation(reservation) + assert result == "10.0.1.100" + + def test_extract_ip_missing_data(self): + """Should return N/A when no IP data available""" + from gpu_dev_cli.reservations import _extract_ip_from_reservation + + assert _extract_ip_from_reservation({}) == "N/A" + assert _extract_ip_from_reservation({"status": "active"}) == "N/A" + + def test_build_ssh_command(self): + """Should build correct SSH command from reservation""" + from gpu_dev_cli.reservations import _build_ssh_command + + reservation = { + "node_ip": "10.0.1.100", + "node_port": 30123, + } + + cmd = _build_ssh_command(reservation) + + assert "ssh" in cmd + assert "10.0.1.100" in cmd + assert "30123" in cmd + assert "dev@" in cmd + + def test_add_agent_forwarding(self): + """Should add -A flag to SSH command""" + from gpu_dev_cli.reservations import _add_agent_forwarding_to_ssh + + result = _add_agent_forwarding_to_ssh("ssh dev@10.0.1.100 -p 30123") + + assert "-A" in result + assert result == "ssh -A dev@10.0.1.100 -p 30123" + + def test_add_agent_forwarding_no_duplicate(self): + """Should not add -A if already present""" + from gpu_dev_cli.reservations import _add_agent_forwarding_to_ssh + + result = _add_agent_forwarding_to_ssh("ssh -A dev@10.0.1.100 -p 30123") + assert result.count("-A") == 1 + + +class TestIDEIntegration: + """Tests for IDE link generation""" + + def test_make_vscode_link(self): + """Should generate VS Code remote SSH link""" + from gpu_dev_cli.reservations import _make_vscode_link + + link = _make_vscode_link("gpu-dev-abc123") + + assert link == "vscode://vscode-remote/ssh-remote+gpu-dev-abc123/home/dev" + assert link.startswith("vscode://") + + def test_make_cursor_link(self): + """Should generate Cursor IDE remote SSH link""" + from gpu_dev_cli.reservations import _make_cursor_link + + link = _make_cursor_link("gpu-dev-abc123") + + assert link == "cursor://vscode-remote/ssh-remote+gpu-dev-abc123/home/dev" + assert link.startswith("cursor://") + + def test_generate_vscode_command(self): + """Should generate VS Code CLI command""" + from gpu_dev_cli.reservations import _generate_vscode_command + + cmd = _generate_vscode_command("ssh dev@myhost.io -p 30001") + + assert cmd is not None + assert "code --remote" in cmd + assert "myhost.io" in cmd + assert "ForwardAgent=yes" in cmd + + def test_generate_vscode_command_invalid(self): + """Should return None for invalid SSH command""" + from gpu_dev_cli.reservations import _generate_vscode_command + + assert _generate_vscode_command("") is None + assert _generate_vscode_command("not-ssh") is None + + +class TestShowCommand: + """Tests for show command functionality""" + + def test_get_connection_info_active(self): + """Should get connection info for active reservation""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_job_status.return_value = { + "reservation_id": "abc12345-uuid-uuid", + "user_id": "testuser", + "status": "active", + "gpu_type": "h100", + "gpu_count": 4, + "duration_hours": 8, + "created_at": "2026-01-15T10:00:00Z", + "expires_at": "2026-01-15T18:00:00Z", + "pod_name": "gpu-dev-abc12345", + "node_ip": "10.0.1.100", + "node_port": 30123, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_connection_info( + reservation_id="abc12345-uuid-uuid", + user_id="testuser", + ) + + assert result is not None + assert result["status"] == "active" + assert result["node_ip"] == "10.0.1.100" + assert result["node_port"] == 30123 + + def test_get_connection_info_by_short_id(self): + """Should find reservation by short ID prefix""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_job_status.side_effect = RuntimeError("not found") + mock_api_client.list_jobs.return_value = { + "jobs": [ + { + "reservation_id": "abc12345-full-uuid", + "job_id": "abc12345-full-uuid", + "status": "active", + "node_ip": "10.0.1.100", + "node_port": 30123, + } + ], + "total": 1, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_connection_info( + reservation_id="abc12345", + user_id="testuser", + ) + + assert result is not None + assert result["reservation_id"] == "abc12345-full-uuid" + + def test_get_connection_info_not_found(self): + """Should return None when reservation not found""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_job_status.side_effect = RuntimeError("not found") + mock_api_client.list_jobs.return_value = {"jobs": [], "total": 0} + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_connection_info( + reservation_id="nonexistent", + user_id="testuser", + ) + + assert result is None + + +class TestConnectCommand: + """Tests for connect command functionality""" + + def test_get_active_reservation_for_connect(self): + """Should get active reservation details for connecting""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_job_status.return_value = { + "reservation_id": "abc12345", + "status": "active", + "pod_name": "gpu-dev-abc12345", + "node_ip": "10.0.1.100", + "node_port": 30123, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_connection_info( + reservation_id="abc12345", + user_id="testuser", + ) + + assert result["status"] == "active" + assert result["node_ip"] == "10.0.1.100" + + def test_connect_to_pending_reservation(self): + """Should handle connecting to pending reservation""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_job_status.return_value = { + "reservation_id": "abc12345", + "status": "pending", + "pod_name": None, + "node_ip": None, + "node_port": None, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_connection_info( + reservation_id="abc12345", + user_id="testuser", + ) + + # Connection info is returned but node_ip is None + assert result["status"] == "pending" + assert result["node_ip"] is None + + +class TestStatusColors: + """Tests for status color coding""" + + def test_status_to_color_mapping(self): + """Should map status to correct color""" + from gpu_dev_cli.reservations import _get_status_color + + assert _get_status_color("active") == "green" + assert _get_status_color("queued") == "yellow" + assert _get_status_color("pending") == "yellow" + assert _get_status_color("preparing") == "yellow" + assert _get_status_color("cancelled") == "red" + assert _get_status_color("expired") == "red" + assert _get_status_color("failed") == "red" + assert _get_status_color("unknown") == "white" + + +class TestDisplayFormatting: + """Tests for output formatting""" + + def test_format_time_remaining(self): + """Should format time remaining correctly""" + from gpu_dev_cli.reservations import _format_time_remaining + + # Test hours + assert "7h" in _format_time_remaining(7 * 60) + assert "1h" in _format_time_remaining(75) + + # Test minutes + assert "30m" in _format_time_remaining(30) + assert "5m" in _format_time_remaining(5) + + # Test expired + result = _format_time_remaining(-10) + assert "expired" in result.lower() or "-" in result + + def test_format_reservation_id_short(self): + """Should display short reservation ID""" + from gpu_dev_cli.reservations import _format_short_id + + full_id = "abc12345-1234-5678-9012-345678901234" + short_id = _format_short_id(full_id) + + assert short_id == "abc12345" + assert len(short_id) == 8 diff --git a/tests/unit/cli/test_disks.py b/tests/unit/cli/test_disks.py new file mode 100644 index 00000000..02bd0b8d --- /dev/null +++ b/tests/unit/cli/test_disks.py @@ -0,0 +1,643 @@ +""" +Unit tests for gpu_dev_cli.disks module + +Tests: +- Disk listing +- Disk creation +- Disk deletion +- Disk renaming +- In-use status checking +""" + +import json +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest +from moto import mock_aws +import boto3 + + +class TestDiskInUseStatus: + """Tests for get_disk_in_use_status function""" + + @mock_aws + def test_disk_not_in_use_returns_false(self, aws_credentials): + """Should return (False, None) when disk is not in use""" + # Setup + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + + # Create tables + reservations_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-reservations", + KeySchema=[{"AttributeName": "reservation_id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "reservation_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + ], + GlobalSecondaryIndexes=[{ + "IndexName": "UserIndex", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "reservation_id", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }], + BillingMode="PAY_PER_REQUEST", + ) + reservations_table.wait_until_exists() + + disks_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-disks", + KeySchema=[ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "disk_name", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "disk_name", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + disks_table.wait_until_exists() + + # Add disk that's not in use + disks_table.put_item(Item={ + "user_id": "test-user", + "disk_name": "mydata", + "in_use": False, + }) + + # Create mock config + mock_config = MagicMock() + mock_config.session.resource.return_value = dynamodb + mock_config.aws_region = "us-west-1" + mock_config.disks_table = "pytorch-gpu-dev-test-disks" + mock_config.reservations_table = "pytorch-gpu-dev-test-reservations" + + from gpu_dev_cli.disks import get_disk_in_use_status + is_in_use, res_id = get_disk_in_use_status("mydata", "test-user", mock_config) + + assert is_in_use is False + assert res_id is None + + @mock_aws + def test_disk_in_use_via_disks_table(self, aws_credentials): + """Should detect disk in use from disks table in_use field""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + + # Create tables + reservations_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-reservations", + KeySchema=[{"AttributeName": "reservation_id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "reservation_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + ], + GlobalSecondaryIndexes=[{ + "IndexName": "UserIndex", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "reservation_id", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }], + BillingMode="PAY_PER_REQUEST", + ) + reservations_table.wait_until_exists() + + disks_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-disks", + KeySchema=[ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "disk_name", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "disk_name", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + disks_table.wait_until_exists() + + # Add disk that's in use + disks_table.put_item(Item={ + "user_id": "test-user", + "disk_name": "mydata", + "in_use": True, + "attached_to_reservation": "res-123", + }) + + mock_config = MagicMock() + mock_config.session.resource.return_value = dynamodb + mock_config.aws_region = "us-west-1" + mock_config.disks_table = "pytorch-gpu-dev-test-disks" + mock_config.reservations_table = "pytorch-gpu-dev-test-reservations" + + from gpu_dev_cli.disks import get_disk_in_use_status + is_in_use, res_id = get_disk_in_use_status("mydata", "test-user", mock_config) + + assert is_in_use is True + assert res_id == "res-123" + + @mock_aws + def test_disk_in_use_via_active_reservation(self, aws_credentials): + """Should detect disk in use from active reservation""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + + # Create tables + reservations_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-reservations", + KeySchema=[{"AttributeName": "reservation_id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "reservation_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + ], + GlobalSecondaryIndexes=[{ + "IndexName": "UserIndex", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "reservation_id", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }], + BillingMode="PAY_PER_REQUEST", + ) + reservations_table.wait_until_exists() + + disks_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-disks", + KeySchema=[ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "disk_name", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "disk_name", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + disks_table.wait_until_exists() + + # Add disk not marked in_use in disks table + disks_table.put_item(Item={ + "user_id": "test-user", + "disk_name": "mydata", + "in_use": False, + }) + + # But add active reservation using that disk + reservations_table.put_item(Item={ + "reservation_id": "res-456", + "user_id": "test-user", + "disk_name": "mydata", + "status": "active", + }) + + mock_config = MagicMock() + mock_config.session.resource.return_value = dynamodb + mock_config.aws_region = "us-west-1" + mock_config.disks_table = "pytorch-gpu-dev-test-disks" + mock_config.reservations_table = "pytorch-gpu-dev-test-reservations" + + from gpu_dev_cli.disks import get_disk_in_use_status + is_in_use, res_id = get_disk_in_use_status("mydata", "test-user", mock_config) + + assert is_in_use is True + assert res_id == "res-456" + + +class TestListDisks: + """Tests for list_disks function""" + + @mock_aws + def test_list_disks_returns_all_user_disks(self, aws_credentials): + """Should return all disks for user sorted by last_used""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + + # Create tables + reservations_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-reservations", + KeySchema=[{"AttributeName": "reservation_id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "reservation_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + ], + GlobalSecondaryIndexes=[{ + "IndexName": "UserIndex", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "reservation_id", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }], + BillingMode="PAY_PER_REQUEST", + ) + reservations_table.wait_until_exists() + + disks_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-disks", + KeySchema=[ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "disk_name", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "disk_name", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + disks_table.wait_until_exists() + + now = datetime.now(timezone.utc) + + # Add disks + disks_table.put_item(Item={ + "user_id": "test-user", + "disk_name": "data1", + "size_gb": 100, + "created_at": (now - timedelta(days=5)).isoformat(), + "last_used": (now - timedelta(days=2)).isoformat(), + "snapshot_count": 3, + "in_use": False, + }) + disks_table.put_item(Item={ + "user_id": "test-user", + "disk_name": "data2", + "size_gb": 200, + "created_at": (now - timedelta(days=10)).isoformat(), + "last_used": (now - timedelta(hours=1)).isoformat(), # Most recent + "snapshot_count": 1, + "in_use": False, + }) + + mock_config = MagicMock() + mock_config.session.resource.return_value = dynamodb + mock_config.aws_region = "us-west-1" + mock_config.disks_table = "pytorch-gpu-dev-test-disks" + mock_config.reservations_table = "pytorch-gpu-dev-test-reservations" + mock_config.queue_name = "pytorch-gpu-dev-test-reservation-queue" + + from gpu_dev_cli.disks import list_disks + disks = list_disks("test-user", mock_config) + + assert len(disks) == 2 + # Should be sorted by last_used descending + assert disks[0]["name"] == "data2" # Most recent + assert disks[1]["name"] == "data1" + + @mock_aws + def test_list_disks_empty_for_new_user(self, aws_credentials): + """Should return empty list for user with no disks""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + + disks_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-disks", + KeySchema=[ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "disk_name", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "disk_name", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + disks_table.wait_until_exists() + + mock_config = MagicMock() + mock_config.session.resource.return_value = dynamodb + mock_config.aws_region = "us-west-1" + mock_config.disks_table = "pytorch-gpu-dev-test-disks" + mock_config.queue_name = "pytorch-gpu-dev-test-reservation-queue" + + from gpu_dev_cli.disks import list_disks + disks = list_disks("new-user", mock_config) + + assert disks == [] + + +class TestCreateDisk: + """Tests for create_disk function""" + + @mock_aws + def test_create_disk_sends_sqs_message(self, aws_credentials): + """Should send create_disk action to SQS queue""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + sqs = boto3.client("sqs", region_name="us-west-1") + + # Create queue + queue = sqs.create_queue(QueueName="pytorch-gpu-dev-test-reservation-queue") + queue_url = queue["QueueUrl"] + + # Create disks table + disks_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-disks", + KeySchema=[ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "disk_name", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "disk_name", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + disks_table.wait_until_exists() + + mock_config = MagicMock() + mock_config.session.resource.return_value = dynamodb + mock_config.session.client.return_value = sqs + mock_config.aws_region = "us-west-1" + mock_config.disks_table = "pytorch-gpu-dev-test-disks" + mock_config.queue_name = "pytorch-gpu-dev-test-reservation-queue" + mock_config.get_queue_url.return_value = queue_url + + from gpu_dev_cli.disks import create_disk + operation_id = create_disk("newdisk", "test-user", mock_config) + + # Should return operation_id + assert operation_id is not None + + # Check SQS message + messages = sqs.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=1) + assert "Messages" in messages + body = json.loads(messages["Messages"][0]["Body"]) + assert body["action"] == "create_disk" + assert body["disk_name"] == "newdisk" + assert body["user_id"] == "test-user" + + @mock_aws + def test_create_disk_rejects_existing_name(self, aws_credentials): + """Should reject creating disk with existing name""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + + disks_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-disks", + KeySchema=[ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "disk_name", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "disk_name", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + disks_table.wait_until_exists() + + # Add existing disk + disks_table.put_item(Item={ + "user_id": "test-user", + "disk_name": "existingdisk", + "size_gb": 100, + }) + + mock_config = MagicMock() + mock_config.session.resource.return_value = dynamodb + mock_config.aws_region = "us-west-1" + mock_config.disks_table = "pytorch-gpu-dev-test-disks" + mock_config.queue_name = "pytorch-gpu-dev-test-reservation-queue" + + from gpu_dev_cli.disks import create_disk + result = create_disk("existingdisk", "test-user", mock_config) + + assert result is None + + def test_create_disk_validates_name_format(self, aws_credentials): + """Should reject invalid disk names""" + mock_config = MagicMock() + mock_config.session.resource.return_value.Table.return_value.query.return_value = {"Items": []} + + from gpu_dev_cli.disks import create_disk + + # Invalid names + result = create_disk("disk with spaces", "test-user", mock_config) + assert result is None + + result = create_disk("disk@special#chars", "test-user", mock_config) + assert result is None + + +class TestDeleteDisk: + """Tests for delete_disk function""" + + @mock_aws + def test_delete_disk_sends_sqs_message(self, aws_credentials): + """Should send delete_disk action to SQS queue""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + sqs = boto3.client("sqs", region_name="us-west-1") + + queue = sqs.create_queue(QueueName="pytorch-gpu-dev-test-reservation-queue") + queue_url = queue["QueueUrl"] + + # Create tables + reservations_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-reservations", + KeySchema=[{"AttributeName": "reservation_id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "reservation_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + ], + GlobalSecondaryIndexes=[{ + "IndexName": "UserIndex", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "reservation_id", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }], + BillingMode="PAY_PER_REQUEST", + ) + reservations_table.wait_until_exists() + + disks_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-disks", + KeySchema=[ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "disk_name", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "disk_name", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + disks_table.wait_until_exists() + + # Add disk to delete + disks_table.put_item(Item={ + "user_id": "test-user", + "disk_name": "deleteme", + "size_gb": 100, + "in_use": False, + }) + + mock_config = MagicMock() + mock_config.session.resource.return_value = dynamodb + mock_config.session.client.return_value = sqs + mock_config.aws_region = "us-west-1" + mock_config.disks_table = "pytorch-gpu-dev-test-disks" + mock_config.reservations_table = "pytorch-gpu-dev-test-reservations" + mock_config.queue_name = "pytorch-gpu-dev-test-reservation-queue" + mock_config.get_queue_url.return_value = queue_url + + from gpu_dev_cli.disks import delete_disk + operation_id = delete_disk("deleteme", "test-user", mock_config) + + assert operation_id is not None + + # Check SQS message + messages = sqs.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=1) + body = json.loads(messages["Messages"][0]["Body"]) + assert body["action"] == "delete_disk" + assert body["disk_name"] == "deleteme" + + @mock_aws + def test_delete_disk_rejects_in_use_disk(self, aws_credentials): + """Should not allow deleting disk that's in use""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + + # Create tables + reservations_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-reservations", + KeySchema=[{"AttributeName": "reservation_id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "reservation_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + ], + GlobalSecondaryIndexes=[{ + "IndexName": "UserIndex", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "reservation_id", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }], + BillingMode="PAY_PER_REQUEST", + ) + reservations_table.wait_until_exists() + + disks_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-disks", + KeySchema=[ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "disk_name", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "disk_name", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + disks_table.wait_until_exists() + + # Add disk that's in use + disks_table.put_item(Item={ + "user_id": "test-user", + "disk_name": "inuse", + "size_gb": 100, + "in_use": True, + "attached_to_reservation": "res-123", + }) + + mock_config = MagicMock() + mock_config.session.resource.return_value = dynamodb + mock_config.aws_region = "us-west-1" + mock_config.disks_table = "pytorch-gpu-dev-test-disks" + mock_config.reservations_table = "pytorch-gpu-dev-test-reservations" + mock_config.queue_name = "pytorch-gpu-dev-test-reservation-queue" + + from gpu_dev_cli.disks import delete_disk + result = delete_disk("inuse", "test-user", mock_config) + + assert result is None + + +class TestRenameDisk: + """Tests for rename_disk function""" + + @mock_aws + def test_rename_disk_updates_snapshot_tags(self, aws_credentials): + """Should update disk_name tag on all snapshots""" + ec2 = boto3.client("ec2", region_name="us-west-1") + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + + # Create tables + reservations_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-reservations", + KeySchema=[{"AttributeName": "reservation_id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "reservation_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + ], + GlobalSecondaryIndexes=[{ + "IndexName": "UserIndex", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "reservation_id", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }], + BillingMode="PAY_PER_REQUEST", + ) + reservations_table.wait_until_exists() + + disks_table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-disks", + KeySchema=[ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "disk_name", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "user_id", "AttributeType": "S"}, + {"AttributeName": "disk_name", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + ) + disks_table.wait_until_exists() + + # Add disk + disks_table.put_item(Item={ + "user_id": "test-user", + "disk_name": "oldname", + "size_gb": 100, + "in_use": False, + }) + + # Create a volume and snapshot (moto doesn't have full snapshot support) + # So we'll mock the EC2 client instead + mock_ec2 = MagicMock() + mock_ec2.describe_snapshots.return_value = { + "Snapshots": [ + {"SnapshotId": "snap-123"}, + {"SnapshotId": "snap-456"}, + ] + } + mock_ec2.create_tags.return_value = {} + + mock_config = MagicMock() + mock_config.session.resource.return_value = dynamodb + mock_config.session.client.return_value = mock_ec2 + mock_config.aws_region = "us-west-1" + mock_config.disks_table = "pytorch-gpu-dev-test-disks" + mock_config.reservations_table = "pytorch-gpu-dev-test-reservations" + mock_config.queue_name = "pytorch-gpu-dev-test-reservation-queue" + + from gpu_dev_cli.disks import rename_disk + result = rename_disk("oldname", "newname", "test-user", mock_config) + + assert result is True + # Verify create_tags was called for both snapshots + assert mock_ec2.create_tags.call_count == 2 + + def test_rename_disk_validates_new_name(self, aws_credentials): + """Should reject invalid new disk names""" + mock_config = MagicMock() + mock_config.session.resource.return_value.Table.return_value.query.return_value = {"Items": []} + + from gpu_dev_cli.disks import rename_disk + + result = rename_disk("oldname", "invalid name!", "test-user", mock_config) + assert result is False diff --git a/tests/unit/cli/test_edit.py b/tests/unit/cli/test_edit.py new file mode 100644 index 00000000..b3112e1d --- /dev/null +++ b/tests/unit/cli/test_edit.py @@ -0,0 +1,461 @@ +""" +Unit tests for gpu_dev_cli edit command + +Tests: +- Extend reservation (duration extension) +- Add user (collaborator) +- Enable/disable Jupyter +- Cancel reservation +- List reservations +""" + +import json +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch + +import pytest + + +class TestExtendReservation: + """Tests for extending reservation duration via API""" + + def test_extend_calls_api(self): + """Should call API extend endpoint""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.extend_job.return_value = {"status": "extended"} + mock_api_client.get_job_status.return_value = { + "reservation_id": "res-123", + "status": "active", + "expires_at": (datetime.now(timezone.utc) + timedelta(hours=8)).isoformat(), + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + with patch("gpu_dev_cli.reservations.Live"): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + manager.extend_reservation( + reservation_id="res-123", + user_id="test-user", + extension_hours=4, + ) + + mock_api_client.extend_job.assert_called_once_with("res-123", 4) + + def test_extend_converts_float_to_int(self): + """Should convert extension hours to int for API""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.extend_job.return_value = {"status": "extended"} + mock_api_client.get_job_status.return_value = { + "reservation_id": "res-123", + "status": "active", + "expires_at": (datetime.now(timezone.utc) + timedelta(hours=12)).isoformat(), + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + with patch("gpu_dev_cli.reservations.Live"): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + manager.extend_reservation( + reservation_id="res-123", + user_id="test-user", + extension_hours=4.5, + ) + + mock_api_client.extend_job.assert_called_once_with("res-123", 4) + + def test_extend_handles_api_error(self): + """Should return False on API error""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.extend_job.side_effect = Exception("API error") + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.extend_reservation( + reservation_id="res-123", + user_id="test-user", + extension_hours=4, + ) + + assert result is False + + +class TestAddUser: + """Tests for adding collaborators via API""" + + def test_add_user_calls_api(self): + """Should call API add user endpoint""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.add_user.return_value = {"status": "added"} + mock_api_client.get_job_status.return_value = { + "reservation_id": "res-123", + "status": "active", + "secondary_users": ["newuser"], + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + with patch("gpu_dev_cli.reservations.Live"): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + manager.add_user( + reservation_id="res-123", + user_id="test-user", + github_username="newuser", + ) + + mock_api_client.add_user.assert_called_once_with("res-123", "newuser") + + def test_add_user_validates_username_format(self): + """Should reject invalid GitHub usernames""" + mock_config = MagicMock() + mock_api_client = MagicMock() + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + + # Empty username + result = manager.add_user("res-123", "test-user", "") + assert result is False + + # Username with spaces + result = manager.add_user("res-123", "test-user", "user with spaces") + assert result is False + + mock_api_client.add_user.assert_not_called() + + def test_add_user_allows_valid_usernames(self): + """Should allow valid GitHub username formats""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.add_user.return_value = {"status": "added"} + mock_api_client.get_job_status.return_value = { + "reservation_id": "res-123", + "status": "active", + "secondary_users": ["valid-user"], + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + with patch("gpu_dev_cli.reservations.Live"): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + + manager.add_user("res-123", "test-user", "valid-user") + mock_api_client.add_user.assert_called_with("res-123", "valid-user") + + manager.add_user("res-123", "test-user", "user_name") + mock_api_client.add_user.assert_called_with("res-123", "user_name") + + manager.add_user("res-123", "test-user", "user123") + mock_api_client.add_user.assert_called_with("res-123", "user123") + + def test_add_user_handles_api_error(self): + """Should return False on API error""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.add_user.side_effect = Exception("API error") + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.add_user( + reservation_id="res-123", + user_id="test-user", + github_username="newuser", + ) + + assert result is False + + +class TestJupyterToggle: + """Tests for enabling/disabling Jupyter""" + + def test_enable_jupyter_calls_api(self): + """Should call API enable jupyter endpoint""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.enable_jupyter.return_value = {"status": "enabled"} + mock_api_client.get_job_status.return_value = { + "reservation_id": "res-123", + "status": "active", + "jupyter_enabled": True, + "jupyter_url": "http://1.2.3.4:30888", + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + with patch("gpu_dev_cli.reservations.Live"): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + manager.enable_jupyter( + reservation_id="res-123", + user_id="test-user", + ) + + mock_api_client.enable_jupyter.assert_called_once_with("res-123") + + def test_disable_jupyter_calls_api(self): + """Should call API disable jupyter endpoint""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.disable_jupyter.return_value = {"status": "disabled"} + mock_api_client.get_job_status.return_value = { + "reservation_id": "res-123", + "status": "active", + "jupyter_enabled": False, + "jupyter_url": "", + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + with patch("gpu_dev_cli.reservations.Live"): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + manager.disable_jupyter( + reservation_id="res-123", + user_id="test-user", + ) + + mock_api_client.disable_jupyter.assert_called_once_with("res-123") + + def test_enable_jupyter_handles_api_error(self): + """Should return False on API error""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.enable_jupyter.side_effect = Exception("API error") + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.enable_jupyter( + reservation_id="res-123", + user_id="test-user", + ) + + assert result is False + + +class TestCancelReservation: + """Tests for cancelling reservations""" + + def test_cancel_calls_api(self): + """Should call API cancel endpoint""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.cancel_job.return_value = {"status": "cancelled"} + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.cancel_reservation( + reservation_id="res-123", + user_id="test-user", + ) + + assert result is True + mock_api_client.cancel_job.assert_called_once_with("res-123") + + def test_cancel_handles_api_error(self): + """Should return False on API error""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.cancel_job.side_effect = Exception("API error") + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.cancel_reservation( + reservation_id="res-123", + user_id="test-user", + ) + + assert result is False + + +class TestListReservations: + """Tests for listing reservations""" + + def test_list_calls_api(self): + """Should call API list jobs endpoint""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.list_jobs.return_value = { + "jobs": [ + {"reservation_id": "res-1", "status": "active"}, + {"reservation_id": "res-2", "status": "pending"}, + ], + "total": 2, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.list_reservations() + + assert len(result) == 2 + mock_api_client.list_jobs.assert_called_once() + + def test_list_with_status_filter(self): + """Should pass status filter to API""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.list_jobs.return_value = { + "jobs": [{"reservation_id": "res-1", "status": "active"}], + "total": 1, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + manager.list_reservations(statuses_to_include=["active", "pending"]) + + mock_api_client.list_jobs.assert_called_once_with( + status_filter="active,pending", + limit=500, + offset=0, + ) + + def test_list_handles_api_error(self): + """Should return empty list on API error""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.list_jobs.side_effect = Exception("API error") + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.list_reservations() + + assert result == [] + + +class TestGetConnectionInfo: + """Tests for getting connection info""" + + def test_get_connection_info_by_id(self): + """Should get job details by full ID""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_job_status.return_value = { + "reservation_id": "res-123-full-uuid", + "status": "active", + "node_ip": "1.2.3.4", + "node_port": 30001, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_connection_info( + reservation_id="res-123-full-uuid", + user_id="test-user", + ) + + assert result is not None + assert result["node_ip"] == "1.2.3.4" + mock_api_client.get_job_status.assert_called_with("res-123-full-uuid") + + def test_get_connection_info_by_prefix(self): + """Should find job by prefix when exact ID fails""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_job_status.side_effect = RuntimeError("not found") + mock_api_client.list_jobs.return_value = { + "jobs": [ + {"reservation_id": "abc12345-full-uuid", "job_id": "abc12345-full-uuid", "status": "active"}, + ], + "total": 1, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_connection_info( + reservation_id="abc12345", + user_id="test-user", + ) + + assert result is not None + assert result["reservation_id"] == "abc12345-full-uuid" + + def test_get_connection_info_handles_not_found(self): + """Should return None when not found""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_job_status.side_effect = RuntimeError("not found") + mock_api_client.list_jobs.return_value = {"jobs": [], "total": 0} + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_connection_info( + reservation_id="nonexistent", + user_id="test-user", + ) + + assert result is None + + +class TestClusterStatus: + """Tests for cluster status retrieval""" + + def test_get_cluster_status(self): + """Should call API and transform response""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_cluster_status.return_value = { + "total_gpus": 64, + "available_gpus": 32, + "in_use_gpus": 24, + "queued_gpus": 8, + "active_reservations": 5, + "queued_reservations": 2, + "pending_reservations": 0, + "preparing_reservations": 1, + } + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_cluster_status() + + assert result is not None + assert result["total_gpus"] == 64 + assert result["available_gpus"] == 32 + assert result["reserved_gpus"] == 24 + assert result["queue_length"] == 2 + + def test_get_cluster_status_handles_error(self): + """Should return None on API error""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.get_cluster_status.side_effect = Exception("API error") + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.get_cluster_status() + + assert result is None diff --git a/tests/unit/cli/test_reservations.py b/tests/unit/cli/test_reservations.py new file mode 100644 index 00000000..c31ea50c --- /dev/null +++ b/tests/unit/cli/test_reservations.py @@ -0,0 +1,370 @@ +""" +Unit tests for gpu_dev_cli.reservations module + +Tests: +- Reservation creation (SQS messaging) +- Reservation listing +- Reservation cancellation +- Connection info retrieval +- SSH config generation +- VS Code link generation +""" + +import json +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest +from moto import mock_aws +import boto3 + + +class TestVSCodeLinkGeneration: + """Tests for VS Code and Cursor link generation""" + + def test_make_vscode_link(self): + """Should generate correct vscode:// URL""" + from gpu_dev_cli.reservations import _make_vscode_link + + result = _make_vscode_link("gpu-dev-abc123") + + assert result == "vscode://vscode-remote/ssh-remote+gpu-dev-abc123/home/dev" + + def test_make_cursor_link(self): + """Should generate correct cursor:// URL""" + from gpu_dev_cli.reservations import _make_cursor_link + + result = _make_cursor_link("gpu-dev-abc123") + + assert result == "cursor://vscode-remote/ssh-remote+gpu-dev-abc123/home/dev" + + +class TestAgentForwardingSSH: + """Tests for SSH command modifications""" + + def test_add_agent_forwarding_to_ssh(self): + """Should add -A flag to SSH command""" + from gpu_dev_cli.reservations import _add_agent_forwarding_to_ssh + + result = _add_agent_forwarding_to_ssh("ssh -p 30001 dev@1.2.3.4") + + assert "-A" in result + assert "ssh" in result + + def test_add_agent_forwarding_preserves_existing_options(self): + """Should preserve existing SSH options""" + from gpu_dev_cli.reservations import _add_agent_forwarding_to_ssh + + result = _add_agent_forwarding_to_ssh("ssh -o StrictHostKeyChecking=no -p 30001 dev@1.2.3.4") + + assert "-A" in result + assert "StrictHostKeyChecking=no" in result + + +class TestReservationManager: + """Tests for ReservationManager class""" + + @mock_aws + def test_list_reservations_returns_user_reservations(self, aws_credentials): + """Should return all reservations for a user""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + + # Create table + table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-reservations", + KeySchema=[{"AttributeName": "reservation_id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "reservation_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + ], + GlobalSecondaryIndexes=[{ + "IndexName": "UserIndex", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "reservation_id", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }], + BillingMode="PAY_PER_REQUEST", + ) + table.wait_until_exists() + + now = datetime.now(timezone.utc) + + # Add reservations + table.put_item(Item={ + "reservation_id": "res-001", + "user_id": "test-user", + "status": "active", + "gpu_count": 2, + "gpu_type": "t4", + "created_at": now.isoformat(), + "expires_at": (now + timedelta(hours=4)).isoformat(), + }) + table.put_item(Item={ + "reservation_id": "res-002", + "user_id": "test-user", + "status": "completed", + "gpu_count": 1, + "gpu_type": "h100", + "created_at": (now - timedelta(days=1)).isoformat(), + "expires_at": (now - timedelta(hours=20)).isoformat(), + }) + # Different user's reservation + table.put_item(Item={ + "reservation_id": "res-003", + "user_id": "other-user", + "status": "active", + "gpu_count": 4, + "gpu_type": "a100", + "created_at": now.isoformat(), + "expires_at": (now + timedelta(hours=8)).isoformat(), + }) + + mock_config = MagicMock() + mock_config.dynamodb = dynamodb + mock_config.reservations_table = "pytorch-gpu-dev-test-reservations" + + from gpu_dev_cli.reservations import ReservationManager + manager = ReservationManager(mock_config) + reservations = manager.list_reservations("test-user") + + assert len(reservations) == 2 + assert all(r["user_id"] == "test-user" for r in reservations) + + @mock_aws + def test_list_reservations_filters_by_status_in_client(self, aws_credentials): + """Should be able to filter reservations by status after fetching""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + + table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-reservations", + KeySchema=[{"AttributeName": "reservation_id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "reservation_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + ], + GlobalSecondaryIndexes=[{ + "IndexName": "UserIndex", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "reservation_id", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }], + BillingMode="PAY_PER_REQUEST", + ) + table.wait_until_exists() + + now = datetime.now(timezone.utc) + + table.put_item(Item={ + "reservation_id": "res-001", + "user_id": "test-user", + "status": "active", + "created_at": now.isoformat(), + }) + table.put_item(Item={ + "reservation_id": "res-002", + "user_id": "test-user", + "status": "cancelled", + "created_at": now.isoformat(), + }) + + mock_config = MagicMock() + mock_config.dynamodb = dynamodb + mock_config.reservations_table = "pytorch-gpu-dev-test-reservations" + + from gpu_dev_cli.reservations import ReservationManager + manager = ReservationManager(mock_config) + all_reservations = manager.list_reservations("test-user") + + # Filter on client side (as CLI does) + active_reservations = [r for r in all_reservations if r["status"] == "active"] + + assert len(active_reservations) == 1 + assert active_reservations[0]["status"] == "active" + + +class TestReservationCreation: + """Tests for reservation creation via SQS""" + + def test_create_reservation_returns_id(self): + """Should return a reservation ID when creating a reservation""" + # Create mocks + mock_table = MagicMock() + mock_sqs = MagicMock() + mock_config = MagicMock() + mock_config.dynamodb.Table.return_value = mock_table + mock_config.sqs_client = mock_sqs + mock_config.get_queue_url.return_value = "https://sqs.test/queue" + + from gpu_dev_cli.reservations import ReservationManager + manager = ReservationManager(mock_config) + + reservation_id = manager.create_reservation( + user_id="test-user", + gpu_count=2, + gpu_type="t4", + duration_hours=4, + github_user="testgithub", + ) + + # Should return a UUID string + assert reservation_id is not None + assert len(reservation_id) == 36 # UUID format + + def test_create_reservation_sends_sqs_message(self): + """Should send reservation request to SQS""" + mock_table = MagicMock() + mock_sqs = MagicMock() + mock_config = MagicMock() + mock_config.dynamodb.Table.return_value = mock_table + mock_config.sqs_client = mock_sqs + mock_config.get_queue_url.return_value = "https://sqs.test/queue" + + from gpu_dev_cli.reservations import ReservationManager + manager = ReservationManager(mock_config) + + manager.create_reservation( + user_id="test-user", + gpu_count=2, + gpu_type="t4", + duration_hours=4, + github_user="testgithub", + ) + + # Verify SQS send_message was called + mock_sqs.send_message.assert_called_once() + call_args = mock_sqs.send_message.call_args + body = json.loads(call_args.kwargs["MessageBody"]) + + # Create reservation uses status=pending instead of action field + assert body["status"] == "pending" + assert body["gpu_count"] == 2 + assert body["gpu_type"] == "t4" + assert body["github_user"] == "testgithub" + + def test_create_reservation_includes_disk_name(self): + """Should include disk_name in SQS message when specified""" + mock_table = MagicMock() + mock_sqs = MagicMock() + mock_config = MagicMock() + mock_config.dynamodb.Table.return_value = mock_table + mock_config.sqs_client = mock_sqs + mock_config.get_queue_url.return_value = "https://sqs.test/queue" + + from gpu_dev_cli.reservations import ReservationManager + manager = ReservationManager(mock_config) + + manager.create_reservation( + user_id="test-user", + gpu_count=1, + gpu_type="t4", + duration_hours=2, + github_user="testgithub", + disk_name="mydata", + ) + + call_args = mock_sqs.send_message.call_args + body = json.loads(call_args.kwargs["MessageBody"]) + + assert body.get("disk_name") == "mydata" + + +class TestCancellation: + """Tests for reservation cancellation""" + + def test_cancel_sends_sqs_message(self): + """Should send cancel request to SQS""" + mock_table = MagicMock() + mock_sqs = MagicMock() + mock_config = MagicMock() + mock_config.dynamodb.Table.return_value = mock_table + mock_config.sqs_client = mock_sqs + mock_config.get_queue_url.return_value = "https://sqs.test/queue" + + from gpu_dev_cli.reservations import ReservationManager + manager = ReservationManager(mock_config) + + manager.cancel_reservation("res-to-cancel", "test-user") + + # Verify SQS send_message was called + mock_sqs.send_message.assert_called_once() + call_args = mock_sqs.send_message.call_args + body = json.loads(call_args.kwargs["MessageBody"]) + + # Cancellation uses "type": "cancellation" field + assert body["type"] == "cancellation" + assert body["reservation_id"] == "res-to-cancel" + + +class TestConnectionInfo: + """Tests for connection info retrieval""" + + @mock_aws + def test_get_connection_info_returns_ssh_command(self, aws_credentials): + """Should return SSH command for active reservation""" + dynamodb = boto3.resource("dynamodb", region_name="us-west-1") + + # Create table with UserIndex + table = dynamodb.create_table( + TableName="pytorch-gpu-dev-test-reservations", + KeySchema=[{"AttributeName": "reservation_id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "reservation_id", "AttributeType": "S"}, + {"AttributeName": "user_id", "AttributeType": "S"}, + ], + GlobalSecondaryIndexes=[{ + "IndexName": "UserIndex", + "KeySchema": [ + {"AttributeName": "user_id", "KeyType": "HASH"}, + {"AttributeName": "reservation_id", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + }], + BillingMode="PAY_PER_REQUEST", + ) + table.wait_until_exists() + + # Add all required fields that get_connection_info expects + table.put_item(Item={ + "reservation_id": "res-123", + "user_id": "test-user", + "status": "active", + "pod_name": "gpu-dev-abc123", + "node_port": 30001, + "ssh_command": "ssh -p 30001 dev@1.2.3.4", + "gpu_count": 2, + "gpu_type": "t4", + "duration_hours": 4, + }) + + mock_config = MagicMock() + mock_config.dynamodb = dynamodb + mock_config.reservations_table = "pytorch-gpu-dev-test-reservations" + + from gpu_dev_cli.reservations import ReservationManager + manager = ReservationManager(mock_config) + + info = manager.get_connection_info("res-123", "test-user") + + assert info is not None + assert "ssh_command" in info + assert "dev@" in info["ssh_command"] + + +class TestSSHConfigGeneration: + """Tests for SSH config file generation""" + + def test_generate_ssh_config_function(self): + """Should generate valid SSH config content""" + from gpu_dev_cli.reservations import _generate_ssh_config + + config = _generate_ssh_config("test.devservers.io", "gpu-dev-abc123") + + assert "Host gpu-dev-abc123" in config + assert "HostName test.devservers.io" in config + assert "User dev" in config + assert "ForwardAgent yes" in config diff --git a/tests/unit/cli/test_reserve.py b/tests/unit/cli/test_reserve.py new file mode 100644 index 00000000..fe36eee6 --- /dev/null +++ b/tests/unit/cli/test_reserve.py @@ -0,0 +1,550 @@ +""" +Unit tests for gpu_dev_cli reserve command + +Tests: +- GPU type mapping and validation +- GPU count validation per type +- Duration validation +- All reservation options +- API request format transformation +""" + +import json +from datetime import datetime, timedelta, timezone +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest + + +class TestGPUTypeMapping: + """Tests for GPU type to instance type mapping""" + + def test_map_gpu_to_instance_type_valid_types(self): + """Should map all valid GPU types to correct instance types""" + from gpu_dev_cli.reservations import _map_gpu_to_instance_type + + assert _map_gpu_to_instance_type("t4", 4) == "g4dn.12xlarge" + assert _map_gpu_to_instance_type("t4-small", 1) == "g4dn.2xlarge" + assert _map_gpu_to_instance_type("l4", 4) == "g6.12xlarge" + assert _map_gpu_to_instance_type("a10g", 4) == "g5.12xlarge" + assert _map_gpu_to_instance_type("a100", 8) == "p4d.24xlarge" + assert _map_gpu_to_instance_type("h100", 8) == "p5.48xlarge" + assert _map_gpu_to_instance_type("h200", 8) == "p5e.48xlarge" + assert _map_gpu_to_instance_type("b200", 8) == "p6-b200.48xlarge" + + def test_map_gpu_type_case_insensitive(self): + """Should handle GPU types case-insensitively""" + from gpu_dev_cli.reservations import _map_gpu_to_instance_type + + assert _map_gpu_to_instance_type("H100", 1) == "p5.48xlarge" + assert _map_gpu_to_instance_type("T4", 1) == "g4dn.12xlarge" + assert _map_gpu_to_instance_type("A100", 1) == "p4d.24xlarge" + + def test_map_gpu_invalid_type_raises(self): + """Should raise ValueError for invalid GPU types""" + from gpu_dev_cli.reservations import _map_gpu_to_instance_type + + with pytest.raises(ValueError, match="Unsupported GPU type"): + _map_gpu_to_instance_type("invalid-gpu", 1) + + with pytest.raises(ValueError, match="Unsupported GPU type"): + _map_gpu_to_instance_type("v100", 1) + + def test_cpu_instance_types(self): + """Should support CPU-only instance types""" + from gpu_dev_cli.reservations import _map_gpu_to_instance_type + + assert _map_gpu_to_instance_type("cpu-arm", 0) == "c7g.8xlarge" + assert _map_gpu_to_instance_type("cpu-x86", 0) == "c7i.8xlarge" + + +class TestGPUCountValidation: + """Tests for GPU count validation per type""" + + def test_valid_gpu_counts_per_type(self): + """Should accept valid GPU counts for each type""" + from gpu_dev_cli.reservations import _map_gpu_to_instance_type + + # T4 max 4 GPUs + _map_gpu_to_instance_type("t4", 1) + _map_gpu_to_instance_type("t4", 2) + _map_gpu_to_instance_type("t4", 4) + + # H100/A100/B200/H200 max 8 GPUs + _map_gpu_to_instance_type("h100", 1) + _map_gpu_to_instance_type("h100", 4) + _map_gpu_to_instance_type("h100", 8) + + # T4-small max 1 GPU + _map_gpu_to_instance_type("t4-small", 1) + + def test_exceeds_max_gpus_raises(self): + """Should raise ValueError when GPU count exceeds maximum""" + from gpu_dev_cli.reservations import _map_gpu_to_instance_type + + # T4 max is 4 + with pytest.raises(ValueError, match="GPU count 5 exceeds maximum 4"): + _map_gpu_to_instance_type("t4", 5) + + # H100 max is 8 + with pytest.raises(ValueError, match="GPU count 16 exceeds maximum 8"): + _map_gpu_to_instance_type("h100", 16) + + # T4-small max is 1 + with pytest.raises(ValueError, match="GPU count 2 exceeds maximum 1"): + _map_gpu_to_instance_type("t4-small", 2) + + def test_zero_gpus_for_gpu_instance_raises(self): + """Should raise ValueError for 0 GPUs on GPU instances""" + from gpu_dev_cli.reservations import _map_gpu_to_instance_type + + with pytest.raises(ValueError, match="GPU count must be at least 1"): + _map_gpu_to_instance_type("t4", 0) + + +class TestAPIFormatTransformation: + """Tests for _transform_to_api_format function""" + + def test_transform_basic_reservation(self): + """Should transform basic reservation to API format""" + from gpu_dev_cli.reservations import _transform_to_api_format + + message = { + "reservation_id": "res-123", + "user_id": "test-user", + "gpu_type": "h100", + "gpu_count": 4, + "duration_hours": 8, + "github_user": "testgithub", + } + + result = _transform_to_api_format(message) + + assert result["instance_type"] == "p5.48xlarge" + assert result["duration_hours"] == 8 + assert "env_vars" in result + assert result["env_vars"]["GPU_TYPE"] == "h100" + assert result["env_vars"]["GPU_COUNT"] == "4" + assert result["env_vars"]["GITHUB_USER"] == "testgithub" + + def test_transform_includes_docker_image(self): + """Should include custom docker image when provided""" + from gpu_dev_cli.reservations import _transform_to_api_format + + message = { + "gpu_type": "t4", + "gpu_count": 1, + "duration_hours": 4, + "dockerimage": "pytorch/pytorch:2.0.0-cuda11.8", + } + + result = _transform_to_api_format(message) + + assert result["image"] == "pytorch/pytorch:2.0.0-cuda11.8" + + def test_transform_uses_default_image(self): + """Should use default image when no custom image provided""" + from gpu_dev_cli.reservations import _transform_to_api_format + + message = { + "gpu_type": "h100", + "gpu_count": 8, + "duration_hours": 24, + } + + result = _transform_to_api_format(message) + + assert "image" in result + assert "pytorch" in result["image"] + + def test_transform_includes_disk_name(self): + """Should include disk_name when provided""" + from gpu_dev_cli.reservations import _transform_to_api_format + + message = { + "gpu_type": "t4", + "gpu_count": 1, + "duration_hours": 4, + "disk_name": "my-project", + } + + result = _transform_to_api_format(message) + + assert result["disk_name"] == "my-project" + + def test_transform_includes_jupyter_flag(self): + """Should include jupyter_enabled in env_vars""" + from gpu_dev_cli.reservations import _transform_to_api_format + + message = { + "gpu_type": "t4", + "gpu_count": 1, + "duration_hours": 4, + "jupyter_enabled": True, + } + + result = _transform_to_api_format(message) + + assert result["env_vars"]["JUPYTER_ENABLED"] == "true" + + def test_transform_includes_preserve_entrypoint(self): + """Should include preserve_entrypoint in env_vars""" + from gpu_dev_cli.reservations import _transform_to_api_format + + message = { + "gpu_type": "t4", + "gpu_count": 1, + "duration_hours": 4, + "preserve_entrypoint": True, + } + + result = _transform_to_api_format(message) + + assert result["env_vars"]["PRESERVE_ENTRYPOINT"] == "true" + + def test_transform_includes_recreate_env(self): + """Should include recreate_env in env_vars""" + from gpu_dev_cli.reservations import _transform_to_api_format + + message = { + "gpu_type": "t4", + "gpu_count": 1, + "duration_hours": 4, + "recreate_env": True, + } + + result = _transform_to_api_format(message) + + assert result["env_vars"]["RECREATE_ENV"] == "true" + + def test_transform_includes_pod_name(self): + """Should include name as POD_NAME in env_vars""" + from gpu_dev_cli.reservations import _transform_to_api_format + + message = { + "gpu_type": "t4", + "gpu_count": 1, + "duration_hours": 4, + "name": "my-experiment", + } + + result = _transform_to_api_format(message) + + assert result["env_vars"]["POD_NAME"] == "my-experiment" + + def test_transform_raises_without_gpu_fields(self): + """Should raise ValueError if gpu_type or gpu_count missing""" + from gpu_dev_cli.reservations import _transform_to_api_format + + with pytest.raises(ValueError, match="missing required fields"): + _transform_to_api_format({"duration_hours": 4}) + + with pytest.raises(ValueError, match="missing required fields"): + _transform_to_api_format({"gpu_type": "t4"}) + + +class TestReservationManager: + """Tests for ReservationManager class""" + + def test_create_reservation_calls_api(self): + """Should call API client to submit job""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.submit_job.return_value = {"job_id": "job-123"} + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.create_reservation( + user_id="test-user", + gpu_count=4, + gpu_type="h100", + duration_hours=8, + github_user="testgithub", + ) + + assert result is not None + mock_api_client.submit_job.assert_called_once() + call_args = mock_api_client.submit_job.call_args[0][0] + assert call_args["instance_type"] == "p5.48xlarge" + assert call_args["duration_hours"] == 8 + + def test_create_reservation_normalizes_gpu_type(self): + """Should normalize GPU type to lowercase""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.submit_job.return_value = {"job_id": "job-123"} + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + manager.create_reservation( + user_id="test-user", + gpu_count=4, + gpu_type="H100", + duration_hours=8, + ) + + call_args = mock_api_client.submit_job.call_args[0][0] + assert call_args["env_vars"]["GPU_TYPE"] == "h100" + + def test_create_reservation_with_all_options(self): + """Should pass all options to API""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.submit_job.return_value = {"job_id": "job-123"} + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + manager.create_reservation( + user_id="test-user", + gpu_count=2, + gpu_type="t4", + duration_hours=4, + github_user="testgithub", + jupyter_enabled=True, + disk_name="my-disk", + dockerimage="custom/image:latest", + preserve_entrypoint=True, + recreate_env=True, + ) + + call_args = mock_api_client.submit_job.call_args[0][0] + assert call_args["image"] == "custom/image:latest" + assert call_args["disk_name"] == "my-disk" + assert call_args["env_vars"]["JUPYTER_ENABLED"] == "true" + assert call_args["env_vars"]["PRESERVE_ENTRYPOINT"] == "true" + assert call_args["env_vars"]["RECREATE_ENV"] == "true" + + def test_create_reservation_handles_api_error(self): + """Should return None on API error""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.submit_job.side_effect = Exception("API error") + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.create_reservation( + user_id="test-user", + gpu_count=1, + gpu_type="t4", + duration_hours=4, + ) + + assert result is None + + +class TestMultinodeReservation: + """Tests for multinode reservation creation""" + + def test_multinode_calculates_node_count(self): + """Should calculate correct number of nodes for multinode""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.submit_job.return_value = {"job_id": "job-123"} + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.create_multinode_reservation( + user_id="test-user", + gpu_count=16, + gpu_type="h100", + duration_hours=8, + ) + + assert result is not None + assert len(result) == 2 + assert mock_api_client.submit_job.call_count == 2 + + def test_multinode_rejects_invalid_gpu_count(self): + """Should reject GPU count not divisible by max per node""" + mock_config = MagicMock() + mock_api_client = MagicMock() + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + result = manager.create_multinode_reservation( + user_id="test-user", + gpu_count=12, + gpu_type="h100", + duration_hours=8, + ) + + assert result is None + mock_api_client.submit_job.assert_not_called() + + def test_multinode_jupyter_only_on_master(self): + """Should enable Jupyter only on master node (node 0)""" + mock_config = MagicMock() + mock_api_client = MagicMock() + mock_api_client.submit_job.return_value = {"job_id": "job-123"} + + with patch("gpu_dev_cli.reservations.APIClient", return_value=mock_api_client): + from gpu_dev_cli.reservations import ReservationManager + + manager = ReservationManager(mock_config) + manager.create_multinode_reservation( + user_id="test-user", + gpu_count=16, + gpu_type="h100", + duration_hours=8, + jupyter_enabled=True, + ) + + calls = mock_api_client.submit_job.call_args_list + first_call_env = calls[0][0][0]["env_vars"] + second_call_env = calls[1][0][0]["env_vars"] + + # First node (master) should have Jupyter enabled + assert first_call_env.get("JUPYTER_ENABLED") == "true" + # Second node should either have Jupyter disabled or not set + assert second_call_env.get("JUPYTER_ENABLED") in (None, "false") + + +class TestSSHHelpers: + """Tests for SSH-related helper functions""" + + def test_add_agent_forwarding_to_ssh(self): + """Should add -A flag to SSH command""" + from gpu_dev_cli.reservations import _add_agent_forwarding_to_ssh + + result = _add_agent_forwarding_to_ssh("ssh dev@1.2.3.4 -p 30001") + assert "-A" in result + assert result == "ssh -A dev@1.2.3.4 -p 30001" + + def test_add_agent_forwarding_already_present(self): + """Should not add -A if already present""" + from gpu_dev_cli.reservations import _add_agent_forwarding_to_ssh + + result = _add_agent_forwarding_to_ssh("ssh -A dev@1.2.3.4 -p 30001") + assert result.count("-A") == 1 + + def test_add_agent_forwarding_invalid_command(self): + """Should return unchanged for non-SSH commands""" + from gpu_dev_cli.reservations import _add_agent_forwarding_to_ssh + + assert _add_agent_forwarding_to_ssh("") == "" + assert _add_agent_forwarding_to_ssh("scp file user@host:") == "scp file user@host:" + + def test_generate_vscode_command(self): + """Should generate VS Code remote SSH command""" + from gpu_dev_cli.reservations import _generate_vscode_command + + result = _generate_vscode_command("ssh dev@myhost.io -p 30001") + + assert result is not None + assert "code --remote" in result + assert "myhost.io" in result + assert "ForwardAgent=yes" in result + + def test_generate_vscode_command_invalid_input(self): + """Should return None for invalid SSH commands""" + from gpu_dev_cli.reservations import _generate_vscode_command + + assert _generate_vscode_command("") is None + assert _generate_vscode_command("not-ssh-command") is None + + +class TestSSHConfigGeneration: + """Tests for SSH config file generation""" + + def test_generate_ssh_config(self): + """Should generate valid SSH config content""" + from gpu_dev_cli.reservations import _generate_ssh_config + + config = _generate_ssh_config("myhost.devservers.io", "gpu-dev-abc123") + + assert "Host gpu-dev-abc123" in config + assert "HostName myhost.devservers.io" in config + assert "User dev" in config + assert "ForwardAgent yes" in config + assert "StrictHostKeyChecking no" in config + + def test_get_ssh_config_path(self): + """Should return correct SSH config path""" + from gpu_dev_cli.reservations import get_ssh_config_path + + path = get_ssh_config_path("abc12345-full-uuid") + + assert ".gpu-dev" in path + assert "abc12345" in path + assert "sshconfig" in path + + def test_get_ssh_config_path_uses_short_id(self): + """Should use short ID regardless of name parameter""" + from gpu_dev_cli.reservations import get_ssh_config_path + + path1 = get_ssh_config_path("abc12345-full-uuid", name="my-experiment") + path2 = get_ssh_config_path("abc12345-full-uuid") + + assert "abc12345" in path1 + assert "abc12345" in path2 + + +class TestIDELinks: + """Tests for IDE URL generation""" + + def test_make_vscode_link(self): + """Should generate correct VS Code remote SSH link""" + from gpu_dev_cli.reservations import _make_vscode_link + + link = _make_vscode_link("gpu-dev-abc123") + + assert link == "vscode://vscode-remote/ssh-remote+gpu-dev-abc123/home/dev" + assert link.startswith("vscode://") + assert "ssh-remote" in link + + def test_make_cursor_link(self): + """Should generate correct Cursor IDE remote SSH link""" + from gpu_dev_cli.reservations import _make_cursor_link + + link = _make_cursor_link("gpu-dev-abc123") + + assert link == "cursor://vscode-remote/ssh-remote+gpu-dev-abc123/home/dev" + assert link.startswith("cursor://") + assert "ssh-remote" in link + + +class TestExtractIPFromReservation: + """Tests for IP extraction from reservation data""" + + def test_extract_ip_with_node_ip_and_port(self): + """Should return IP:Port format""" + from gpu_dev_cli.reservations import _extract_ip_from_reservation + + reservation = { + "node_ip": "1.2.3.4", + "node_port": 30001, + } + + result = _extract_ip_from_reservation(reservation) + assert result == "1.2.3.4:30001" + + def test_extract_ip_with_only_node_ip(self): + """Should return just IP when no port""" + from gpu_dev_cli.reservations import _extract_ip_from_reservation + + reservation = { + "node_ip": "1.2.3.4", + } + + result = _extract_ip_from_reservation(reservation) + assert result == "1.2.3.4" + + def test_extract_ip_missing_data(self): + """Should return N/A when no IP data""" + from gpu_dev_cli.reservations import _extract_ip_from_reservation + + assert _extract_ip_from_reservation({}) == "N/A" + assert _extract_ip_from_reservation({"status": "active"}) == "N/A" diff --git a/tests/unit/lambda/__init__.py b/tests/unit/lambda/__init__.py new file mode 100644 index 00000000..225518c7 --- /dev/null +++ b/tests/unit/lambda/__init__.py @@ -0,0 +1 @@ +"""Unit tests for Lambda functions""" diff --git a/tests/unit/lambda/test_availability.py b/tests/unit/lambda/test_availability.py new file mode 100644 index 00000000..80821052 --- /dev/null +++ b/tests/unit/lambda/test_availability.py @@ -0,0 +1,292 @@ +""" +Unit tests for Lambda availability_updater + +Tests: +- GPU availability calculation +- Node capacity detection +- Wait time estimation +""" + +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest + + +class TestGPUAvailabilityCalculation: + """Tests for calculating available GPUs""" + + def test_calculate_available_gpus_from_nodes(self): + """Should calculate available GPUs across all nodes of a type""" + nodes = [ + {"name": "node-1", "allocatable_gpus": 4, "used_gpus": 2}, + {"name": "node-2", "allocatable_gpus": 4, "used_gpus": 0}, + {"name": "node-3", "allocatable_gpus": 4, "used_gpus": 4}, + ] + + def calculate_available(nodes): + total_available = 0 + for node in nodes: + available = node["allocatable_gpus"] - node["used_gpus"] + total_available += available + return total_available + + assert calculate_available(nodes) == 6 + + def test_calculate_total_capacity(self): + """Should calculate total GPU capacity""" + nodes = [ + {"name": "node-1", "allocatable_gpus": 4}, + {"name": "node-2", "allocatable_gpus": 4}, + {"name": "node-3", "allocatable_gpus": 8}, + ] + + def calculate_total(nodes): + return sum(n["allocatable_gpus"] for n in nodes) + + assert calculate_total(nodes) == 16 + + +class TestNodeCapacityDetection: + """Tests for detecting node GPU capacity""" + + def test_extract_gpu_count_from_allocatable(self): + """Should extract GPU count from K8s allocatable resources""" + allocatable = { + "cpu": "48", + "memory": "192Gi", + "nvidia.com/gpu": "4", + "ephemeral-storage": "100Gi", + } + + def get_gpu_count(allocatable): + gpu_str = allocatable.get("nvidia.com/gpu", "0") + return int(gpu_str) + + assert get_gpu_count(allocatable) == 4 + + def test_handle_missing_gpu_resource(self): + """Should return 0 for nodes without GPUs""" + allocatable = { + "cpu": "32", + "memory": "64Gi", + } + + def get_gpu_count(allocatable): + gpu_str = allocatable.get("nvidia.com/gpu", "0") + return int(gpu_str) + + assert get_gpu_count(allocatable) == 0 + + def test_filter_nodes_by_gpu_type(self): + """Should filter nodes by GPU type label""" + nodes = [ + {"name": "t4-node-1", "labels": {"GpuType": "t4"}}, + {"name": "t4-node-2", "labels": {"GpuType": "t4"}}, + {"name": "h100-node-1", "labels": {"GpuType": "h100"}}, + {"name": "cpu-node-1", "labels": {}}, + ] + + def filter_by_gpu_type(nodes, gpu_type): + return [n for n in nodes if n.get("labels", {}).get("GpuType") == gpu_type] + + t4_nodes = filter_by_gpu_type(nodes, "t4") + h100_nodes = filter_by_gpu_type(nodes, "h100") + + assert len(t4_nodes) == 2 + assert len(h100_nodes) == 1 + + +class TestUsedGPUCalculation: + """Tests for calculating GPUs in use""" + + def test_sum_gpu_requests_from_pods(self): + """Should sum GPU requests from all pods on a node""" + pods = [ + {"name": "pod-1", "gpu_request": 2}, + {"name": "pod-2", "gpu_request": 1}, + {"name": "pod-3", "gpu_request": 0}, # CPU pod + ] + + def calculate_used(pods): + return sum(p.get("gpu_request", 0) for p in pods) + + assert calculate_used(pods) == 3 + + def test_filter_gpu_pods_only(self): + """Should only count pods with GPU requests""" + pods = [ + {"name": "gpu-pod-1", "resources": {"nvidia.com/gpu": "2"}}, + {"name": "gpu-pod-2", "resources": {"nvidia.com/gpu": "1"}}, + {"name": "cpu-pod", "resources": {"cpu": "4"}}, + ] + + def get_pod_gpu_request(pod): + resources = pod.get("resources", {}) + return int(resources.get("nvidia.com/gpu", "0")) + + def calculate_used(pods): + return sum(get_pod_gpu_request(p) for p in pods) + + assert calculate_used(pods) == 3 + + +class TestWaitTimeEstimation: + """Tests for queue wait time estimation""" + + def test_estimate_based_on_queue_length(self): + """Should estimate wait time based on queue length""" + def estimate_wait_minutes(queue_length, avg_reservation_hours=4): + if queue_length == 0: + return 0 + # Simple estimate: each queued reservation waits for avg duration + return queue_length * avg_reservation_hours * 60 + + assert estimate_wait_minutes(0) == 0 + assert estimate_wait_minutes(1) == 240 # 4 hours + assert estimate_wait_minutes(3) == 720 # 12 hours + + def test_estimate_considers_gpu_count(self): + """Should factor in GPU requirements for estimation""" + def estimate_wait_detailed(queue, available_gpus, avg_hours=4): + if not queue: + return 0 + + # Calculate how many queue items can be served now + remaining = list(queue) + wait_time = 0 + + while remaining: + # Find reservations that can fit in available capacity + can_serve = [] + cannot_serve = [] + + for item in remaining: + if item["gpu_count"] <= available_gpus: + can_serve.append(item) + available_gpus -= item["gpu_count"] + else: + cannot_serve.append(item) + + if not can_serve: + # Nothing can be served, wait for next cycle + wait_time += avg_hours * 60 + # Assume some GPUs free up + available_gpus = 4 # Reset assumption + + remaining = cannot_serve + + return wait_time + + queue = [ + {"reservation_id": "res-1", "gpu_count": 2}, + {"reservation_id": "res-2", "gpu_count": 2}, + ] + + # With 4 available GPUs, both can be served immediately + assert estimate_wait_detailed(queue, available_gpus=4) == 0 + + # With 2 GPUs, first serves now, second waits + queue = [ + {"reservation_id": "res-1", "gpu_count": 4}, + {"reservation_id": "res-2", "gpu_count": 4}, + ] + # First takes all 4, second waits + result = estimate_wait_detailed(queue, available_gpus=4) + assert result > 0 + + +class TestAvailabilityTableUpdate: + """Tests for DynamoDB availability table updates""" + + def test_availability_record_structure(self): + """Should create properly structured availability record""" + def create_availability_record(gpu_type, available, total, queue_length): + return { + "gpu_type": gpu_type, + "available_gpus": Decimal(available), + "total_gpus": Decimal(total), + "queue_length": Decimal(queue_length), + "estimated_wait_minutes": Decimal(queue_length * 240), + "last_updated": datetime.now(timezone.utc).isoformat(), + } + + record = create_availability_record("t4", 6, 8, 2) + + assert record["gpu_type"] == "t4" + assert record["available_gpus"] == Decimal(6) + assert record["total_gpus"] == Decimal(8) + assert record["queue_length"] == Decimal(2) + assert "last_updated" in record + + def test_all_gpu_types_tracked(self): + """Should track availability for all GPU types""" + GPU_TYPES = ["t4", "l4", "a10g", "a100", "h100", "h200", "b200"] + + def update_all_availability(tracker): + results = {} + for gpu_type in GPU_TYPES: + results[gpu_type] = { + "available": tracker.get(gpu_type, {}).get("available", 0), + "total": tracker.get(gpu_type, {}).get("total", 0), + } + return results + + mock_tracker = { + "t4": {"available": 4, "total": 8}, + "h100": {"available": 0, "total": 16}, + } + + results = update_all_availability(mock_tracker) + + assert len(results) == 7 + assert results["t4"]["available"] == 4 + assert results["h100"]["available"] == 0 + assert results["a100"]["available"] == 0 # Not in tracker + + +class TestNodeReadiness: + """Tests for node readiness checks""" + + def test_node_is_ready(self): + """Should detect ready nodes""" + def is_node_ready(conditions): + for condition in conditions: + if condition.get("type") == "Ready": + return condition.get("status") == "True" + return False + + ready_conditions = [ + {"type": "Ready", "status": "True"}, + {"type": "MemoryPressure", "status": "False"}, + ] + not_ready_conditions = [ + {"type": "Ready", "status": "False"}, + {"type": "MemoryPressure", "status": "True"}, + ] + + assert is_node_ready(ready_conditions) is True + assert is_node_ready(not_ready_conditions) is False + + def test_node_is_schedulable(self): + """Should detect schedulable nodes""" + def is_schedulable(spec): + return not spec.get("unschedulable", False) + + assert is_schedulable({}) is True + assert is_schedulable({"unschedulable": False}) is True + assert is_schedulable({"unschedulable": True}) is False + + def test_exclude_cordoned_nodes(self): + """Should exclude cordoned nodes from capacity""" + nodes = [ + {"name": "node-1", "unschedulable": False, "gpus": 4}, + {"name": "node-2", "unschedulable": True, "gpus": 4}, # Cordoned + {"name": "node-3", "unschedulable": False, "gpus": 4}, + ] + + def get_schedulable_capacity(nodes): + return sum(n["gpus"] for n in nodes if not n.get("unschedulable", False)) + + assert get_schedulable_capacity(nodes) == 8 diff --git a/tests/unit/lambda/test_expiry.py b/tests/unit/lambda/test_expiry.py new file mode 100644 index 00000000..f27c390c --- /dev/null +++ b/tests/unit/lambda/test_expiry.py @@ -0,0 +1,284 @@ +""" +Unit tests for Lambda reservation_expiry + +Tests: +- Expiration detection +- Warning generation +- Cleanup operations +""" + +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest + + +class TestExpirationDetection: + """Tests for detecting expired reservations""" + + def test_reservation_is_expired_when_past_expires_at(self): + """Should detect reservation as expired when expires_at is in the past""" + now = datetime.now(timezone.utc) + + reservation = { + "reservation_id": "res-123", + "status": "active", + "expires_at": (now - timedelta(minutes=5)).isoformat(), + } + + def is_expired(reservation): + expires_at_str = reservation.get("expires_at") + if not expires_at_str: + return False + expires_at = datetime.fromisoformat(expires_at_str) + return datetime.now(timezone.utc) > expires_at + + assert is_expired(reservation) is True + + def test_reservation_not_expired_when_future(self): + """Should not detect reservation as expired when expires_at is in future""" + now = datetime.now(timezone.utc) + + reservation = { + "reservation_id": "res-123", + "status": "active", + "expires_at": (now + timedelta(hours=2)).isoformat(), + } + + def is_expired(reservation): + expires_at_str = reservation.get("expires_at") + if not expires_at_str: + return False + expires_at = datetime.fromisoformat(expires_at_str) + return datetime.now(timezone.utc) > expires_at + + assert is_expired(reservation) is False + + def test_only_active_reservations_can_expire(self): + """Should only check expiration for active reservations""" + EXPIRABLE_STATUSES = ["active", "preparing"] + + def should_check_expiry(reservation): + return reservation.get("status") in EXPIRABLE_STATUSES + + assert should_check_expiry({"status": "active"}) is True + assert should_check_expiry({"status": "preparing"}) is True + assert should_check_expiry({"status": "completed"}) is False + assert should_check_expiry({"status": "cancelled"}) is False + assert should_check_expiry({"status": "queued"}) is False + + +class TestWarningGeneration: + """Tests for expiry warning generation""" + + def test_warning_at_30_minutes(self): + """Should generate warning at 30 minutes before expiry""" + now = datetime.now(timezone.utc) + + reservation = { + "reservation_id": "res-123", + "status": "active", + "expires_at": (now + timedelta(minutes=28)).isoformat(), + "warned_30min": False, + } + + def should_warn_30min(reservation): + if reservation.get("warned_30min"): + return False + expires_at = datetime.fromisoformat(reservation["expires_at"]) + time_remaining = expires_at - datetime.now(timezone.utc) + return time_remaining <= timedelta(minutes=30) and time_remaining > timedelta(minutes=15) + + assert should_warn_30min(reservation) is True + + def test_warning_at_15_minutes(self): + """Should generate warning at 15 minutes before expiry""" + now = datetime.now(timezone.utc) + + reservation = { + "reservation_id": "res-123", + "status": "active", + "expires_at": (now + timedelta(minutes=12)).isoformat(), + "warned_15min": False, + } + + def should_warn_15min(reservation): + if reservation.get("warned_15min"): + return False + expires_at = datetime.fromisoformat(reservation["expires_at"]) + time_remaining = expires_at - datetime.now(timezone.utc) + return time_remaining <= timedelta(minutes=15) and time_remaining > timedelta(minutes=5) + + assert should_warn_15min(reservation) is True + + def test_warning_at_5_minutes(self): + """Should generate warning at 5 minutes before expiry""" + now = datetime.now(timezone.utc) + + reservation = { + "reservation_id": "res-123", + "status": "active", + "expires_at": (now + timedelta(minutes=3)).isoformat(), + "warned_5min": False, + } + + def should_warn_5min(reservation): + if reservation.get("warned_5min"): + return False + expires_at = datetime.fromisoformat(reservation["expires_at"]) + time_remaining = expires_at - datetime.now(timezone.utc) + return time_remaining <= timedelta(minutes=5) and time_remaining > timedelta(0) + + assert should_warn_5min(reservation) is True + + def test_no_duplicate_warnings(self): + """Should not generate warning if already warned""" + now = datetime.now(timezone.utc) + + reservation = { + "reservation_id": "res-123", + "status": "active", + "expires_at": (now + timedelta(minutes=28)).isoformat(), + "warned_30min": True, # Already warned + } + + def should_warn_30min(reservation): + if reservation.get("warned_30min"): + return False + expires_at = datetime.fromisoformat(reservation["expires_at"]) + time_remaining = expires_at - datetime.now(timezone.utc) + return time_remaining <= timedelta(minutes=30) + + assert should_warn_30min(reservation) is False + + +class TestCleanupOperations: + """Tests for cleanup logic""" + + def test_cleanup_order(self): + """Should cleanup resources in correct order""" + cleanup_steps = [] + + def cleanup_reservation(reservation): + # Order matters: + # 1. Create final snapshot (before pod deletion) + cleanup_steps.append("snapshot") + # 2. Delete the pod + cleanup_steps.append("delete_pod") + # 3. Delete the service + cleanup_steps.append("delete_service") + # 4. Update reservation status + cleanup_steps.append("update_status") + # 5. Clear disk in_use flag + cleanup_steps.append("clear_disk") + + cleanup_reservation({}) + + assert cleanup_steps == [ + "snapshot", + "delete_pod", + "delete_service", + "update_status", + "clear_disk", + ] + + def test_snapshot_created_before_deletion(self): + """Should create snapshot before deleting pod""" + actions = [] + + def create_shutdown_snapshot(pod_name, volume_id): + actions.append(f"snapshot:{pod_name}") + + def delete_pod(pod_name): + actions.append(f"delete:{pod_name}") + + def cleanup_with_snapshot(pod_name, volume_id): + create_shutdown_snapshot(pod_name, volume_id) + delete_pod(pod_name) + + cleanup_with_snapshot("test-pod", "vol-123") + + assert actions[0].startswith("snapshot:") + assert actions[1].startswith("delete:") + + +class TestStalePendingCleanup: + """Tests for cleaning up stale pending reservations""" + + def test_stale_pending_threshold(self): + """Should identify reservations pending too long""" + STALE_THRESHOLD_DAYS = 7 + now = datetime.now(timezone.utc) + + def is_stale_pending(reservation): + if reservation.get("status") not in ["queued", "pending"]: + return False + created_at = datetime.fromisoformat(reservation["created_at"]) + age = now - created_at + return age > timedelta(days=STALE_THRESHOLD_DAYS) + + old_reservation = { + "status": "queued", + "created_at": (now - timedelta(days=10)).isoformat(), + } + recent_reservation = { + "status": "queued", + "created_at": (now - timedelta(days=2)).isoformat(), + } + active_old = { + "status": "active", + "created_at": (now - timedelta(days=10)).isoformat(), + } + + assert is_stale_pending(old_reservation) is True + assert is_stale_pending(recent_reservation) is False + assert is_stale_pending(active_old) is False # Active, not pending + + +class TestSnapshotRetention: + """Tests for snapshot retention policy""" + + def test_keep_recent_snapshots(self): + """Should keep the most recent N snapshots""" + KEEP_LATEST = 3 + now = datetime.now(timezone.utc) + + snapshots = [ + {"SnapshotId": "snap-1", "StartTime": now - timedelta(days=1)}, + {"SnapshotId": "snap-2", "StartTime": now - timedelta(days=5)}, + {"SnapshotId": "snap-3", "StartTime": now - timedelta(days=10)}, + {"SnapshotId": "snap-4", "StartTime": now - timedelta(days=15)}, + {"SnapshotId": "snap-5", "StartTime": now - timedelta(days=20)}, + ] + + def get_snapshots_to_delete(snapshots, keep_latest=3): + sorted_snaps = sorted(snapshots, key=lambda x: x["StartTime"], reverse=True) + return sorted_snaps[keep_latest:] + + to_delete = get_snapshots_to_delete(snapshots, KEEP_LATEST) + + assert len(to_delete) == 2 + assert to_delete[0]["SnapshotId"] == "snap-4" + assert to_delete[1]["SnapshotId"] == "snap-5" + + def test_delete_snapshots_older_than_30_days(self): + """Should delete snapshots older than retention period""" + RETENTION_DAYS = 30 + now = datetime.now(timezone.utc) + + snapshots = [ + {"SnapshotId": "snap-recent", "StartTime": now - timedelta(days=5)}, + {"SnapshotId": "snap-old", "StartTime": now - timedelta(days=35)}, + {"SnapshotId": "snap-very-old", "StartTime": now - timedelta(days=60)}, + ] + + def get_old_snapshots(snapshots, retention_days=30): + cutoff = datetime.now(timezone.utc) - timedelta(days=retention_days) + return [s for s in snapshots if s["StartTime"] < cutoff] + + old = get_old_snapshots(snapshots, RETENTION_DAYS) + + assert len(old) == 2 + assert "snap-old" in [s["SnapshotId"] for s in old] + assert "snap-very-old" in [s["SnapshotId"] for s in old] diff --git a/tests/unit/lambda/test_reservation_processor.py b/tests/unit/lambda/test_reservation_processor.py new file mode 100644 index 00000000..5d866875 --- /dev/null +++ b/tests/unit/lambda/test_reservation_processor.py @@ -0,0 +1,386 @@ +""" +Unit tests for Lambda reservation_processor + +Tests: +- CLI version validation +- GPU configuration +- Retry with backoff +- Resource calculation +""" + +import os +import sys +from datetime import datetime, timezone +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest + + +# Set up environment before importing Lambda code +@pytest.fixture(autouse=True) +def lambda_env(): + """Set required environment variables for Lambda""" + env_vars = { + "RESERVATIONS_TABLE": "pytorch-gpu-dev-test-reservations", + "EKS_CLUSTER_NAME": "pytorch-gpu-dev-test-cluster", + "REGION": "us-west-1", + "MAX_RESERVATION_HOURS": "48", + "DEFAULT_TIMEOUT_HOURS": "8", + "QUEUE_URL": "https://sqs.us-west-1.amazonaws.com/123456789012/test-queue", + "PRIMARY_AVAILABILITY_ZONE": "us-west-1a", + "GPU_DEV_CONTAINER_IMAGE": "pytorch/pytorch:2.8.0-cuda12.9-cudnn9-devel", + "LAMBDA_VERSION": "0.3.5", + "MIN_CLI_VERSION": "0.3.0", + } + with patch.dict(os.environ, env_vars): + yield env_vars + + +class TestCLIVersionValidation: + """Tests for CLI version validation""" + + def test_validate_version_passes_for_equal_version(self, lambda_env): + """Should pass when CLI version equals minimum""" + # Need to mock the imports that require K8s + with patch.dict(sys.modules, { + 'shared': MagicMock(), + 'shared.snapshot_utils': MagicMock(), + 'buildkit_job': MagicMock(), + 'shared.dns_utils': MagicMock(), + 'kubernetes': MagicMock(), + 'kubernetes.client': MagicMock(), + 'kubernetes.stream': MagicMock(), + }): + # Re-import with mocks + import importlib + if 'index' in sys.modules: + del sys.modules['index'] + + # Directly test the version parsing logic + def parse_version(version_str): + try: + return tuple(map(int, version_str.split('.'))) + except (ValueError, AttributeError): + return (0, 0, 0) + + cli_ver = parse_version("0.3.0") + min_ver = parse_version("0.3.0") + + assert cli_ver >= min_ver + + def test_validate_version_passes_for_newer_version(self, lambda_env): + """Should pass when CLI version is newer than minimum""" + def parse_version(version_str): + try: + return tuple(map(int, version_str.split('.'))) + except (ValueError, AttributeError): + return (0, 0, 0) + + cli_ver = parse_version("0.4.0") + min_ver = parse_version("0.3.0") + + assert cli_ver >= min_ver + + def test_validate_version_fails_for_older_version(self, lambda_env): + """Should fail when CLI version is older than minimum""" + def parse_version(version_str): + try: + return tuple(map(int, version_str.split('.'))) + except (ValueError, AttributeError): + return (0, 0, 0) + + cli_ver = parse_version("0.2.9") + min_ver = parse_version("0.3.0") + + assert cli_ver < min_ver + + def test_validate_version_handles_patch_versions(self, lambda_env): + """Should correctly compare patch versions""" + def parse_version(version_str): + try: + return tuple(map(int, version_str.split('.'))) + except (ValueError, AttributeError): + return (0, 0, 0) + + assert parse_version("0.3.5") > parse_version("0.3.0") + assert parse_version("0.3.10") > parse_version("0.3.9") + assert parse_version("1.0.0") > parse_version("0.99.99") + + +class TestGPUConfiguration: + """Tests for GPU_CONFIG structure""" + + def test_gpu_config_has_required_types(self, lambda_env): + """Should have all expected GPU types configured""" + GPU_CONFIG = { + "t4": {"instance_type": "g4dn.12xlarge", "max_gpus": 4, "cpus": 48, "memory_gb": 192}, + "l4": {"instance_type": "g6.12xlarge", "max_gpus": 4, "cpus": 48, "memory_gb": 192}, + "a10g": {"instance_type": "g5.12xlarge", "max_gpus": 4, "cpus": 48, "memory_gb": 192}, + "a100": {"instance_type": "p4d.24xlarge", "max_gpus": 8, "cpus": 96, "memory_gb": 1152}, + "h100": {"instance_type": "p5.48xlarge", "max_gpus": 8, "cpus": 192, "memory_gb": 2048}, + "h200": {"instance_type": "p5e.48xlarge", "max_gpus": 8, "cpus": 192, "memory_gb": 2048}, + "b200": {"instance_type": "p6-b200.48xlarge", "max_gpus": 8, "cpus": 192, "memory_gb": 2048}, + "cpu-arm": {"instance_type": "c7g.8xlarge", "max_gpus": 0, "cpus": 32, "memory_gb": 64}, + "cpu-x86": {"instance_type": "c7i.8xlarge", "max_gpus": 0, "cpus": 32, "memory_gb": 64}, + } + + expected_types = ["t4", "l4", "a10g", "a100", "h100", "h200", "b200", "cpu-arm", "cpu-x86"] + for gpu_type in expected_types: + assert gpu_type in GPU_CONFIG + + def test_gpu_config_has_required_fields(self, lambda_env): + """Each GPU config should have required fields""" + GPU_CONFIG = { + "t4": {"instance_type": "g4dn.12xlarge", "max_gpus": 4, "cpus": 48, "memory_gb": 192}, + "h100": {"instance_type": "p5.48xlarge", "max_gpus": 8, "cpus": 192, "memory_gb": 2048}, + } + + required_fields = ["instance_type", "max_gpus", "cpus", "memory_gb"] + + for gpu_type, config in GPU_CONFIG.items(): + for field in required_fields: + assert field in config, f"Missing {field} in {gpu_type} config" + + def test_cpu_types_have_zero_gpus(self, lambda_env): + """CPU instance types should have max_gpus=0""" + GPU_CONFIG = { + "cpu-arm": {"instance_type": "c7g.8xlarge", "max_gpus": 0, "cpus": 32, "memory_gb": 64}, + "cpu-x86": {"instance_type": "c7i.8xlarge", "max_gpus": 0, "cpus": 32, "memory_gb": 64}, + } + + assert GPU_CONFIG["cpu-arm"]["max_gpus"] == 0 + assert GPU_CONFIG["cpu-x86"]["max_gpus"] == 0 + + +class TestRetryWithBackoff: + """Tests for retry_with_backoff function""" + + def test_retry_returns_on_success(self, lambda_env): + """Should return immediately on success""" + call_count = 0 + + def successful_func(): + nonlocal call_count + call_count += 1 + return "success" + + # Implement retry logic inline for testing + def retry_with_backoff(func, max_retries=5): + return func() + + result = retry_with_backoff(successful_func) + + assert result == "success" + assert call_count == 1 + + def test_retry_retries_on_throttling(self, lambda_env): + """Should retry on throttling errors""" + import botocore.exceptions + + call_count = 0 + + def throttled_then_success(): + nonlocal call_count + call_count += 1 + if call_count < 3: + error_response = {'Error': {'Code': 'Throttling'}} + raise botocore.exceptions.ClientError(error_response, 'test') + return "success" + + def retry_with_backoff(func, max_retries=5, initial_delay=0.01): + import time + delay = initial_delay + for attempt in range(max_retries): + try: + return func() + except botocore.exceptions.ClientError as e: + error_code = e.response.get('Error', {}).get('Code', '') + if error_code not in ['Throttling', 'RequestLimitExceeded']: + raise + if attempt < max_retries - 1: + time.sleep(delay) + delay *= 2 + else: + raise + + result = retry_with_backoff(throttled_then_success) + + assert result == "success" + assert call_count == 3 + + def test_retry_raises_non_throttling_errors(self, lambda_env): + """Should not retry on non-throttling errors""" + import botocore.exceptions + + def failing_func(): + error_response = {'Error': {'Code': 'ValidationError'}} + raise botocore.exceptions.ClientError(error_response, 'test') + + def retry_with_backoff(func, max_retries=5): + try: + return func() + except botocore.exceptions.ClientError as e: + error_code = e.response.get('Error', {}).get('Code', '') + if error_code not in ['Throttling', 'RequestLimitExceeded']: + raise + raise + + with pytest.raises(botocore.exceptions.ClientError): + retry_with_backoff(failing_func) + + +class TestResourceCalculation: + """Tests for resource limit/request calculations""" + + def test_calculate_cpu_limits(self, lambda_env): + """Should calculate correct CPU limits based on GPU ratio""" + GPU_CONFIG = { + "t4": {"max_gpus": 4, "cpus": 48, "memory_gb": 192}, + "h100": {"max_gpus": 8, "cpus": 192, "memory_gb": 2048}, + } + + def get_pod_resource_limits(gpu_count, gpu_type): + gpu_count = int(gpu_count) + config = GPU_CONFIG.get(gpu_type, GPU_CONFIG["t4"]) + max_gpus = config["max_gpus"] + total_cpus = config["cpus"] + + # Scale CPU based on GPU ratio + if max_gpus > 0: + cpu_per_gpu = total_cpus / max_gpus + allocated_cpus = int(cpu_per_gpu * gpu_count) + else: + allocated_cpus = total_cpus + + return {"cpu": f"{allocated_cpus}"} + + # Test T4: 4 GPUs on 48 CPUs = 12 CPUs per GPU + result = get_pod_resource_limits(2, "t4") + assert result["cpu"] == "24" + + # Test H100: 8 GPUs on 192 CPUs = 24 CPUs per GPU + result = get_pod_resource_limits(4, "h100") + assert result["cpu"] == "96" + + def test_calculate_memory_limits(self, lambda_env): + """Should calculate correct memory limits based on GPU ratio""" + GPU_CONFIG = { + "t4": {"max_gpus": 4, "cpus": 48, "memory_gb": 192}, + "h100": {"max_gpus": 8, "cpus": 192, "memory_gb": 2048}, + } + + def get_pod_resource_limits(gpu_count, gpu_type): + gpu_count = int(gpu_count) + config = GPU_CONFIG.get(gpu_type, GPU_CONFIG["t4"]) + max_gpus = config["max_gpus"] + total_memory = config["memory_gb"] + + if max_gpus > 0: + memory_per_gpu = total_memory / max_gpus + allocated_memory = int(memory_per_gpu * gpu_count) + else: + allocated_memory = total_memory + + return {"memory": f"{allocated_memory}Gi"} + + # Test T4: 4 GPUs sharing 192GB = 48GB per GPU + result = get_pod_resource_limits(2, "t4") + assert result["memory"] == "96Gi" + + # Test H100: 8 GPUs sharing 2048GB = 256GB per GPU + result = get_pod_resource_limits(1, "h100") + assert result["memory"] == "256Gi" + + def test_handles_decimal_gpu_count(self, lambda_env): + """Should handle Decimal type from DynamoDB""" + GPU_CONFIG = {"t4": {"max_gpus": 4, "cpus": 48, "memory_gb": 192}} + + def get_pod_resource_limits(gpu_count, gpu_type): + gpu_count = int(gpu_count) # Convert Decimal to int + config = GPU_CONFIG.get(gpu_type, GPU_CONFIG["t4"]) + max_gpus = config["max_gpus"] + total_cpus = config["cpus"] + cpu_per_gpu = total_cpus / max_gpus + allocated_cpus = int(cpu_per_gpu * gpu_count) + return {"cpu": f"{allocated_cpus}"} + + # Pass Decimal like DynamoDB would + result = get_pod_resource_limits(Decimal("2"), "t4") + assert result["cpu"] == "24" + + +class TestQueuePositionCalculation: + """Tests for queue position and wait time estimation""" + + def test_calculate_queue_position(self, lambda_env): + """Should calculate correct queue position""" + # Simulate queue items + queue_items = [ + {"reservation_id": "res-1", "created_at": "2024-01-01T10:00:00+00:00"}, + {"reservation_id": "res-2", "created_at": "2024-01-01T10:01:00+00:00"}, + {"reservation_id": "res-3", "created_at": "2024-01-01T10:02:00+00:00"}, + ] + + def get_queue_position(reservation_id, queue_items): + sorted_items = sorted(queue_items, key=lambda x: x["created_at"]) + for i, item in enumerate(sorted_items): + if item["reservation_id"] == reservation_id: + return i + 1 + return None + + assert get_queue_position("res-1", queue_items) == 1 + assert get_queue_position("res-2", queue_items) == 2 + assert get_queue_position("res-3", queue_items) == 3 + + def test_estimate_wait_time_based_on_queue(self, lambda_env): + """Should estimate wait time based on queue position and average duration""" + def estimate_wait_time(queue_position, avg_duration_hours=4): + if queue_position <= 0: + return 0 + # Rough estimate: each reservation ahead takes avg_duration_hours + return queue_position * avg_duration_hours * 60 # minutes + + # First in queue, no wait + assert estimate_wait_time(0) == 0 + + # Second in queue + assert estimate_wait_time(1) == 240 # 4 hours + + # Third in queue + assert estimate_wait_time(2) == 480 # 8 hours + + +class TestReservationStatusTransitions: + """Tests for reservation status state machine""" + + def test_valid_status_transitions(self, lambda_env): + """Should validate allowed status transitions""" + VALID_TRANSITIONS = { + "queued": ["pending", "cancelled", "failed"], + "pending": ["preparing", "cancelled", "failed"], + "preparing": ["active", "cancelled", "failed"], + "active": ["completed", "cancelled", "failed"], + "completed": [], + "cancelled": [], + "failed": [], + } + + def is_valid_transition(from_status, to_status): + return to_status in VALID_TRANSITIONS.get(from_status, []) + + # Valid transitions + assert is_valid_transition("queued", "pending") + assert is_valid_transition("pending", "preparing") + assert is_valid_transition("preparing", "active") + assert is_valid_transition("active", "completed") + + # Cancellation always valid from active states + assert is_valid_transition("queued", "cancelled") + assert is_valid_transition("pending", "cancelled") + assert is_valid_transition("active", "cancelled") + + # Invalid transitions + assert not is_valid_transition("completed", "active") + assert not is_valid_transition("cancelled", "active") + assert not is_valid_transition("failed", "active") diff --git a/tests/unit/services/__init__.py b/tests/unit/services/__init__.py new file mode 100644 index 00000000..3c1506a4 --- /dev/null +++ b/tests/unit/services/__init__.py @@ -0,0 +1 @@ +# Service unit tests diff --git a/tests/unit/services/conftest.py b/tests/unit/services/conftest.py new file mode 100644 index 00000000..933314c3 --- /dev/null +++ b/tests/unit/services/conftest.py @@ -0,0 +1,324 @@ +""" +Shared pytest fixtures for service unit tests. +""" +import json +import uuid +from datetime import UTC, datetime, timedelta +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +# ============================================================================ +# Mock PostgreSQL Connection Fixtures +# ============================================================================ + +@pytest.fixture +def mock_asyncpg_pool(): + """Mock asyncpg connection pool for API tests.""" + pool = AsyncMock() + conn = AsyncMock() + pool.acquire.return_value.__aenter__.return_value = conn + pool.acquire.return_value.__aexit__.return_value = None + return pool, conn + + +@pytest.fixture +def mock_db_cursor(): + """Mock psycopg2 cursor with RealDictCursor behavior.""" + cursor = MagicMock() + cursor.fetchone.return_value = None + cursor.fetchall.return_value = [] + cursor.__enter__ = MagicMock(return_value=cursor) + cursor.__exit__ = MagicMock(return_value=None) + return cursor + + +@pytest.fixture +def mock_db_connection(mock_db_cursor): + """Mock psycopg2 connection.""" + conn = MagicMock() + conn.cursor.return_value = mock_db_cursor + conn.__enter__ = MagicMock(return_value=conn) + conn.__exit__ = MagicMock(return_value=None) + return conn + + +@pytest.fixture +def mock_connection_pool(mock_db_connection): + """Mock psycopg2 ThreadedConnectionPool.""" + pool = MagicMock() + pool.getconn.return_value = mock_db_connection + pool.putconn.return_value = None + pool.minconn = 1 + pool.maxconn = 10 + pool.closed = False + return pool + + +# ============================================================================ +# Mock Kubernetes Client Fixtures +# ============================================================================ + +@pytest.fixture +def mock_k8s_batch_api(): + """Mock Kubernetes BatchV1Api.""" + api = MagicMock() + api.create_namespaced_job.return_value = None + api.delete_namespaced_job.return_value = None + api.read_namespaced_job_status.return_value = MagicMock( + status=MagicMock( + active=0, + succeeded=1, + failed=0, + start_time=datetime.now(UTC), + completion_time=datetime.now(UTC) + ) + ) + api.list_namespaced_job.return_value = MagicMock(items=[]) + return api + + +@pytest.fixture +def mock_k8s_core_api(): + """Mock Kubernetes CoreV1Api.""" + api = MagicMock() + api.list_namespaced_pod.return_value = MagicMock(items=[]) + api.read_namespaced_pod_log.return_value = "pod logs here" + api.create_namespaced_pod.return_value = None + api.delete_namespaced_pod.return_value = None + api.list_node.return_value = MagicMock(items=[]) + return api + + +# ============================================================================ +# Mock AWS Client Fixtures +# ============================================================================ + +@pytest.fixture +def mock_ec2_client(): + """Mock boto3 EC2 client.""" + client = MagicMock() + # describe_volumes mock + client.describe_volumes.return_value = {"Volumes": []} + # describe_snapshots mock + client.describe_snapshots.return_value = {"Snapshots": []} + # create_volume mock + client.create_volume.return_value = {"VolumeId": "vol-12345678"} + # delete_volume mock + client.delete_volume.return_value = {} + # create_snapshot mock + client.create_snapshot.return_value = {"SnapshotId": "snap-12345678"} + # create_tags mock + client.create_tags.return_value = {} + # delete_tags mock + client.delete_tags.return_value = {} + # Paginator mock + paginator = MagicMock() + paginator.paginate.return_value = [{"Volumes": []}] + client.get_paginator.return_value = paginator + return client + + +@pytest.fixture +def mock_sts_client(): + """Mock boto3 STS client.""" + client = MagicMock() + client.get_caller_identity.return_value = { + "Account": "123456789012", + "UserId": "AIDAEXAMPLE", + "Arn": "arn:aws:sts::123456789012:assumed-role/SSOCloudDevGpuReservation/testuser" + } + return client + + +# ============================================================================ +# Test Data Factory Fixtures +# ============================================================================ + +@pytest.fixture +def sample_reservation(): + """Factory for sample reservation data.""" + def _create( + reservation_id: str | None = None, + user_id: str = "testuser", + status: str = "active", + gpu_type: str = "a100", + gpu_count: int = 4, + duration_hours: int = 4, + **kwargs + ) -> dict[str, Any]: + return { + "reservation_id": reservation_id or str(uuid.uuid4()), + "user_id": user_id, + "status": status, + "gpu_type": gpu_type, + "gpu_count": gpu_count, + "instance_type": "p4d.24xlarge", + "duration_hours": duration_hours, + "created_at": datetime.now(UTC), + "expires_at": datetime.now(UTC) + timedelta(hours=duration_hours), + "name": kwargs.get("name", "test-pod"), + "pod_name": kwargs.get("pod_name", "gpu-dev-test"), + "node_ip": kwargs.get("node_ip", "10.0.0.1"), + "node_port": kwargs.get("node_port", 30001), + "jupyter_enabled": kwargs.get("jupyter_enabled", False), + "jupyter_url": kwargs.get("jupyter_url"), + "jupyter_token": kwargs.get("jupyter_token"), + "github_user": kwargs.get("github_user", "testghuser"), + **kwargs + } + return _create + + +@pytest.fixture +def sample_disk(): + """Factory for sample disk data.""" + def _create( + disk_name: str = "test-disk", + user_id: str = "testuser", + size_gb: int = 100, + **kwargs + ) -> dict[str, Any]: + return { + "disk_id": kwargs.get("disk_id", 1), + "disk_name": disk_name, + "user_id": user_id, + "ebs_volume_id": kwargs.get("ebs_volume_id", "vol-12345678"), + "size_gb": size_gb, + "created_at": kwargs.get("created_at", datetime.now(UTC)), + "last_used": kwargs.get("last_used"), + "in_use": kwargs.get("in_use", False), + "reservation_id": kwargs.get("reservation_id"), + "is_backing_up": kwargs.get("is_backing_up", False), + "is_deleted": kwargs.get("is_deleted", False), + "snapshot_count": kwargs.get("snapshot_count", 0), + "last_snapshot_at": kwargs.get("last_snapshot_at"), + **kwargs + } + return _create + + +@pytest.fixture +def sample_aws_volume(): + """Factory for sample AWS EBS volume data.""" + def _create( + volume_id: str = "vol-12345678", + user_id: str = "testuser", + disk_name: str = "test-disk", + size_gb: int = 100, + **kwargs + ) -> dict[str, Any]: + return { + "VolumeId": volume_id, + "Size": size_gb, + "State": kwargs.get("state", "available"), + "AvailabilityZone": kwargs.get("availability_zone", "us-east-1a"), + "CreateTime": kwargs.get("created_at", datetime.now(UTC)), + "Attachments": kwargs.get("attachments", []), + "Tags": [ + {"Key": "gpu-dev-user", "Value": user_id}, + {"Key": "disk-name", "Value": disk_name}, + *kwargs.get("extra_tags", []) + ] + } + return _create + + +@pytest.fixture +def sample_pgmq_message(): + """Factory for sample PGMQ message data.""" + def _create( + msg_id: int = 1, + action: str = "create_reservation", + user_id: str = "testuser", + **kwargs + ) -> dict[str, Any]: + message_body = { + "action": action, + "user_id": user_id, + "reservation_id": kwargs.get("reservation_id", str(uuid.uuid4())), + "gpu_type": kwargs.get("gpu_type", "a100"), + "gpu_count": kwargs.get("gpu_count", 4), + "github_user": kwargs.get("github_user", "testghuser"), + "duration_hours": kwargs.get("duration_hours", 4), + "version": kwargs.get("version", "0.4.0"), + "_metadata": { + "retry_count": kwargs.get("retry_count", 0), + "max_retries": kwargs.get("max_retries", 3), + "created_at": datetime.now(UTC).isoformat() + }, + **{k: v for k, v in kwargs.items() if k not in [ + "reservation_id", "gpu_type", "gpu_count", "github_user", + "duration_hours", "version", "retry_count", "max_retries" + ]} + } + return { + "msg_id": msg_id, + "read_ct": kwargs.get("read_ct", 1), + "enqueued_at": datetime.now(UTC), + "vt": datetime.now(UTC) + timedelta(seconds=300), + "message": message_body + } + return _create + + +@pytest.fixture +def sample_user_info(): + """Factory for sample authenticated user info.""" + def _create( + user_id: int = 1, + username: str = "testuser", + email: str = "testuser@example.com", + **kwargs + ) -> dict[str, Any]: + return { + "user_id": user_id, + "username": username, + "email": email, + **kwargs + } + return _create + + +# ============================================================================ +# API Test Client Fixtures +# ============================================================================ + +@pytest.fixture +def mock_verify_api_key(sample_user_info): + """Mock API key verification.""" + async def _verify(*args, **kwargs): + return sample_user_info() + return _verify + + +# ============================================================================ +# Environment Variable Fixtures +# ============================================================================ + +@pytest.fixture +def mock_env_vars(): + """Mock environment variables for services.""" + env = { + "POSTGRES_HOST": "localhost", + "POSTGRES_PORT": "5432", + "POSTGRES_USER": "testuser", + "POSTGRES_PASSWORD": "testpass", + "POSTGRES_DB": "testdb", + "QUEUE_NAME": "test_queue", + "DISK_QUEUE_NAME": "test_disk_queue", + "KUBE_NAMESPACE": "test-namespace", + "WORKER_IMAGE": "test-image:latest", + "SERVICE_ACCOUNT": "test-sa", + "REGION": "us-east-1", + "EKS_CLUSTER_NAME": "test-cluster", + "PRIMARY_AVAILABILITY_ZONE": "us-east-1a", + "MAX_RESERVATION_HOURS": "48", + "DEFAULT_TIMEOUT_HOURS": "4", + "API_KEY_TTL_HOURS": "2", + "ALLOWED_AWS_ROLE": "SSOCloudDevGpuReservation" + } + with patch.dict("os.environ", env, clear=False): + yield env diff --git a/tests/unit/services/test_api_endpoints.py b/tests/unit/services/test_api_endpoints.py new file mode 100644 index 00000000..17c1e2c4 --- /dev/null +++ b/tests/unit/services/test_api_endpoints.py @@ -0,0 +1,617 @@ +""" +Unit tests for FastAPI API service endpoints. +""" +import json +import uuid +from datetime import UTC, datetime, timedelta +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + + +# ============================================================================ +# Authentication Tests +# ============================================================================ + +class TestVerifyApiKey: + """Tests for API key verification.""" + + @pytest.mark.asyncio + async def test_verify_api_key_valid(self, mock_asyncpg_pool): + """Valid API key returns user info.""" + pool, conn = mock_asyncpg_pool + expires_at = datetime.now(UTC) + timedelta(hours=1) + conn.fetchrow.return_value = { + "user_id": 1, + "username": "testuser", + "email": "testuser@example.com", + "user_active": True, + "key_id": 123, + "expires_at": expires_at, + "key_active": True + } + conn.execute.return_value = None + + with patch("app.main.db_pool", pool): + with patch("app.main.hash_api_key", return_value="hashedkey"): + from app.main import verify_api_key + credentials = MagicMock() + credentials.credentials = "valid_api_key_12345678" + result = await verify_api_key(credentials) + + assert result["username"] == "testuser" + assert result["user_id"] == 1 + + @pytest.mark.asyncio + async def test_verify_api_key_expired(self, mock_asyncpg_pool): + """Expired API key raises 403.""" + pool, conn = mock_asyncpg_pool + expires_at = datetime.now(UTC) - timedelta(hours=1) + conn.fetchrow.return_value = { + "user_id": 1, + "username": "testuser", + "email": "testuser@example.com", + "user_active": True, + "key_id": 123, + "expires_at": expires_at, + "key_active": True + } + + with patch("app.main.db_pool", pool): + with patch("app.main.hash_api_key", return_value="hashedkey"): + from app.main import verify_api_key + credentials = MagicMock() + credentials.credentials = "expired_api_key_123456" + with pytest.raises(HTTPException) as exc: + await verify_api_key(credentials) + assert exc.value.status_code == 403 + assert "expired" in exc.value.detail.lower() + + @pytest.mark.asyncio + async def test_verify_api_key_invalid(self, mock_asyncpg_pool): + """Invalid API key raises 401.""" + pool, conn = mock_asyncpg_pool + conn.fetchrow.return_value = None + + with patch("app.main.db_pool", pool): + with patch("app.main.hash_api_key", return_value="hashedkey"): + from app.main import verify_api_key + credentials = MagicMock() + credentials.credentials = "invalid_api_key_123456" + with pytest.raises(HTTPException) as exc: + await verify_api_key(credentials) + assert exc.value.status_code == 401 + + @pytest.mark.asyncio + async def test_verify_api_key_inactive_user(self, mock_asyncpg_pool): + """Inactive user raises 403.""" + pool, conn = mock_asyncpg_pool + expires_at = datetime.now(UTC) + timedelta(hours=1) + conn.fetchrow.return_value = { + "user_id": 1, + "username": "testuser", + "email": "testuser@example.com", + "user_active": False, + "key_id": 123, + "expires_at": expires_at, + "key_active": True + } + + with patch("app.main.db_pool", pool): + with patch("app.main.hash_api_key", return_value="hashedkey"): + from app.main import verify_api_key + credentials = MagicMock() + credentials.credentials = "valid_api_key_12345678" + with pytest.raises(HTTPException) as exc: + await verify_api_key(credentials) + assert exc.value.status_code == 403 + assert "disabled" in exc.value.detail.lower() + + +# ============================================================================ +# Health Check Tests +# ============================================================================ + +class TestHealthEndpoint: + """Tests for /health endpoint.""" + + @pytest.mark.asyncio + async def test_health_check_healthy(self, mock_asyncpg_pool): + """Health check returns healthy when DB and queue are OK.""" + pool, conn = mock_asyncpg_pool + conn.fetchval.return_value = 1 + conn.fetch.return_value = [{"queue_name": "gpu_reservations"}] + + with patch("app.main.db_pool", pool): + with patch("app.main.QUEUE_NAME", "gpu_reservations"): + from app.main import health_check + result = await health_check() + + assert result["status"] == "healthy" + assert result["database"] == "healthy" + assert result["queue"] == "healthy" + + @pytest.mark.asyncio + async def test_health_check_missing_queue(self, mock_asyncpg_pool): + """Health check returns unhealthy when queue is missing.""" + pool, conn = mock_asyncpg_pool + conn.fetchval.return_value = 1 + conn.fetch.return_value = [{"queue_name": "other_queue"}] + + with patch("app.main.db_pool", pool): + with patch("app.main.QUEUE_NAME", "gpu_reservations"): + from app.main import health_check + result = await health_check() + + assert result["status"] == "unhealthy" + assert result["database"] == "healthy" + assert result["queue"] == "missing" + + @pytest.mark.asyncio + async def test_health_check_db_not_initialized(self): + """Health check returns unhealthy when DB pool is None.""" + with patch("app.main.db_pool", None): + from app.main import health_check + result = await health_check() + + assert result["status"] == "unhealthy" + assert result["database"] == "not initialized" + + +# ============================================================================ +# Job Submission Tests +# ============================================================================ + +class TestSubmitJobEndpoint: + """Tests for POST /v1/jobs/submit endpoint.""" + + @pytest.mark.asyncio + async def test_submit_job_success(self, mock_asyncpg_pool, sample_user_info): + """Submitting a job returns queued status.""" + pool, conn = mock_asyncpg_pool + conn.fetchval.return_value = 123 # msg_id + + with patch("app.main.db_pool", pool): + with patch("app.main.QUEUE_NAME", "gpu_reservations"): + from app.main import submit_job, JobSubmissionRequest + request = JobSubmissionRequest( + image="pytorch/pytorch:2.1.0", + instance_type="p4d.24xlarge", + duration_hours=4 + ) + user_info = sample_user_info() + result = await submit_job(request, user_info) + + assert result.status == "queued" + assert "123" in result.message + + @pytest.mark.asyncio + async def test_submit_job_with_existing_reservation_id( + self, mock_asyncpg_pool, sample_user_info + ): + """Job uses reservation_id from env_vars if provided.""" + pool, conn = mock_asyncpg_pool + conn.fetchval.return_value = 456 + + existing_id = str(uuid.uuid4()) + with patch("app.main.db_pool", pool): + with patch("app.main.QUEUE_NAME", "gpu_reservations"): + from app.main import submit_job, JobSubmissionRequest + request = JobSubmissionRequest( + image="pytorch/pytorch:2.1.0", + instance_type="p4d.24xlarge", + duration_hours=4, + env_vars={"RESERVATION_ID": existing_id} + ) + user_info = sample_user_info() + result = await submit_job(request, user_info) + + assert result.job_id == existing_id + + +# ============================================================================ +# Get Job Tests +# ============================================================================ + +class TestGetJobEndpoint: + """Tests for GET /v1/jobs/{job_id} endpoint.""" + + @pytest.mark.asyncio + async def test_get_job_success(self, mock_asyncpg_pool, sample_user_info): + """Get job returns job details.""" + pool, conn = mock_asyncpg_pool + job_id = str(uuid.uuid4()) + conn.fetchrow.return_value = { + "reservation_id": job_id, + "user_id": "testuser", + "status": "active", + "gpu_type": "a100", + "gpu_count": 4, + "instance_type": "p4d.24xlarge", + "duration_hours": 4, + "created_at": datetime.now(UTC), + "expires_at": datetime.now(UTC) + timedelta(hours=4), + "name": "test-pod", + "pod_name": "gpu-dev-test", + "node_ip": "10.0.0.1", + "node_port": 30001, + "jupyter_enabled": False, + "jupyter_url": None, + "jupyter_token": None, + "github_user": "testghuser" + } + + with patch("app.main.db_pool", pool): + from app.main import get_job_status + user_info = sample_user_info() + result = await get_job_status(job_id, user_info) + + assert result.job_id == job_id + assert result.status == "active" + assert result.gpu_count == 4 + + @pytest.mark.asyncio + async def test_get_job_not_found(self, mock_asyncpg_pool, sample_user_info): + """Get non-existent job raises 404.""" + pool, conn = mock_asyncpg_pool + conn.fetchrow.return_value = None + + with patch("app.main.db_pool", pool): + from app.main import get_job_status + with pytest.raises(HTTPException) as exc: + await get_job_status("non-existent-id", sample_user_info()) + assert exc.value.status_code == 404 + + @pytest.mark.asyncio + async def test_get_job_unauthorized(self, mock_asyncpg_pool, sample_user_info): + """Get another user's job raises 403.""" + pool, conn = mock_asyncpg_pool + job_id = str(uuid.uuid4()) + conn.fetchrow.return_value = { + "reservation_id": job_id, + "user_id": "otheruser", # Different user + "status": "active", + "gpu_type": "a100", + "gpu_count": 4, + "instance_type": "p4d.24xlarge", + "duration_hours": 4, + "created_at": datetime.now(UTC), + "expires_at": datetime.now(UTC) + timedelta(hours=4), + "name": "test-pod", + "pod_name": "gpu-dev-test", + "node_ip": "10.0.0.1", + "node_port": 30001, + "jupyter_enabled": False, + "jupyter_url": None, + "jupyter_token": None, + "github_user": "testghuser" + } + + with patch("app.main.db_pool", pool): + from app.main import get_job_status + with pytest.raises(HTTPException) as exc: + await get_job_status(job_id, sample_user_info()) + assert exc.value.status_code == 403 + + +# ============================================================================ +# List Jobs Tests +# ============================================================================ + +class TestListJobsEndpoint: + """Tests for GET /v1/jobs endpoint.""" + + @pytest.mark.asyncio + async def test_list_jobs_success(self, mock_asyncpg_pool, sample_user_info): + """List jobs returns user's jobs.""" + pool, conn = mock_asyncpg_pool + conn.fetchval.return_value = 2 # total count + conn.fetch.return_value = [ + { + "reservation_id": str(uuid.uuid4()), + "user_id": "testuser", + "status": "active", + "gpu_type": "a100", + "gpu_count": 4, + "instance_type": "p4d.24xlarge", + "duration_hours": 4, + "created_at": datetime.now(UTC), + "expires_at": datetime.now(UTC) + timedelta(hours=4), + "name": "test-pod-1", + "pod_name": "gpu-dev-test-1", + "node_ip": "10.0.0.1", + "node_port": 30001, + "jupyter_enabled": False, + "jupyter_url": None, + "jupyter_token": None, + "github_user": "testghuser" + }, + { + "reservation_id": str(uuid.uuid4()), + "user_id": "testuser", + "status": "queued", + "gpu_type": "h100", + "gpu_count": 8, + "instance_type": "p5.48xlarge", + "duration_hours": 2, + "created_at": datetime.now(UTC), + "expires_at": datetime.now(UTC) + timedelta(hours=2), + "name": "test-pod-2", + "pod_name": None, + "node_ip": None, + "node_port": None, + "jupyter_enabled": False, + "jupyter_url": None, + "jupyter_token": None, + "github_user": "testghuser" + } + ] + + with patch("app.main.db_pool", pool): + from app.main import list_jobs + result = await list_jobs(sample_user_info()) + + assert result.total == 2 + assert len(result.jobs) == 2 + assert result.jobs[0].status == "active" + + @pytest.mark.asyncio + async def test_list_jobs_with_status_filter( + self, mock_asyncpg_pool, sample_user_info + ): + """List jobs with status filter returns filtered results.""" + pool, conn = mock_asyncpg_pool + conn.fetchval.return_value = 1 + conn.fetch.return_value = [ + { + "reservation_id": str(uuid.uuid4()), + "user_id": "testuser", + "status": "active", + "gpu_type": "a100", + "gpu_count": 4, + "instance_type": "p4d.24xlarge", + "duration_hours": 4, + "created_at": datetime.now(UTC), + "expires_at": datetime.now(UTC) + timedelta(hours=4), + "name": "test-pod", + "pod_name": "gpu-dev-test", + "node_ip": "10.0.0.1", + "node_port": 30001, + "jupyter_enabled": False, + "jupyter_url": None, + "jupyter_token": None, + "github_user": "testghuser" + } + ] + + with patch("app.main.db_pool", pool): + from app.main import list_jobs + result = await list_jobs(sample_user_info(), status_filter="active") + + assert result.total == 1 + assert result.jobs[0].status == "active" + + +# ============================================================================ +# Cancel Job Tests +# ============================================================================ + +class TestCancelJobEndpoint: + """Tests for POST /v1/jobs/{job_id}/cancel endpoint.""" + + @pytest.mark.asyncio + async def test_cancel_job_success(self, mock_asyncpg_pool, sample_user_info): + """Cancel job sends message to queue.""" + pool, conn = mock_asyncpg_pool + conn.fetchval.return_value = 789 # msg_id + + with patch("app.main.db_pool", pool): + with patch("app.main.QUEUE_NAME", "gpu_reservations"): + from app.main import cancel_job + job_id = str(uuid.uuid4()) + result = await cancel_job(job_id, sample_user_info()) + + assert result.action == "cancel" + assert result.status == "requested" + assert result.job_id == job_id + conn.fetchval.assert_called_once() + + +# ============================================================================ +# Extend Job Tests +# ============================================================================ + +class TestExtendJobEndpoint: + """Tests for POST /v1/jobs/{job_id}/extend endpoint.""" + + @pytest.mark.asyncio + async def test_extend_job_success(self, mock_asyncpg_pool, sample_user_info): + """Extend job sends message to queue.""" + pool, conn = mock_asyncpg_pool + conn.fetchval.return_value = 101 # msg_id + + with patch("app.main.db_pool", pool): + with patch("app.main.QUEUE_NAME", "gpu_reservations"): + from app.main import extend_job, ExtendJobRequest + job_id = str(uuid.uuid4()) + request = ExtendJobRequest(extension_hours=2) + result = await extend_job(job_id, request, sample_user_info()) + + assert result.action == "extend" + assert result.status == "requested" + assert "2 hours" in result.message + + +# ============================================================================ +# GPU Availability Tests +# ============================================================================ + +class TestGPUAvailabilityEndpoint: + """Tests for GET /v1/gpu/availability endpoint.""" + + @pytest.mark.asyncio + async def test_get_gpu_availability_success( + self, mock_asyncpg_pool, sample_user_info + ): + """Get GPU availability returns availability data.""" + pool, conn = mock_asyncpg_pool + # GPU config query + gpu_config = [ + {"gpu_type": "a100", "total_cluster_gpus": 16, "max_per_node": 8}, + {"gpu_type": "h100", "total_cluster_gpus": 16, "max_per_node": 8} + ] + # In-use query + in_use = [{"gpu_type": "a100", "count": 4}] + # Queued query + queued = [{"gpu_type": "h100", "count": 8}] + + conn.fetch.side_effect = [gpu_config, in_use, queued] + + with patch("app.main.db_pool", pool): + from app.main import get_gpu_availability + result = await get_gpu_availability(sample_user_info()) + + assert "a100" in result.availability + assert result.availability["a100"].total == 16 + assert result.availability["a100"].in_use == 4 + assert result.availability["a100"].available == 12 + + @pytest.mark.asyncio + async def test_get_gpu_availability_empty( + self, mock_asyncpg_pool, sample_user_info + ): + """Get GPU availability returns empty when no GPU config.""" + pool, conn = mock_asyncpg_pool + conn.fetch.return_value = [] # No GPU config + + with patch("app.main.db_pool", pool): + from app.main import get_gpu_availability + result = await get_gpu_availability(sample_user_info()) + + assert result.availability == {} + + +# ============================================================================ +# AWS Login Tests +# ============================================================================ + +class TestAWSLoginEndpoint: + """Tests for POST /v1/auth/aws-login endpoint.""" + + @pytest.mark.asyncio + async def test_aws_login_success(self, mock_asyncpg_pool, mock_sts_client): + """AWS login with valid credentials returns API key.""" + pool, conn = mock_asyncpg_pool + # User lookup (not found -> create new) + conn.fetchrow.side_effect = [None, {"user_id": 1}] + conn.fetchval.return_value = 1 # user_id after insert + + with patch("app.main.db_pool", pool): + with patch("app.main.verify_aws_credentials") as mock_verify: + mock_verify.return_value = { + "account": "123456789012", + "user_id": "AIDAEXAMPLE", + "arn": "arn:aws:sts::123456789012:assumed-role/SSOCloudDevGpuReservation/testuser" + } + with patch("app.main.ALLOWED_AWS_ROLE", "SSOCloudDevGpuReservation"): + with patch("app.main.create_api_key_for_user") as mock_create_key: + mock_create_key.return_value = ( + "test_api_key", + "test_pre", + datetime.now(UTC) + timedelta(hours=2) + ) + from app.main import aws_login, AWSLoginRequest + request = AWSLoginRequest( + aws_access_key_id="AKIAIOSFODNN7EXAMPLE", + aws_secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + aws_session_token="session_token_123" * 10 + ) + result = await aws_login(request) + + assert result.api_key == "test_api_key" + assert "testuser" in result.aws_arn + + @pytest.mark.asyncio + async def test_aws_login_wrong_role(self, mock_asyncpg_pool): + """AWS login with wrong role raises 403.""" + pool, conn = mock_asyncpg_pool + + with patch("app.main.db_pool", pool): + with patch("app.main.verify_aws_credentials") as mock_verify: + mock_verify.return_value = { + "account": "123456789012", + "user_id": "AIDAEXAMPLE", + "arn": "arn:aws:sts::123456789012:assumed-role/WrongRole/testuser" + } + with patch("app.main.ALLOWED_AWS_ROLE", "SSOCloudDevGpuReservation"): + from app.main import aws_login, AWSLoginRequest + request = AWSLoginRequest( + aws_access_key_id="AKIAIOSFODNN7EXAMPLE", + aws_secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + ) + with pytest.raises(HTTPException) as exc: + await aws_login(request) + assert exc.value.status_code == 403 + + +# ============================================================================ +# Helper Function Tests +# ============================================================================ + +class TestHelperFunctions: + """Tests for helper functions.""" + + def test_extract_username_from_arn_assumed_role(self): + """Extract username from assumed-role ARN.""" + from app.main import extract_username_from_arn + arn = "arn:aws:sts::123456789:assumed-role/SSOCloudDevGpuReservation/john" + assert extract_username_from_arn(arn) == "john" + + def test_extract_username_from_arn_iam_user(self): + """Extract username from IAM user ARN.""" + from app.main import extract_username_from_arn + arn = "arn:aws:iam::123456789:user/jane" + assert extract_username_from_arn(arn) == "jane" + + def test_extract_role_from_arn_assumed_role(self): + """Extract role name from assumed-role ARN.""" + from app.main import extract_role_from_arn + arn = "arn:aws:sts::123456789:assumed-role/SSOCloudDevGpuReservation/john" + assert extract_role_from_arn(arn) == "SSOCloudDevGpuReservation" + + def test_extract_role_from_arn_user(self): + """Extract role name from user ARN returns empty.""" + from app.main import extract_role_from_arn + arn = "arn:aws:iam::123456789:user/john" + assert extract_role_from_arn(arn) == "" + + def test_hash_api_key(self): + """Hash API key produces consistent hash.""" + from app.main import hash_api_key + key = "test_api_key_12345" + hash1 = hash_api_key(key) + hash2 = hash_api_key(key) + assert hash1 == hash2 + assert len(hash1) == 64 # SHA256 hex + + def test_ensure_utc_naive_datetime(self): + """ensure_utc adds UTC timezone to naive datetime.""" + from app.main import ensure_utc + naive = datetime(2024, 1, 1, 12, 0, 0) + result = ensure_utc(naive) + assert result.tzinfo is not None + + def test_ensure_utc_none(self): + """ensure_utc returns None for None input.""" + from app.main import ensure_utc + assert ensure_utc(None) is None + + def test_create_message_metadata(self): + """create_message_metadata returns proper structure.""" + from app.main import create_message_metadata + metadata = create_message_metadata(max_retries=5) + assert metadata["retry_count"] == 0 + assert metadata["max_retries"] == 5 + assert "created_at" in metadata diff --git a/tests/unit/services/test_disk_reconciler.py b/tests/unit/services/test_disk_reconciler.py new file mode 100644 index 00000000..c39c7036 --- /dev/null +++ b/tests/unit/services/test_disk_reconciler.py @@ -0,0 +1,681 @@ +""" +Unit tests for disk reconciler. +""" +import random +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, patch, call + +import pytest +from botocore.exceptions import ClientError + + +# ============================================================================ +# Volume Discovery Tests +# ============================================================================ + +class TestVolumeDiscovery: + """Tests for AWS volume discovery.""" + + def test_get_all_gpudev_volumes_success(self, mock_ec2_client): + """get_all_gpudev_volumes returns parsed volumes.""" + mock_ec2_client.get_paginator.return_value.paginate.return_value = [ + { + "Volumes": [ + { + "VolumeId": "vol-12345678", + "Size": 100, + "State": "available", + "AvailabilityZone": "us-east-1a", + "CreateTime": datetime.now(UTC), + "Attachments": [], + "Tags": [ + {"Key": "gpu-dev-user", "Value": "testuser"}, + {"Key": "disk-name", "Value": "test-disk"} + ] + } + ] + } + ] + + from shared.disk_reconciler import get_all_gpudev_volumes + volumes, error = get_all_gpudev_volumes(mock_ec2_client) + + assert error is None + assert len(volumes) == 1 + assert volumes[0]["volume_id"] == "vol-12345678" + assert volumes[0]["user_id"] == "testuser" + assert volumes[0]["disk_name"] == "test-disk" + + def test_get_all_gpudev_volumes_empty(self, mock_ec2_client): + """get_all_gpudev_volumes returns empty list when no volumes.""" + mock_ec2_client.get_paginator.return_value.paginate.return_value = [ + {"Volumes": []} + ] + + from shared.disk_reconciler import get_all_gpudev_volumes + volumes, error = get_all_gpudev_volumes(mock_ec2_client) + + assert error is None + assert volumes == [] + + def test_get_all_gpudev_volumes_skips_quarantined(self, mock_ec2_client): + """get_all_gpudev_volumes skips quarantined volumes.""" + mock_ec2_client.get_paginator.return_value.paginate.return_value = [ + { + "Volumes": [ + { + "VolumeId": "vol-12345678", + "Size": 100, + "State": "available", + "AvailabilityZone": "us-east-1a", + "CreateTime": datetime.now(UTC), + "Attachments": [], + "Tags": [ + {"Key": "gpu-dev-user", "Value": "testuser"}, + {"Key": "disk-name", "Value": "test-disk"}, + {"Key": "gpu-dev-quarantined", "Value": "2024-01-01T00:00:00Z"} + ] + } + ] + } + ] + + from shared.disk_reconciler import get_all_gpudev_volumes + volumes, error = get_all_gpudev_volumes(mock_ec2_client) + + assert error is None + assert volumes == [] + + def test_get_all_gpudev_volumes_retry_on_throttling(self, mock_ec2_client): + """get_all_gpudev_volumes retries on throttling.""" + throttle_error = ClientError( + {"Error": {"Code": "Throttling", "Message": "Rate exceeded"}}, + "DescribeVolumes" + ) + mock_ec2_client.get_paginator.return_value.paginate.side_effect = [ + throttle_error, + [{"Volumes": []}] + ] + + from shared.disk_reconciler import get_all_gpudev_volumes + with patch("time.sleep"): # Skip actual sleep + volumes, error = get_all_gpudev_volumes(mock_ec2_client, max_retries=2) + + assert error is None + assert volumes == [] + + def test_get_all_gpudev_volumes_error_on_max_retries(self, mock_ec2_client): + """get_all_gpudev_volumes returns error after max retries.""" + throttle_error = ClientError( + {"Error": {"Code": "Throttling", "Message": "Rate exceeded"}}, + "DescribeVolumes" + ) + mock_ec2_client.get_paginator.return_value.paginate.side_effect = throttle_error + + from shared.disk_reconciler import get_all_gpudev_volumes + with patch("time.sleep"): + volumes, error = get_all_gpudev_volumes(mock_ec2_client, max_retries=2) + + assert error is not None + assert "throttling" in error.lower() + assert volumes == [] + + +# ============================================================================ +# Volume Parsing Tests +# ============================================================================ + +class TestVolumeParsing: + """Tests for AWS volume response parsing.""" + + def test_parse_volume_from_aws_basic(self): + """parse_volume_from_aws parses basic volume.""" + aws_volume = { + "VolumeId": "vol-12345678", + "Size": 100, + "State": "available", + "AvailabilityZone": "us-east-1a", + "CreateTime": datetime.now(UTC), + "Attachments": [], + "Tags": [ + {"Key": "gpu-dev-user", "Value": "testuser"}, + {"Key": "disk-name", "Value": "test-disk"} + ] + } + + from shared.disk_reconciler import parse_volume_from_aws + result = parse_volume_from_aws(aws_volume) + + assert result["volume_id"] == "vol-12345678" + assert result["size_gb"] == 100 + assert result["user_id"] == "testuser" + assert result["disk_name"] == "test-disk" + assert result["is_attached"] is False + + def test_parse_volume_from_aws_attached(self): + """parse_volume_from_aws parses attached volume.""" + aws_volume = { + "VolumeId": "vol-12345678", + "Size": 100, + "State": "in-use", + "AvailabilityZone": "us-east-1a", + "CreateTime": datetime.now(UTC), + "Attachments": [ + {"InstanceId": "i-12345678", "State": "attached"} + ], + "Tags": [ + {"Key": "gpu-dev-user", "Value": "testuser"}, + {"Key": "disk-name", "Value": "test-disk"} + ] + } + + from shared.disk_reconciler import parse_volume_from_aws + result = parse_volume_from_aws(aws_volume) + + assert result["is_attached"] is True + assert result["attached_instance"] == "i-12345678" + + def test_parse_volume_from_aws_skips_missing_tags(self): + """parse_volume_from_aws skips volumes without required tags.""" + aws_volume = { + "VolumeId": "vol-12345678", + "Size": 100, + "State": "available", + "AvailabilityZone": "us-east-1a", + "CreateTime": datetime.now(UTC), + "Attachments": [], + "Tags": [ + {"Key": "gpu-dev-user", "Value": "testuser"} + # Missing disk-name + ] + } + + from shared.disk_reconciler import parse_volume_from_aws + result = parse_volume_from_aws(aws_volume) + + assert result is None + + def test_parse_volume_from_aws_skips_transient_states(self): + """parse_volume_from_aws skips volumes in transient states.""" + aws_volume = { + "VolumeId": "vol-12345678", + "Size": 100, + "State": "creating", + "AvailabilityZone": "us-east-1a", + "CreateTime": datetime.now(UTC), + "Attachments": [], + "Tags": [ + {"Key": "gpu-dev-user", "Value": "testuser"}, + {"Key": "disk-name", "Value": "test-disk"} + ] + } + + from shared.disk_reconciler import parse_volume_from_aws + result = parse_volume_from_aws(aws_volume) + + assert result is None + + +# ============================================================================ +# Orphan Detection Tests +# ============================================================================ + +class TestOrphanDetection: + """Tests for orphan volume/record detection.""" + + def test_orphaned_aws_volume_imported( + self, mock_ec2_client, mock_db_cursor, mock_connection_pool + ): + """Orphaned AWS volume is imported to database.""" + aws_volume = { + "volume_id": "vol-12345678", + "size_gb": 100, + "state": "available", + "availability_zone": "us-east-1a", + "created_at": datetime.now(UTC), + "is_attached": False, + "attached_instance": None, + "disk_name": "new-disk", + "user_id": "testuser", + "reservation_id": None, + "tags": {} + } + + mock_ec2_client.describe_snapshots.return_value = {"Snapshots": []} + + with patch("shared.disk_reconciler.create_disk") as mock_create: + mock_create.return_value = True + from shared.disk_reconciler import import_volume_to_db + result = import_volume_to_db(aws_volume, mock_ec2_client) + + assert result is True + mock_create.assert_called_once() + call_args = mock_create.call_args[0][0] + assert call_args["disk_name"] == "new-disk" + assert call_args["ebs_volume_id"] == "vol-12345678" + + +# ============================================================================ +# Snapshot Management Tests +# ============================================================================ + +class TestSnapshotManagement: + """Tests for snapshot information gathering.""" + + def test_get_snapshot_info_no_snapshots(self, mock_ec2_client): + """get_snapshot_info returns defaults when no snapshots.""" + mock_ec2_client.describe_snapshots.return_value = {"Snapshots": []} + + from shared.disk_reconciler import get_snapshot_info + result = get_snapshot_info(mock_ec2_client, "vol-12345678", "testuser") + + assert result["count"] == 0 + assert result["is_backing_up"] is False + assert result["last_snapshot_at"] is None + + def test_get_snapshot_info_with_completed_snapshots(self, mock_ec2_client): + """get_snapshot_info returns info for completed snapshots.""" + snapshot_time = datetime.now(UTC) + mock_ec2_client.describe_snapshots.side_effect = [ + {"Snapshots": []}, # pending check + {"Snapshots": [ + {"SnapshotId": "snap-1", "StartTime": snapshot_time - timedelta(days=1)}, + {"SnapshotId": "snap-2", "StartTime": snapshot_time} + ]} # completed check + ] + + from shared.disk_reconciler import get_snapshot_info + result = get_snapshot_info(mock_ec2_client, "vol-12345678", "testuser") + + assert result["count"] == 2 + assert result["is_backing_up"] is False + assert result["last_snapshot_at"] == snapshot_time + + def test_get_snapshot_info_with_pending_snapshot(self, mock_ec2_client): + """get_snapshot_info detects in-progress backup.""" + mock_ec2_client.describe_snapshots.side_effect = [ + {"Snapshots": [{"SnapshotId": "snap-pending", "Status": "pending"}]}, + {"Snapshots": []} + ] + + from shared.disk_reconciler import get_snapshot_info + result = get_snapshot_info(mock_ec2_client, "vol-12345678", "testuser") + + assert result["is_backing_up"] is True + + +# ============================================================================ +# Conflict Resolution Tests +# ============================================================================ + +class TestConflictResolution: + """Tests for duplicate volume conflict resolution.""" + + def test_resolve_conflict_one_attached(self, mock_ec2_client): + """Attached volume is chosen when one is attached.""" + volumes = [ + { + "volume_id": "vol-1", + "is_attached": True, + "size_gb": 100, + "created_at": datetime.now(UTC) - timedelta(days=1) + }, + { + "volume_id": "vol-2", + "is_attached": False, + "size_gb": 100, + "created_at": datetime.now(UTC) + } + ] + + mock_ec2_client.describe_volumes.return_value = { + "Volumes": [{"VolumeId": "vol-2", "Attachments": []}] + } + mock_ec2_client.create_tags.return_value = {} + + from shared.disk_reconciler import resolve_volume_conflict_with_quarantine + current, quarantined = resolve_volume_conflict_with_quarantine( + mock_ec2_client, + "testuser", + "test-disk", + volumes, + None + ) + + assert current["volume_id"] == "vol-1" + assert "vol-2" in quarantined + + def test_resolve_conflict_multiple_attached_fails(self, mock_ec2_client): + """Multiple attached volumes returns None (error state).""" + volumes = [ + { + "volume_id": "vol-1", + "is_attached": True, + "size_gb": 100, + "created_at": datetime.now(UTC) + }, + { + "volume_id": "vol-2", + "is_attached": True, + "size_gb": 100, + "created_at": datetime.now(UTC) + } + ] + + from shared.disk_reconciler import resolve_volume_conflict_with_quarantine + current, quarantined = resolve_volume_conflict_with_quarantine( + mock_ec2_client, + "testuser", + "test-disk", + volumes, + None + ) + + assert current is None + assert quarantined == [] + + def test_resolve_conflict_uses_db_preference(self, mock_ec2_client): + """DB-referenced volume is preferred when all detached.""" + volumes = [ + { + "volume_id": "vol-1", + "is_attached": False, + "size_gb": 100, + "created_at": datetime.now(UTC) + }, + { + "volume_id": "vol-2", + "is_attached": False, + "size_gb": 100, + "created_at": datetime.now(UTC) - timedelta(days=1) + } + ] + + db_record = {"ebs_volume_id": "vol-2"} + + mock_ec2_client.describe_volumes.return_value = { + "Volumes": [{"VolumeId": "vol-1", "Attachments": []}] + } + mock_ec2_client.create_tags.return_value = {} + + from shared.disk_reconciler import resolve_volume_conflict_with_quarantine + current, quarantined = resolve_volume_conflict_with_quarantine( + mock_ec2_client, + "testuser", + "test-disk", + volumes, + db_record + ) + + assert current["volume_id"] == "vol-2" + assert "vol-1" in quarantined + + +# ============================================================================ +# Cross-AZ Migration Tests +# ============================================================================ + +class TestCrossAZMigration: + """Tests for cross-AZ volume migration logic.""" + + def test_sync_volume_detects_size_change( + self, mock_ec2_client, mock_db_cursor + ): + """sync_volume_to_db detects volume size change.""" + aws_vol = { + "volume_id": "vol-12345678", + "size_gb": 200, # Changed from 100 + "is_attached": False, + "created_at": datetime.now(UTC) + } + db_disk = { + "disk_id": 1, + "disk_name": "test-disk", + "user_id": "testuser", + "ebs_volume_id": "vol-12345678", + "size_gb": 100, + "in_use": False, + "snapshot_count": 0, + "is_backing_up": False, + "last_snapshot_at": None + } + + mock_ec2_client.describe_snapshots.return_value = {"Snapshots": []} + + with patch("shared.disk_reconciler.update_disk") as mock_update: + mock_update.return_value = True + from shared.disk_reconciler import sync_volume_to_db + result = sync_volume_to_db(aws_vol, db_disk, mock_ec2_client) + + assert result == "updated" + mock_update.assert_called_once() + call_args = mock_update.call_args[0][2] + assert call_args["size_gb"] == 200 + + def test_sync_volume_detects_attachment_change( + self, mock_ec2_client, mock_db_cursor + ): + """sync_volume_to_db detects attachment status change.""" + aws_vol = { + "volume_id": "vol-12345678", + "size_gb": 100, + "is_attached": True, # Changed from False + "created_at": datetime.now(UTC) + } + db_disk = { + "disk_id": 1, + "disk_name": "test-disk", + "user_id": "testuser", + "ebs_volume_id": "vol-12345678", + "size_gb": 100, + "in_use": False, # DB shows not in use + "snapshot_count": 0, + "is_backing_up": False, + "last_snapshot_at": None + } + + mock_ec2_client.describe_snapshots.return_value = {"Snapshots": []} + + with patch("shared.disk_reconciler.update_disk") as mock_update: + mock_update.return_value = True + from shared.disk_reconciler import sync_volume_to_db + result = sync_volume_to_db(aws_vol, db_disk, mock_ec2_client) + + assert result == "updated" + call_args = mock_update.call_args[0][2] + assert call_args["in_use"] is True + + def test_sync_volume_no_changes(self, mock_ec2_client, mock_db_cursor): + """sync_volume_to_db returns synced when no changes.""" + aws_vol = { + "volume_id": "vol-12345678", + "size_gb": 100, + "is_attached": False, + "created_at": datetime.now(UTC) + } + db_disk = { + "disk_id": 1, + "disk_name": "test-disk", + "user_id": "testuser", + "ebs_volume_id": "vol-12345678", + "size_gb": 100, + "in_use": False, + "snapshot_count": 0, + "is_backing_up": False, + "last_snapshot_at": None + } + + mock_ec2_client.describe_snapshots.return_value = {"Snapshots": []} + + with patch("shared.disk_reconciler.update_disk") as mock_update: + from shared.disk_reconciler import sync_volume_to_db + result = sync_volume_to_db(aws_vol, db_disk, mock_ec2_client) + + assert result == "synced" + mock_update.assert_not_called() + + +# ============================================================================ +# Quarantine Cleanup Tests +# ============================================================================ + +class TestQuarantineCleanup: + """Tests for quarantined volume cleanup.""" + + def test_cleanup_old_quarantined_volumes_deletes_old(self, mock_ec2_client): + """Old quarantined volumes are deleted.""" + old_timestamp = (datetime.now(UTC) - timedelta(days=35)).isoformat() + mock_ec2_client.get_paginator.return_value.paginate.return_value = [ + { + "Volumes": [ + { + "VolumeId": "vol-old", + "Size": 100, + "Attachments": [], + "Tags": [ + {"Key": "gpu-dev-quarantined", "Value": old_timestamp}, + {"Key": "gpu-dev-user", "Value": "testuser"}, + {"Key": "disk-name", "Value": "old-disk"} + ] + } + ] + } + ] + mock_ec2_client.create_snapshot.return_value = {"SnapshotId": "snap-123"} + mock_ec2_client.delete_volume.return_value = {} + + from shared.disk_reconciler import cleanup_old_quarantined_volumes + stats = cleanup_old_quarantined_volumes(mock_ec2_client, max_age_days=30) + + assert stats["deleted"] == 1 + mock_ec2_client.create_snapshot.assert_called_once() + mock_ec2_client.delete_volume.assert_called_once() + + def test_cleanup_old_quarantined_volumes_skips_recent(self, mock_ec2_client): + """Recent quarantined volumes are not deleted.""" + recent_timestamp = (datetime.now(UTC) - timedelta(days=5)).isoformat() + mock_ec2_client.get_paginator.return_value.paginate.return_value = [ + { + "Volumes": [ + { + "VolumeId": "vol-recent", + "Size": 100, + "Attachments": [], + "Tags": [ + {"Key": "gpu-dev-quarantined", "Value": recent_timestamp}, + {"Key": "gpu-dev-user", "Value": "testuser"}, + {"Key": "disk-name", "Value": "recent-disk"} + ] + } + ] + } + ] + + from shared.disk_reconciler import cleanup_old_quarantined_volumes + stats = cleanup_old_quarantined_volumes(mock_ec2_client, max_age_days=30) + + assert stats["deleted"] == 0 + assert stats["skipped_too_recent"] == 1 + mock_ec2_client.delete_volume.assert_not_called() + + def test_cleanup_old_quarantined_volumes_skips_attached(self, mock_ec2_client): + """Attached quarantined volumes are not deleted (safety).""" + old_timestamp = (datetime.now(UTC) - timedelta(days=35)).isoformat() + mock_ec2_client.get_paginator.return_value.paginate.return_value = [ + { + "Volumes": [ + { + "VolumeId": "vol-attached", + "Size": 100, + "Attachments": [{"State": "attached"}], + "Tags": [ + {"Key": "gpu-dev-quarantined", "Value": old_timestamp}, + {"Key": "gpu-dev-user", "Value": "testuser"}, + {"Key": "disk-name", "Value": "attached-disk"} + ] + } + ] + } + ] + + from shared.disk_reconciler import cleanup_old_quarantined_volumes + stats = cleanup_old_quarantined_volumes(mock_ec2_client, max_age_days=30) + + assert stats["deleted"] == 0 + assert stats["errors"] == 1 + mock_ec2_client.delete_volume.assert_not_called() + + +# ============================================================================ +# Reconciliation Lock Tests +# ============================================================================ + +class TestReconciliationLock: + """Tests for reconciliation advisory lock.""" + + def test_reconcile_acquires_lock( + self, mock_ec2_client, mock_db_cursor, mock_connection_pool + ): + """reconcile_all_disks acquires advisory lock.""" + mock_db_cursor.fetchone.return_value = {"locked": True} + mock_db_cursor.fetchall.return_value = [] + + mock_ec2_client.get_paginator.return_value.paginate.return_value = [ + {"Volumes": []} + ] + + with patch("shared.disk_reconciler.get_db_cursor") as mock_get_cursor: + mock_get_cursor.return_value.__enter__.return_value = mock_db_cursor + with patch("shared.disk_reconciler.get_db_transaction"): + from shared.disk_reconciler import reconcile_all_disks + stats = reconcile_all_disks(mock_ec2_client) + + # Check lock was acquired + assert mock_db_cursor.execute.call_count >= 2 + lock_call = mock_db_cursor.execute.call_args_list[0] + assert "pg_try_advisory_lock" in lock_call[0][0] + + def test_reconcile_skips_if_lock_held( + self, mock_ec2_client, mock_db_cursor, mock_connection_pool + ): + """reconcile_all_disks skips if lock is already held.""" + mock_db_cursor.fetchone.return_value = {"locked": False} + + with patch("shared.disk_reconciler.get_db_cursor") as mock_get_cursor: + mock_get_cursor.return_value.__enter__.return_value = mock_db_cursor + from shared.disk_reconciler import reconcile_all_disks + stats = reconcile_all_disks(mock_ec2_client) + + assert stats["skipped_concurrent_run"] is True + mock_ec2_client.get_paginator.assert_not_called() + + +# ============================================================================ +# Timezone Handling Tests +# ============================================================================ + +class TestTimezoneHandling: + """Tests for timezone-aware datetime handling.""" + + def test_ensure_utc_with_naive_datetime(self): + """ensure_utc handles naive datetime.""" + from shared.disk_reconciler import ensure_utc + naive = datetime(2024, 1, 1, 12, 0, 0) + result = ensure_utc(naive) + + assert result.tzinfo is not None + assert result.tzinfo == UTC + + def test_ensure_utc_with_aware_datetime(self): + """ensure_utc handles aware datetime.""" + from shared.disk_reconciler import ensure_utc + aware = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + result = ensure_utc(aware) + + assert result.tzinfo is not None + assert result == aware + + def test_ensure_utc_with_none(self): + """ensure_utc handles None.""" + from shared.disk_reconciler import ensure_utc + assert ensure_utc(None) is None diff --git a/tests/unit/services/test_job_processor.py b/tests/unit/services/test_job_processor.py new file mode 100644 index 00000000..6c77cfa5 --- /dev/null +++ b/tests/unit/services/test_job_processor.py @@ -0,0 +1,544 @@ +""" +Unit tests for reservation processor (poller and job manager). +""" +import json +import os +import time +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, patch, call + +import pytest +from kubernetes.client.rest import ApiException + + +# ============================================================================ +# JobManager Tests +# ============================================================================ + +class TestJobManager: + """Tests for JobManager class.""" + + @pytest.fixture + def job_manager(self, mock_k8s_batch_api, mock_k8s_core_api): + """Create JobManager instance with mocks.""" + with patch.dict(os.environ, { + "KUBE_NAMESPACE": "test-namespace", + "WORKER_IMAGE": "test-image:latest", + "SERVICE_ACCOUNT": "test-sa", + "QUEUE_NAME": "test_queue" + }): + from processor.job_manager import JobManager + return JobManager(mock_k8s_batch_api, mock_k8s_core_api) + + def test_job_manager_init(self, mock_k8s_batch_api, mock_k8s_core_api): + """JobManager initializes with correct config.""" + with patch.dict(os.environ, { + "KUBE_NAMESPACE": "my-namespace", + "WORKER_IMAGE": "my-image:v1", + "SERVICE_ACCOUNT": "my-sa" + }): + from processor.job_manager import JobManager + manager = JobManager(mock_k8s_batch_api, mock_k8s_core_api) + + assert manager.namespace == "my-namespace" + assert manager.worker_image == "my-image:v1" + assert manager.service_account == "my-sa" + + def test_create_job_success(self, job_manager, mock_k8s_batch_api): + """create_job creates K8s job successfully.""" + message = { + "action": "create_reservation", + "user_id": "testuser", + "reservation_id": "res-123", + "_metadata": {"retry_count": 0} + } + + job_name = job_manager.create_job(msg_id=100, message=message) + + assert job_name == "reservation-worker-100" + mock_k8s_batch_api.create_namespaced_job.assert_called_once() + call_args = mock_k8s_batch_api.create_namespaced_job.call_args + assert call_args.kwargs["namespace"] == "test-namespace" + + def test_create_job_idempotent_on_409( + self, job_manager, mock_k8s_batch_api + ): + """create_job handles 409 conflict (job already exists).""" + api_exception = ApiException(status=409, reason="Conflict") + mock_k8s_batch_api.create_namespaced_job.side_effect = api_exception + + message = {"action": "test", "user_id": "testuser"} + job_name = job_manager.create_job(msg_id=100, message=message) + + assert job_name == "reservation-worker-100" + + def test_create_job_raises_on_other_errors( + self, job_manager, mock_k8s_batch_api + ): + """create_job raises on non-409 errors.""" + api_exception = ApiException(status=500, reason="Internal Error") + mock_k8s_batch_api.create_namespaced_job.side_effect = api_exception + + message = {"action": "test", "user_id": "testuser"} + with pytest.raises(ApiException): + job_manager.create_job(msg_id=100, message=message) + + def test_get_job_status_succeeded(self, job_manager, mock_k8s_batch_api): + """get_job_status returns Succeeded for completed job.""" + mock_k8s_batch_api.read_namespaced_job_status.return_value = MagicMock( + status=MagicMock( + active=0, + succeeded=1, + failed=0, + start_time=datetime.now(UTC), + completion_time=datetime.now(UTC) + ) + ) + + status = job_manager.get_job_status("test-job") + + assert status["phase"] == "Succeeded" + assert status["succeeded"] == 1 + + def test_get_job_status_failed(self, job_manager, mock_k8s_batch_api): + """get_job_status returns Failed for failed job.""" + mock_k8s_batch_api.read_namespaced_job_status.return_value = MagicMock( + status=MagicMock( + active=0, + succeeded=0, + failed=1, + start_time=datetime.now(UTC), + completion_time=datetime.now(UTC) + ) + ) + + status = job_manager.get_job_status("test-job") + + assert status["phase"] == "Failed" + assert status["failed"] == 1 + + def test_get_job_status_running(self, job_manager, mock_k8s_batch_api): + """get_job_status returns Running for active job.""" + mock_k8s_batch_api.read_namespaced_job_status.return_value = MagicMock( + status=MagicMock( + active=1, + succeeded=0, + failed=0, + start_time=datetime.now(UTC), + completion_time=None + ) + ) + + status = job_manager.get_job_status("test-job") + + assert status["phase"] == "Running" + assert status["active"] == 1 + + def test_get_job_status_pending(self, job_manager, mock_k8s_batch_api): + """get_job_status returns Pending for pending job.""" + mock_k8s_batch_api.read_namespaced_job_status.return_value = MagicMock( + status=MagicMock( + active=0, + succeeded=0, + failed=0, + start_time=None, + completion_time=None + ) + ) + + status = job_manager.get_job_status("test-job") + + assert status["phase"] == "Pending" + + def test_get_job_status_not_found(self, job_manager, mock_k8s_batch_api): + """get_job_status returns None for non-existent job.""" + api_exception = ApiException(status=404, reason="Not Found") + mock_k8s_batch_api.read_namespaced_job_status.side_effect = api_exception + + status = job_manager.get_job_status("non-existent-job") + + assert status is None + + def test_delete_job_success(self, job_manager, mock_k8s_batch_api): + """delete_job deletes job successfully.""" + job_manager.delete_job("test-job") + + mock_k8s_batch_api.delete_namespaced_job.assert_called_once_with( + name="test-job", + namespace="test-namespace", + propagation_policy="Background" + ) + + def test_delete_job_already_deleted(self, job_manager, mock_k8s_batch_api): + """delete_job handles 404 (already deleted).""" + api_exception = ApiException(status=404, reason="Not Found") + mock_k8s_batch_api.delete_namespaced_job.side_effect = api_exception + + # Should not raise + job_manager.delete_job("test-job") + + def test_get_job_logs_success( + self, job_manager, mock_k8s_batch_api, mock_k8s_core_api + ): + """get_job_logs returns pod logs.""" + mock_k8s_core_api.list_namespaced_pod.return_value = MagicMock( + items=[MagicMock(metadata=MagicMock(name="test-pod"))] + ) + mock_k8s_core_api.read_namespaced_pod_log.return_value = "log output" + + logs = job_manager.get_job_logs("test-job", tail_lines=50) + + assert logs == "log output" + + def test_get_job_logs_no_pod(self, job_manager, mock_k8s_core_api): + """get_job_logs returns None when no pod found.""" + mock_k8s_core_api.list_namespaced_pod.return_value = MagicMock(items=[]) + + logs = job_manager.get_job_logs("test-job") + + assert logs is None + + def test_list_active_jobs(self, job_manager, mock_k8s_batch_api): + """list_active_jobs returns active job names.""" + mock_k8s_batch_api.list_namespaced_job.return_value = MagicMock( + items=[ + MagicMock( + metadata=MagicMock(name="job-1"), + status=MagicMock(active=1) + ), + MagicMock( + metadata=MagicMock(name="job-2"), + status=MagicMock(active=0) + ), + MagicMock( + metadata=MagicMock(name="job-3"), + status=MagicMock(active=1) + ) + ] + ) + + active_jobs = job_manager.list_active_jobs() + + assert "job-1" in active_jobs + assert "job-2" not in active_jobs + assert "job-3" in active_jobs + + def test_get_worker_env_includes_message_body(self, job_manager): + """_get_worker_env includes MESSAGE_BODY.""" + message_json = json.dumps({"action": "test"}) + env_vars = job_manager._get_worker_env(message_json) + + message_body_var = next( + (v for v in env_vars if v.name == "MESSAGE_BODY"), + None + ) + assert message_body_var is not None + assert message_body_var.value == message_json + + +# ============================================================================ +# Poller Message Processing Tests +# ============================================================================ + +class TestPollerMessageProcessing: + """Tests for poller message processing logic.""" + + @pytest.fixture + def mock_db_env(self): + """Set up database environment.""" + with patch.dict(os.environ, { + "POSTGRES_HOST": "localhost", + "POSTGRES_PORT": "5432", + "POSTGRES_USER": "testuser", + "POSTGRES_PASSWORD": "testpass", + "POSTGRES_DB": "testdb", + "QUEUE_NAME": "test_queue", + "POLL_INTERVAL_SECONDS": "1", + "VISIBILITY_TIMEOUT_SECONDS": "300", + "BATCH_SIZE": "1", + "MAX_CONCURRENT_JOBS": "10" + }): + yield + + def test_poll_messages_success( + self, mock_db_cursor, mock_db_env, mock_connection_pool + ): + """poll_messages returns messages from queue.""" + mock_db_cursor.fetchall.return_value = [ + { + "msg_id": 1, + "read_ct": 1, + "enqueued_at": datetime.now(UTC), + "vt": datetime.now(UTC), + "message": {"action": "test", "user_id": "testuser"} + } + ] + + with patch("processor.poller.get_db_cursor") as mock_get_cursor: + mock_get_cursor.return_value.__enter__.return_value = mock_db_cursor + from processor.poller import poll_messages + messages = poll_messages(batch_size=1) + + assert len(messages) == 1 + assert messages[0]["msg_id"] == 1 + + def test_poll_messages_empty(self, mock_db_cursor, mock_db_env): + """poll_messages returns empty list when no messages.""" + mock_db_cursor.fetchall.return_value = [] + + with patch("processor.poller.get_db_cursor") as mock_get_cursor: + mock_get_cursor.return_value.__enter__.return_value = mock_db_cursor + from processor.poller import poll_messages + messages = poll_messages(batch_size=1) + + assert messages == [] + + def test_poll_messages_error(self, mock_db_cursor, mock_db_env): + """poll_messages returns empty list on error.""" + mock_db_cursor.execute.side_effect = Exception("DB error") + + with patch("processor.poller.get_db_cursor") as mock_get_cursor: + mock_get_cursor.return_value.__enter__.return_value = mock_db_cursor + from processor.poller import poll_messages + messages = poll_messages(batch_size=1) + + assert messages == [] + + def test_archive_message_success(self, mock_db_cursor, mock_db_env): + """archive_message archives message successfully.""" + mock_db_cursor.fetchone.return_value = {"archived": True} + + with patch("processor.poller.get_db_cursor") as mock_get_cursor: + mock_get_cursor.return_value.__enter__.return_value = mock_db_cursor + from processor.poller import archive_message + result = archive_message(msg_id=123, reason="test failure") + + assert result is True + mock_db_cursor.execute.assert_called() + + +# ============================================================================ +# Retry Logic Tests +# ============================================================================ + +class TestRetryLogic: + """Tests for retry handling.""" + + def test_max_retries_exceeded(self, mock_db_env): + """Message archived when max retries exceeded.""" + message = { + "msg_id": 1, + "read_ct": 4, # Exceeds MAX_RETRIES (3) + "message": {"action": "test"} + } + + with patch("processor.poller.MAX_RETRIES", 3): + with patch("processor.poller.archive_message") as mock_archive: + mock_archive.return_value = True + # The actual check happens in process_loop + # Here we verify the logic + assert message["read_ct"] >= 3 + + +# ============================================================================ +# Active Jobs Tracking Tests +# ============================================================================ + +class TestActiveJobsTracking: + """Tests for active jobs tracking.""" + + def test_check_job_status_succeeded( + self, mock_k8s_batch_api, mock_k8s_core_api + ): + """Succeeded job is removed from tracking.""" + with patch.dict(os.environ, { + "KUBE_NAMESPACE": "test-namespace", + "WORKER_IMAGE": "test-image:latest" + }): + from processor.job_manager import JobManager + from processor.poller import active_jobs, check_job_status + + manager = JobManager(mock_k8s_batch_api, mock_k8s_core_api) + mock_k8s_batch_api.read_namespaced_job_status.return_value = MagicMock( + status=MagicMock( + active=0, + succeeded=1, + failed=0, + start_time=datetime.now(UTC), + completion_time=datetime.now(UTC) + ) + ) + + active_jobs[1] = { + "job_name": "test-job", + "created_at": time.time() + } + + check_job_status(manager, 1, active_jobs[1]) + + assert 1 not in active_jobs + + def test_check_job_status_failed( + self, mock_k8s_batch_api, mock_k8s_core_api + ): + """Failed job is removed from tracking.""" + with patch.dict(os.environ, { + "KUBE_NAMESPACE": "test-namespace", + "WORKER_IMAGE": "test-image:latest" + }): + from processor.job_manager import JobManager + from processor.poller import active_jobs, check_job_status + + manager = JobManager(mock_k8s_batch_api, mock_k8s_core_api) + mock_k8s_batch_api.read_namespaced_job_status.return_value = MagicMock( + status=MagicMock( + active=0, + succeeded=0, + failed=1, + start_time=datetime.now(UTC), + completion_time=datetime.now(UTC) + ) + ) + + active_jobs[2] = { + "job_name": "test-job-2", + "created_at": time.time() + } + + check_job_status(manager, 2, active_jobs[2]) + + assert 2 not in active_jobs + + def test_rebuild_active_jobs_from_k8s( + self, mock_k8s_batch_api, mock_k8s_core_api + ): + """Active jobs rebuilt from K8s on startup.""" + with patch.dict(os.environ, { + "KUBE_NAMESPACE": "test-namespace", + "WORKER_IMAGE": "test-image:latest" + }): + from processor.job_manager import JobManager + from processor.poller import rebuild_active_jobs_from_k8s, active_jobs + + active_jobs.clear() + + manager = JobManager(mock_k8s_batch_api, mock_k8s_core_api) + mock_k8s_batch_api.list_namespaced_job.return_value = MagicMock( + items=[ + MagicMock( + metadata=MagicMock( + name="reservation-worker-100", + creation_timestamp=datetime.now(UTC), + labels={"action": "create_reservation"}, + annotations={"user_id": "testuser"} + ), + status=MagicMock(active=1) + ), + MagicMock( + metadata=MagicMock( + name="reservation-worker-101", + creation_timestamp=datetime.now(UTC), + labels={"action": "cancel"}, + annotations={"user_id": "testuser2"} + ), + status=MagicMock(active=1) + ) + ] + ) + + recovered = rebuild_active_jobs_from_k8s(manager) + + assert recovered == 2 + assert 100 in active_jobs + assert 101 in active_jobs + + +# ============================================================================ +# Job Environment Tests +# ============================================================================ + +class TestJobEnvironment: + """Tests for job environment configuration.""" + + def test_worker_env_includes_db_config( + self, mock_k8s_batch_api, mock_k8s_core_api + ): + """Worker environment includes database configuration.""" + with patch.dict(os.environ, { + "KUBE_NAMESPACE": "test-namespace", + "WORKER_IMAGE": "test-image:latest", + "POSTGRES_HOST": "db.example.com", + "POSTGRES_PORT": "5432", + "POSTGRES_USER": "gpudev", + "POSTGRES_DB": "gpudev" + }): + from processor.job_manager import JobManager + + manager = JobManager(mock_k8s_batch_api, mock_k8s_core_api) + env_vars = manager._get_worker_env() + + env_names = [v.name for v in env_vars] + assert "POSTGRES_HOST" in env_names + assert "POSTGRES_PORT" in env_names + assert "POSTGRES_USER" in env_names + assert "POSTGRES_DB" in env_names + assert "POSTGRES_PASSWORD" in env_names + + def test_worker_env_includes_aws_config( + self, mock_k8s_batch_api, mock_k8s_core_api + ): + """Worker environment includes AWS configuration.""" + with patch.dict(os.environ, { + "KUBE_NAMESPACE": "test-namespace", + "WORKER_IMAGE": "test-image:latest", + "REGION": "us-east-1", + "EKS_CLUSTER_NAME": "test-cluster", + "PRIMARY_AVAILABILITY_ZONE": "us-east-1a" + }): + from processor.job_manager import JobManager + + manager = JobManager(mock_k8s_batch_api, mock_k8s_core_api) + env_vars = manager._get_worker_env() + + env_names = [v.name for v in env_vars] + assert "REGION" in env_names + assert "EKS_CLUSTER_NAME" in env_names + assert "PRIMARY_AVAILABILITY_ZONE" in env_names + + +# ============================================================================ +# State Transition Tests +# ============================================================================ + +class TestReservationStateTransitions: + """Tests for reservation state transitions.""" + + def test_queued_to_preparing(self): + """Reservation transitions from queued to preparing.""" + # This tests the expected state machine behavior + valid_transitions = { + "queued": ["preparing", "cancelled", "failed"], + "preparing": ["active", "failed", "cancelled"], + "active": ["completed", "cancelled", "failed", "expired"], + "completed": [], + "cancelled": [], + "failed": [], + "expired": [] + } + + assert "preparing" in valid_transitions["queued"] + assert "active" in valid_transitions["preparing"] + assert "completed" in valid_transitions["active"] + + def test_invalid_state_transition(self): + """Invalid state transitions are rejected.""" + valid_transitions = { + "completed": [], + "cancelled": [], + "failed": [] + } + + # Terminal states cannot transition + assert valid_transitions["completed"] == [] + assert valid_transitions["cancelled"] == [] + assert valid_transitions["failed"] == []