diff --git a/terraform-gpu-devservers/providers/__init__.py b/terraform-gpu-devservers/providers/__init__.py new file mode 100644 index 00000000..2c6f5a97 --- /dev/null +++ b/terraform-gpu-devservers/providers/__init__.py @@ -0,0 +1,195 @@ +""" +Cloud Provider Factory + +This module provides a factory function to get the appropriate cloud provider +based on configuration. The provider abstraction allows the GPU reservation +system to work with multiple cloud platforms without modifying core business logic. + +Usage: + from providers import get_cloud_provider + + provider = get_cloud_provider() + + # Storage operations + volume = provider.create_volume(size_gb=100, availability_zone='us-east-2a') + + # Snapshot operations + snapshot = provider.create_snapshot(volume.volume_id) + + # Object storage + uri = provider.upload_to_object_storage('bucket', 'key', b'content') + +Configuration: + Set CLOUD_PROVIDER environment variable: + - 'aws' (default): Amazon Web Services + - 'gcp': Google Cloud Platform + - 'custom': Custom/on-premises provider + + Provider-specific configuration via environment variables: + - AWS: AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY + - GCP: GCP_PROJECT, GCP_ZONE, GOOGLE_APPLICATION_CREDENTIALS + - Custom: CUSTOM_STORAGE_BACKEND, CUSTOM_AUTH_BACKEND +""" + +import logging +import os +from typing import Optional + +from .base import ( + AuthProvider, + AuthenticationError, + AuthorizationError, + CloudProvider, + NodeInfo, + ProviderError, + QuotaExceededError, + SnapshotInfo, + SnapshotNotFoundError, + VolumeInfo, + VolumeInUseError, + VolumeNotFoundError, +) + +logger = logging.getLogger(__name__) + +# Cached provider instance +_provider_instance: Optional[CloudProvider] = None + + +def get_cloud_provider( + provider_name: Optional[str] = None, + force_new: bool = False, + **kwargs +) -> CloudProvider: + """ + Get the configured cloud provider instance. + + This factory function returns the appropriate provider based on configuration. + The provider instance is cached for performance; use force_new=True to + create a new instance. + + Args: + provider_name: Override the provider (defaults to CLOUD_PROVIDER env var) + force_new: Force creation of new instance (bypass cache) + **kwargs: Provider-specific configuration options + + Returns: + CloudProvider instance (AWSProvider, GCPProvider, or CustomProvider) + + Raises: + ValueError: If provider name is not recognized + + Example: + # Use default provider from environment + provider = get_cloud_provider() + + # Override provider for testing + provider = get_cloud_provider('custom') + + # Force new instance with custom config + provider = get_cloud_provider('aws', force_new=True, region='us-west-2') + """ + global _provider_instance + + # Use cached instance if available and not forcing new + if _provider_instance is not None and not force_new and provider_name is None: + return _provider_instance + + # Determine provider name + name = provider_name or os.environ.get("CLOUD_PROVIDER", "aws") + name = name.lower() + + logger.info(f"Initializing cloud provider: {name}") + + if name == "aws": + from .aws import AWSProvider + region = kwargs.get("region") or os.environ.get("AWS_REGION", "us-east-2") + provider = AWSProvider(region=region) + + elif name == "gcp": + from .gcp import GCPProvider + project = kwargs.get("project") or os.environ.get("GCP_PROJECT", "") + zone = kwargs.get("zone") or os.environ.get("GCP_ZONE", "us-central1-a") + if not project: + raise ValueError( + "GCP_PROJECT environment variable must be set for GCP provider" + ) + provider = GCPProvider(project=project, zone=zone) + + elif name == "custom": + from .custom import CustomProvider + provider = CustomProvider() + + else: + raise ValueError( + f"Unknown cloud provider: {name}. " + f"Valid options: aws, gcp, custom" + ) + + # Cache the instance + if not force_new: + _provider_instance = provider + + return provider + + +def get_auth_provider( + provider_name: Optional[str] = None, + **kwargs +) -> AuthProvider: + """ + Get an authentication provider instance. + + Args: + provider_name: Override the provider (defaults to CLOUD_PROVIDER env var) + **kwargs: Provider-specific configuration options + + Returns: + AuthProvider instance + """ + name = provider_name or os.environ.get("CLOUD_PROVIDER", "aws") + name = name.lower() + + if name == "aws": + from .aws import AWSIAMAuthProvider + region = kwargs.get("region") or os.environ.get("AWS_REGION", "us-east-2") + return AWSIAMAuthProvider(region=region) + + elif name == "gcp": + raise NotImplementedError("GCP auth provider not implemented") + + elif name == "custom": + from .custom import CustomAuthProvider + return CustomAuthProvider() + + else: + raise ValueError(f"Unknown auth provider: {name}") + + +def clear_provider_cache(): + """Clear the cached provider instance.""" + global _provider_instance + _provider_instance = None + + +__all__ = [ + # Factory functions + "get_cloud_provider", + "get_auth_provider", + "clear_provider_cache", + # Base classes + "CloudProvider", + "AuthProvider", + # Data classes + "VolumeInfo", + "SnapshotInfo", + "NodeInfo", + # Exceptions + "ProviderError", + "VolumeNotFoundError", + "VolumeInUseError", + "SnapshotNotFoundError", + "QuotaExceededError", + "AuthenticationError", + "AuthorizationError", +] diff --git a/terraform-gpu-devservers/providers/aws.py b/terraform-gpu-devservers/providers/aws.py new file mode 100644 index 00000000..5584081a --- /dev/null +++ b/terraform-gpu-devservers/providers/aws.py @@ -0,0 +1,403 @@ +""" +AWS Cloud Provider Implementation + +Wraps existing boto3 code to provide cloud-agnostic interface. +""" + +import logging +from typing import Dict, List, Optional, Any +from datetime import datetime, timezone + +import boto3 +from botocore.exceptions import ClientError + +from .base import ( + CloudProvider, + AuthProvider, + VolumeInfo, + SnapshotInfo, + NodeInfo, +) + +logger = logging.getLogger(__name__) + + +class AWSProvider(CloudProvider): + """AWS implementation of CloudProvider interface.""" + + def __init__(self, region: str = "us-east-2"): + self.region = region + self._ec2 = None + self._s3 = None + self._autoscaling = None + self._efs = None + + @property + def ec2(self): + if self._ec2 is None: + self._ec2 = boto3.client("ec2", region_name=self.region) + return self._ec2 + + @property + def s3(self): + if self._s3 is None: + self._s3 = boto3.client("s3", region_name=self.region) + return self._s3 + + @property + def efs(self): + if self._efs is None: + self._efs = boto3.client("efs", region_name=self.region) + return self._efs + + def name(self) -> str: + return "aws" + + # === Block Storage (EBS) === + + def create_volume( + self, + size_gb: int, + availability_zone: str, + volume_type: str = "ssd", + tags: Optional[Dict[str, str]] = None, + snapshot_id: Optional[str] = None, + ) -> VolumeInfo: + """Create an EBS volume.""" + aws_volume_type = {"ssd": "gp3", "hdd": "sc1", "io": "io2"}.get(volume_type, "gp3") + + params = { + "AvailabilityZone": availability_zone, + "Size": size_gb, + "VolumeType": aws_volume_type, + "Encrypted": True, + } + + if snapshot_id: + params["SnapshotId"] = snapshot_id + + if aws_volume_type == "gp3": + params["Iops"] = 3000 + params["Throughput"] = 125 + + response = self.ec2.create_volume(**params) + volume_id = response["VolumeId"] + + if tags: + self.ec2.create_tags( + Resources=[volume_id], + Tags=[{"Key": k, "Value": v} for k, v in tags.items()], + ) + + return VolumeInfo( + volume_id=volume_id, + size_gb=response["Size"], + availability_zone=response["AvailabilityZone"], + status=response["State"], + tags=tags or {}, + ) + + def delete_volume(self, volume_id: str) -> bool: + """Delete an EBS volume.""" + try: + self.ec2.delete_volume(VolumeId=volume_id) + return True + except ClientError as e: + logger.error(f"Failed to delete volume {volume_id}: {e}") + return False + + def attach_volume( + self, volume_id: str, instance_id: str, device_path: str + ) -> bool: + """Attach EBS volume to EC2 instance.""" + try: + self.ec2.attach_volume( + VolumeId=volume_id, + InstanceId=instance_id, + Device=device_path, + ) + return True + except ClientError as e: + logger.error(f"Failed to attach volume {volume_id}: {e}") + return False + + def detach_volume(self, volume_id: str) -> bool: + """Detach EBS volume from instance.""" + try: + self.ec2.detach_volume(VolumeId=volume_id) + return True + except ClientError as e: + logger.error(f"Failed to detach volume {volume_id}: {e}") + return False + + def get_volume(self, volume_id: str) -> Optional[VolumeInfo]: + """Get EBS volume information.""" + try: + response = self.ec2.describe_volumes(VolumeIds=[volume_id]) + if response["Volumes"]: + vol = response["Volumes"][0] + tags = {t["Key"]: t["Value"] for t in vol.get("Tags", [])} + return VolumeInfo( + volume_id=vol["VolumeId"], + size_gb=vol["Size"], + availability_zone=vol["AvailabilityZone"], + status=vol["State"], + tags=tags, + ) + except ClientError: + pass + return None + + def list_volumes( + self, filters: Optional[Dict[str, str]] = None + ) -> List[VolumeInfo]: + """List EBS volumes matching filters.""" + aws_filters = [] + if filters: + for key, value in filters.items(): + aws_filters.append({"Name": f"tag:{key}", "Values": [value]}) + + response = self.ec2.describe_volumes(Filters=aws_filters) + + volumes = [] + for vol in response.get("Volumes", []): + tags = {t["Key"]: t["Value"] for t in vol.get("Tags", [])} + volumes.append( + VolumeInfo( + volume_id=vol["VolumeId"], + size_gb=vol["Size"], + availability_zone=vol["AvailabilityZone"], + status=vol["State"], + tags=tags, + ) + ) + return volumes + + # === Snapshots === + + def create_snapshot( + self, + volume_id: str, + description: str = "", + tags: Optional[Dict[str, str]] = None, + ) -> SnapshotInfo: + """Create EBS snapshot.""" + response = self.ec2.create_snapshot( + VolumeId=volume_id, + Description=description, + ) + + snapshot_id = response["SnapshotId"] + + if tags: + self.ec2.create_tags( + Resources=[snapshot_id], + Tags=[{"Key": k, "Value": v} for k, v in tags.items()], + ) + + return SnapshotInfo( + snapshot_id=snapshot_id, + volume_id=volume_id, + status=response["State"], + size_gb=response["VolumeSize"], + created_at=response["StartTime"].isoformat(), + tags=tags or {}, + ) + + def delete_snapshot(self, snapshot_id: str) -> bool: + """Delete EBS snapshot.""" + try: + self.ec2.delete_snapshot(SnapshotId=snapshot_id) + return True + except ClientError as e: + logger.error(f"Failed to delete snapshot {snapshot_id}: {e}") + return False + + def get_snapshot(self, snapshot_id: str) -> Optional[SnapshotInfo]: + """Get EBS snapshot information.""" + try: + response = self.ec2.describe_snapshots(SnapshotIds=[snapshot_id]) + if response["Snapshots"]: + snap = response["Snapshots"][0] + tags = {t["Key"]: t["Value"] for t in snap.get("Tags", [])} + return SnapshotInfo( + snapshot_id=snap["SnapshotId"], + volume_id=snap["VolumeId"], + status=snap["State"], + size_gb=snap["VolumeSize"], + created_at=snap["StartTime"].isoformat(), + tags=tags, + ) + except ClientError: + pass + return None + + def list_snapshots( + self, + filters: Optional[Dict[str, str]] = None, + volume_id: Optional[str] = None, + status: Optional[List[str]] = None, + use_pagination: bool = True, + ) -> List[SnapshotInfo]: + """ + List EBS snapshots matching filters. + + Args: + filters: Tag-based filters as key-value pairs (e.g., {"gpu-dev-user": "john"}) + volume_id: Filter by specific volume ID + status: Filter by status (e.g., ["pending", "completed"]) + use_pagination: Whether to use pagination for large result sets + """ + aws_filters = [] + + # Tag-based filters + if filters: + for key, value in filters.items(): + aws_filters.append({"Name": f"tag:{key}", "Values": [value]}) + + # Volume ID filter + if volume_id: + aws_filters.append({"Name": "volume-id", "Values": [volume_id]}) + + # Status filter + if status: + aws_filters.append({"Name": "status", "Values": status}) + + snapshots = [] + + if use_pagination: + paginator = self.ec2.get_paginator('describe_snapshots') + page_iterator = paginator.paginate( + OwnerIds=["self"], + Filters=aws_filters if aws_filters else [], + PaginationConfig={'PageSize': 100} + ) + + for page in page_iterator: + for snap in page.get("Snapshots", []): + tags = {t["Key"]: t["Value"] for t in snap.get("Tags", [])} + snapshots.append( + SnapshotInfo( + snapshot_id=snap["SnapshotId"], + volume_id=snap["VolumeId"], + status=snap["State"], + size_gb=snap["VolumeSize"], + created_at=snap["StartTime"].isoformat(), + tags=tags, + ) + ) + else: + params = {"OwnerIds": ["self"]} + if aws_filters: + params["Filters"] = aws_filters + response = self.ec2.describe_snapshots(**params) + + for snap in response.get("Snapshots", []): + tags = {t["Key"]: t["Value"] for t in snap.get("Tags", [])} + snapshots.append( + SnapshotInfo( + snapshot_id=snap["SnapshotId"], + volume_id=snap["VolumeId"], + status=snap["State"], + size_gb=snap["VolumeSize"], + created_at=snap["StartTime"].isoformat(), + tags=tags, + ) + ) + return snapshots + + def wait_for_snapshot( + self, snapshot_id: str, timeout_seconds: int = 600 + ) -> bool: + """Wait for EBS snapshot to complete.""" + try: + waiter = self.ec2.get_waiter("snapshot_completed") + waiter.wait( + SnapshotIds=[snapshot_id], + WaiterConfig={"Delay": 15, "MaxAttempts": timeout_seconds // 15}, + ) + return True + except Exception as e: + logger.error(f"Snapshot {snapshot_id} did not complete: {e}") + return False + + # === Compute === + + def get_nodes_by_gpu_type(self, gpu_type: str) -> List[NodeInfo]: + """Get EC2 instances by GPU type label.""" + # This would typically query K8s nodes, not EC2 directly + # For now, return empty - K8s client handles this + return [] + + def get_node_availability(self) -> Dict[str, Dict[str, int]]: + """Get GPU availability - delegates to K8s.""" + # This is handled by the availability updater + return {} + + # === Object Storage (S3) === + + def upload_to_object_storage( + self, + bucket: str, + key: str, + content: bytes, + metadata: Optional[Dict[str, str]] = None, + content_type: str = "application/octet-stream", + ) -> str: + """Upload content to S3.""" + self.s3.put_object( + Bucket=bucket, + Key=key, + Body=content, + ContentType=content_type, + **({"Metadata": metadata} if metadata else {}), + ) + + return f"s3://{bucket}/{key}" + + def download_from_object_storage( + self, bucket: str, key: str + ) -> Optional[bytes]: + """Download content from S3.""" + try: + response = self.s3.get_object(Bucket=bucket, Key=key) + return response["Body"].read() + except ClientError: + return None + + +class AWSIAMAuthProvider(AuthProvider): + """AWS IAM/STS based authentication (legacy).""" + + def __init__(self, region: str = "us-east-2"): + self.region = region + self._sts = None + + @property + def sts(self): + if self._sts is None: + self._sts = boto3.client("sts", region_name=self.region) + return self._sts + + def verify_token(self, token: str) -> Optional[Dict[str, Any]]: + """Verify AWS credentials (token contains access key info).""" + # This is handled differently - credentials are verified via STS + return None + + def get_user_info(self, token: str) -> Optional[Dict[str, Any]]: + """Get user info from AWS identity.""" + try: + response = self.sts.get_caller_identity() + return { + "user_id": response["UserId"], + "account": response["Account"], + "arn": response["Arn"], + } + except Exception: + return None + + def create_api_key( + self, user_id: str, scopes: List[str], ttl_hours: int = 24 + ) -> str: + """Create API key - handled by API service.""" + raise NotImplementedError("Use API service for API key creation") diff --git a/terraform-gpu-devservers/providers/base.py b/terraform-gpu-devservers/providers/base.py new file mode 100644 index 00000000..827026f0 --- /dev/null +++ b/terraform-gpu-devservers/providers/base.py @@ -0,0 +1,283 @@ +""" +Abstract base classes for cloud provider interfaces. + +This module defines the abstract interfaces that all cloud providers must implement. +The abstraction allows the GPU reservation system to work with multiple cloud platforms +(AWS, GCP, custom on-prem) without modifying core business logic. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + + +@dataclass +class VolumeInfo: + """Standardized volume information across providers.""" + volume_id: str + size_gb: int + availability_zone: str + status: str # 'available', 'in-use', 'creating', 'deleting' + tags: dict[str, str] + + +@dataclass +class SnapshotInfo: + """Standardized snapshot information across providers.""" + snapshot_id: str + volume_id: str + status: str # 'pending', 'completed', 'error' + size_gb: int + created_at: str # ISO format timestamp + tags: dict[str, str] + + +@dataclass +class NodeInfo: + """Standardized node/instance information across providers.""" + node_id: str + name: str + instance_type: str + availability_zone: str + gpu_type: str | None + gpu_count: int + status: str # 'running', 'stopped', 'terminated' + labels: dict[str, str] + + +class CloudProvider(ABC): + """ + Abstract base class for cloud provider implementations. + + This is the main interface for cloud-specific functionality. + Each cloud implementation provides concrete implementations of + storage, snapshot, compute, and object storage operations. + + Example: + provider = get_cloud_provider() # Returns AWSProvider or GCPProvider + + # Storage operations + volume = provider.create_volume(size_gb=100, availability_zone='us-east-2a') + + # Snapshot operations + snapshot = provider.create_snapshot(volume.volume_id) + + # Object storage + uri = provider.upload_to_object_storage('bucket', 'key', b'content') + """ + + @abstractmethod + def name(self) -> str: + """Provider name (aws, gcp, custom).""" + pass + + # === Block Storage === + + @abstractmethod + def create_volume( + self, + size_gb: int, + availability_zone: str, + volume_type: str = "ssd", + tags: dict[str, str] | None = None, + snapshot_id: str | None = None, + ) -> VolumeInfo: + """ + Create a block storage volume. + + Args: + size_gb: Volume size in gigabytes + availability_zone: Zone for volume placement + volume_type: Storage class (ssd, hdd, io) + tags: Key-value tags for the volume + snapshot_id: Create volume from snapshot + + Returns: + VolumeInfo with created volume details + """ + pass + + @abstractmethod + def delete_volume(self, volume_id: str) -> bool: + """Delete a block storage volume.""" + pass + + @abstractmethod + def attach_volume( + self, volume_id: str, instance_id: str, device_path: str + ) -> bool: + """Attach volume to instance.""" + pass + + @abstractmethod + def detach_volume(self, volume_id: str) -> bool: + """Detach volume from instance.""" + pass + + @abstractmethod + def get_volume(self, volume_id: str) -> VolumeInfo | None: + """Get volume information.""" + pass + + @abstractmethod + def list_volumes( + self, filters: dict[str, str] | None = None + ) -> list[VolumeInfo]: + """List volumes matching filters (by tags).""" + pass + + # === Snapshots === + + @abstractmethod + def create_snapshot( + self, + volume_id: str, + description: str = "", + tags: dict[str, str] | None = None, + ) -> SnapshotInfo: + """Create a snapshot of a volume.""" + pass + + @abstractmethod + def delete_snapshot(self, snapshot_id: str) -> bool: + """Delete a snapshot.""" + pass + + @abstractmethod + def get_snapshot(self, snapshot_id: str) -> SnapshotInfo | None: + """Get snapshot information.""" + pass + + @abstractmethod + def list_snapshots( + self, + filters: dict[str, str] | None = None, + volume_id: str | None = None, + status: list[str] | None = None, + use_pagination: bool = True, + ) -> list[SnapshotInfo]: + """ + List snapshots matching filters. + + Args: + filters: Tag-based filters as key-value pairs + volume_id: Filter by specific volume ID + status: Filter by status (e.g., ["pending", "completed"]) + use_pagination: Whether to use pagination for large result sets + """ + pass + + @abstractmethod + def wait_for_snapshot( + self, snapshot_id: str, timeout_seconds: int = 600 + ) -> bool: + """Wait for snapshot to complete.""" + pass + + # === Compute === + + @abstractmethod + def get_nodes_by_gpu_type(self, gpu_type: str) -> list[NodeInfo]: + """Get nodes/instances by GPU type.""" + pass + + @abstractmethod + def get_node_availability(self) -> dict[str, dict[str, int]]: + """Get GPU availability by type.""" + pass + + # === Object Storage === + + @abstractmethod + def upload_to_object_storage( + self, + bucket: str, + key: str, + content: bytes, + metadata: dict[str, str] | None = None, + content_type: str = "application/octet-stream", + ) -> str: + """Upload content to object storage. Returns URI.""" + pass + + @abstractmethod + def download_from_object_storage( + self, bucket: str, key: str + ) -> bytes | None: + """Download content from object storage.""" + pass + + +class AuthProvider(ABC): + """ + Abstract interface for identity verification. + + Used for authenticating API requests and verifying user identity. + """ + + @abstractmethod + def verify_token(self, token: str) -> dict[str, Any] | None: + """ + Verify an authentication token. + + Returns user info dict if valid, None if invalid. + """ + pass + + @abstractmethod + def get_user_info(self, token: str) -> dict[str, Any] | None: + """Get user information from token.""" + pass + + @abstractmethod + def create_api_key( + self, user_id: str, scopes: list[str], ttl_hours: int = 24 + ) -> str: + """Create an API key for a user.""" + pass + + +class ProviderError(Exception): + """Base exception for provider errors.""" + + def __init__( + self, + message: str, + provider: str = "unknown", + operation: str = "unknown", + details: dict | None = None + ): + self.provider = provider + self.operation = operation + self.details = details or {} + super().__init__(f"[{provider}] {operation}: {message}") + + +class VolumeNotFoundError(ProviderError): + """Volume does not exist.""" + pass + + +class VolumeInUseError(ProviderError): + """Volume is attached and cannot be modified.""" + pass + + +class SnapshotNotFoundError(ProviderError): + """Snapshot does not exist.""" + pass + + +class QuotaExceededError(ProviderError): + """Resource quota exceeded.""" + pass + + +class AuthenticationError(ProviderError): + """Authentication failed.""" + pass + + +class AuthorizationError(ProviderError): + """User not authorized for operation.""" + pass diff --git a/terraform-gpu-devservers/providers/custom.py b/terraform-gpu-devservers/providers/custom.py new file mode 100644 index 00000000..2a4ac23a --- /dev/null +++ b/terraform-gpu-devservers/providers/custom.py @@ -0,0 +1,409 @@ +""" +Custom Cloud Provider Template + +This module provides a template for implementing custom providers for: +- On-premises data centers +- Private clouds (OpenStack, VMware vSphere) +- Alternative cloud providers (DigitalOcean, Linode, etc.) +- Hybrid environments + +IMPLEMENTATION GUIDE +==================== + +1. Copy this file and rename it for your environment +2. Implement each abstract method +3. Register your provider in __init__.py +4. Set CLOUD_PROVIDER environment variable + +STORAGE INTEGRATION PATTERNS +============================ + +LVM (Linux Volume Manager): + - Create logical volumes in volume groups + - Use thin provisioning for snapshots + - Mount via device mapper + +Ceph RBD: + - Create RBD images in pools + - Use librbd or rbd CLI + - Map via rbd-nbd or krbd + +iSCSI: + - Create LUNs on storage array + - Map to initiator + - Discover and login to targets + +NFS: + - Create directories/quotas on NFS server + - Export via /etc/exports or storage array + +OBJECT STORAGE PATTERNS +======================= + +MinIO: + - S3-compatible API + - Use boto3 with custom endpoint + +Ceph RadosGW: + - S3-compatible API + - Use boto3 with custom endpoint + +Local filesystem: + - Use local directory as object store + - Simple for testing +""" + +import logging +import os +from typing import Any + +from .base import ( + AuthProvider, + CloudProvider, + NodeInfo, + SnapshotInfo, + VolumeInfo, +) + +logger = logging.getLogger(__name__) + + +class CustomProvider(CloudProvider): + """ + Template for custom cloud provider implementations. + + To implement: + 1. Replace each NotImplementedError with actual implementation + 2. Configure via environment variables + 3. Add any additional helper methods needed + """ + + def __init__(self): + # Configuration from environment + self.storage_backend = os.environ.get("CUSTOM_STORAGE_BACKEND", "lvm") + self.object_store_path = os.environ.get("CUSTOM_OBJECT_STORE", "/var/lib/gpu-dev/objects") + + def name(self) -> str: + return "custom" + + # === Block Storage === + + def create_volume( + self, + size_gb: int, + availability_zone: str, + volume_type: str = "ssd", + tags: dict[str, str] | None = None, + snapshot_id: str | None = None, + ) -> VolumeInfo: + """ + Create a block storage volume. + + Example LVM implementation: + import subprocess + import uuid + vol_name = f"gpudev-{uuid.uuid4().hex[:8]}" + cmd = ['lvcreate', '-L', f'{size_gb}G', '-n', vol_name, 'vg_gpudev'] + if snapshot_id: + cmd.extend(['--snapshot', snapshot_id]) + subprocess.run(cmd, check=True) + return VolumeInfo(volume_id=vol_name, ...) + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement create_volume() for your storage backend." + ) + + def delete_volume(self, volume_id: str) -> bool: + """ + Delete a block storage volume. + + Example LVM implementation: + subprocess.run(['lvremove', '-f', f'vg_gpudev/{volume_id}'], check=True) + return True + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement delete_volume() for your storage backend." + ) + + def attach_volume( + self, volume_id: str, instance_id: str, device_path: str + ) -> bool: + """ + Attach volume to instance. + + For Kubernetes-based workloads, this typically means: + 1. Make the volume accessible on the node (iSCSI login, RBD map, etc.) + 2. Create a PersistentVolume pointing to the device + 3. Let Kubernetes handle the pod mounting + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement attach_volume() for your storage backend." + ) + + def detach_volume(self, volume_id: str) -> bool: + """ + Detach volume from instance. + + Ensure the volume is properly unmounted before detaching. + For iSCSI: logout from target + For RBD: unmap the device + For NFS: unmount the share + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement detach_volume() for your storage backend." + ) + + def get_volume(self, volume_id: str) -> VolumeInfo | None: + """ + Get volume information. + + Example LVM implementation: + result = subprocess.run( + ['lvs', '--noheadings', '-o', 'lv_size,lv_attr', f'vg_gpudev/{volume_id}'], + capture_output=True, text=True + ) + if result.returncode != 0: + return None + # Parse output and return VolumeInfo + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement get_volume() for your storage backend." + ) + + def list_volumes( + self, filters: dict[str, str] | None = None + ) -> list[VolumeInfo]: + """ + List volumes matching filters. + + Note: For backends without native tagging, store tags in a local database + or use naming conventions to encode metadata. + """ + raise NotImplementedError( + f"Custom storage ({self.storage_backend}) not implemented. " + "Implement list_volumes() for your storage backend." + ) + + # === Snapshots === + + def create_snapshot( + self, + volume_id: str, + description: str = "", + tags: dict[str, str] | None = None, + ) -> SnapshotInfo: + """ + Create a snapshot of a volume. + + Example LVM implementation: + import uuid + snap_name = f"snap-{uuid.uuid4().hex[:8]}" + subprocess.run([ + 'lvcreate', '--snapshot', + '-L', '10G', # COW pool size + '-n', snap_name, + f'vg_gpudev/{volume_id}' + ], check=True) + return SnapshotInfo(snapshot_id=snap_name, ...) + """ + raise NotImplementedError( + f"Custom snapshots ({self.storage_backend}) not implemented. " + "Implement create_snapshot() for your storage backend." + ) + + def delete_snapshot(self, snapshot_id: str) -> bool: + """Delete a snapshot.""" + raise NotImplementedError( + f"Custom snapshots ({self.storage_backend}) not implemented. " + "Implement delete_snapshot() for your storage backend." + ) + + def get_snapshot(self, snapshot_id: str) -> SnapshotInfo | None: + """Get snapshot information.""" + raise NotImplementedError( + f"Custom snapshots ({self.storage_backend}) not implemented. " + "Implement get_snapshot() for your storage backend." + ) + + def list_snapshots( + self, + filters: dict[str, str] | None = None, + volume_id: str | None = None, + status: list[str] | None = None, + use_pagination: bool = True, + ) -> list[SnapshotInfo]: + """List snapshots matching filters.""" + raise NotImplementedError( + f"Custom snapshots ({self.storage_backend}) not implemented. " + "Implement list_snapshots() for your storage backend." + ) + + def wait_for_snapshot( + self, snapshot_id: str, timeout_seconds: int = 600 + ) -> bool: + """ + Wait for snapshot to complete. + + For LVM/ZFS snapshots, this is typically instant. + For storage arrays, poll the API until complete. + """ + raise NotImplementedError( + f"Custom snapshots ({self.storage_backend}) not implemented. " + "Implement wait_for_snapshot() for your storage backend." + ) + + # === Compute === + + def get_nodes_by_gpu_type(self, gpu_type: str) -> list[NodeInfo]: + """ + Get nodes/instances by GPU type. + + For Kubernetes-based deployments, query K8s API: + from kubernetes import client + v1 = client.CoreV1Api() + nodes = v1.list_node(label_selector=f'gpu-type={gpu_type}') + """ + raise NotImplementedError( + "Custom compute not implemented. " + "Query via Kubernetes API instead." + ) + + def get_node_availability(self) -> dict[str, dict[str, int]]: + """ + Get GPU availability by type. + + This is typically handled by the availability-updater-service + which queries Kubernetes for GPU allocations. + """ + raise NotImplementedError( + "Handled by availability updater service." + ) + + # === Object Storage === + + def upload_to_object_storage( + self, + bucket: str, + key: str, + content: bytes, + metadata: dict[str, str] | None = None, + content_type: str = "application/octet-stream", + ) -> str: + """ + Upload content to object storage. + + Example MinIO/S3-compatible implementation: + import boto3 + s3 = boto3.client('s3', endpoint_url=os.environ['MINIO_ENDPOINT']) + s3.put_object(Bucket=bucket, Key=key, Body=content, + ContentType=content_type, Metadata=metadata or {}) + return f's3://{bucket}/{key}' + + Example filesystem implementation: + import os + path = os.path.join(self.object_store_path, bucket, key) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'wb') as f: + f.write(content) + return f'file://{path}' + """ + raise NotImplementedError( + "Custom object storage not implemented. " + "Implement upload_to_object_storage() for your storage backend." + ) + + def download_from_object_storage( + self, bucket: str, key: str + ) -> bytes | None: + """ + Download content from object storage. + + Example filesystem implementation: + path = os.path.join(self.object_store_path, bucket, key) + if os.path.exists(path): + with open(path, 'rb') as f: + return f.read() + return None + """ + raise NotImplementedError( + "Custom object storage not implemented. " + "Implement download_from_object_storage() for your storage backend." + ) + + +class CustomAuthProvider(AuthProvider): + """ + Template for custom authentication provider. + + Common patterns: + + LDAP/Active Directory: + from ldap3 import Server, Connection, ALL + server = Server('ldap://ad.example.com', get_info=ALL) + conn = Connection(server, user=bind_dn, password=bind_pw) + conn.bind() + conn.search('dc=example,dc=com', f'(uid={username})', attributes=['memberOf']) + + OIDC (Keycloak, Okta): + from jose import jwt + payload = jwt.decode(token, key, algorithms=['RS256'], audience='gpu-dev') + return {'user_id': payload['sub'], 'email': payload['email'], ...} + + SAML: + from onelogin.saml2.auth import OneLogin_Saml2_Auth + auth = OneLogin_Saml2_Auth(request_data, saml_settings) + auth.process_response() + return {'user_id': auth.get_nameid(), ...} + """ + + def __init__(self): + self.backend = os.environ.get("CUSTOM_AUTH_BACKEND", "oidc") + + def verify_token(self, token: str) -> dict[str, Any] | None: + """ + Verify authentication token. + + Example OIDC implementation: + from jose import jwt + try: + payload = jwt.decode( + token, + self.public_key, + algorithms=['RS256'], + audience=os.environ.get('OIDC_AUDIENCE') + ) + return { + 'user_id': payload['sub'], + 'email': payload.get('email'), + 'groups': payload.get('groups', []) + } + except jwt.JWTError: + return None + """ + raise NotImplementedError( + f"Custom auth ({self.backend}) not implemented. " + "Implement verify_token() for your auth backend." + ) + + def get_user_info(self, token: str) -> dict[str, Any] | None: + """Get user information from token.""" + raise NotImplementedError( + f"Custom auth ({self.backend}) not implemented. " + "Implement get_user_info() for your auth backend." + ) + + def create_api_key( + self, user_id: str, scopes: list[str], ttl_hours: int = 24 + ) -> str: + """ + Create an API key for a user. + + This is typically handled by the API service using database-backed + API keys rather than cloud provider tokens. + """ + raise NotImplementedError("Use API service for API key creation") diff --git a/terraform-gpu-devservers/providers/gcp.py b/terraform-gpu-devservers/providers/gcp.py new file mode 100644 index 00000000..0acfdaca --- /dev/null +++ b/terraform-gpu-devservers/providers/gcp.py @@ -0,0 +1,192 @@ +""" +GCP Cloud Provider Implementation (Stub) + +This is a template for GCP support. Implement the methods +using Google Cloud SDK. +""" + +import logging +from typing import Dict, List, Optional, Any + +from .base import ( + CloudProvider, + VolumeInfo, + SnapshotInfo, + NodeInfo, +) + +logger = logging.getLogger(__name__) + + +class GCPProvider(CloudProvider): + """ + GCP implementation of CloudProvider interface. + + TODO: Implement using google-cloud-compute SDK + """ + + def __init__(self, project: str, zone: str): + self.project = project + self.zone = zone + self.region = zone.rsplit("-", 1)[0] # us-central1-a -> us-central1 + + # Initialize GCP clients (uncomment when implementing) + # from google.cloud import compute_v1 + # self.disks_client = compute_v1.DisksClient() + # self.snapshots_client = compute_v1.SnapshotsClient() + # self.instances_client = compute_v1.InstancesClient() + + def name(self) -> str: + return "gcp" + + # === Block Storage (GCE Persistent Disk) === + + def create_volume( + self, + size_gb: int, + availability_zone: str, + volume_type: str = "ssd", + tags: Optional[Dict[str, str]] = None, + snapshot_id: Optional[str] = None, + ) -> VolumeInfo: + """ + Create a GCE Persistent Disk. + + volume_type mapping: + - ssd -> pd-ssd + - hdd -> pd-standard + - balanced -> pd-balanced + """ + raise NotImplementedError( + "GCP volume creation not implemented. " + "Use google.cloud.compute_v1.DisksClient.insert()" + ) + + def delete_volume(self, volume_id: str) -> bool: + """Delete a GCE Persistent Disk.""" + raise NotImplementedError( + "GCP volume deletion not implemented. " + "Use google.cloud.compute_v1.DisksClient.delete()" + ) + + def attach_volume( + self, volume_id: str, instance_id: str, device_path: str + ) -> bool: + """Attach disk to GCE instance.""" + raise NotImplementedError( + "GCP volume attachment not implemented. " + "Use google.cloud.compute_v1.InstancesClient.attach_disk()" + ) + + def detach_volume(self, volume_id: str) -> bool: + """Detach disk from GCE instance.""" + raise NotImplementedError( + "GCP volume detachment not implemented. " + "Use google.cloud.compute_v1.InstancesClient.detach_disk()" + ) + + def get_volume(self, volume_id: str) -> Optional[VolumeInfo]: + """Get disk information.""" + raise NotImplementedError( + "GCP volume get not implemented. " + "Use google.cloud.compute_v1.DisksClient.get()" + ) + + def list_volumes( + self, filters: Optional[Dict[str, str]] = None + ) -> List[VolumeInfo]: + """List disks matching filters (labels in GCP).""" + raise NotImplementedError( + "GCP volume list not implemented. " + "Use google.cloud.compute_v1.DisksClient.list()" + ) + + # === Snapshots === + + def create_snapshot( + self, + volume_id: str, + description: str = "", + tags: Optional[Dict[str, str]] = None, + ) -> SnapshotInfo: + """Create disk snapshot.""" + raise NotImplementedError( + "GCP snapshot creation not implemented. " + "Use google.cloud.compute_v1.SnapshotsClient.insert()" + ) + + def delete_snapshot(self, snapshot_id: str) -> bool: + """Delete snapshot.""" + raise NotImplementedError( + "GCP snapshot deletion not implemented. " + "Use google.cloud.compute_v1.SnapshotsClient.delete()" + ) + + def get_snapshot(self, snapshot_id: str) -> Optional[SnapshotInfo]: + """Get snapshot information.""" + raise NotImplementedError( + "GCP snapshot get not implemented. " + "Use google.cloud.compute_v1.SnapshotsClient.get()" + ) + + def list_snapshots( + self, + filters: Optional[Dict[str, str]] = None, + volume_id: Optional[str] = None, + status: Optional[List[str]] = None, + use_pagination: bool = True, + ) -> List[SnapshotInfo]: + """List snapshots.""" + raise NotImplementedError( + "GCP snapshot list not implemented. " + "Use google.cloud.compute_v1.SnapshotsClient.list()" + ) + + def wait_for_snapshot( + self, snapshot_id: str, timeout_seconds: int = 600 + ) -> bool: + """Wait for snapshot to complete.""" + raise NotImplementedError( + "GCP snapshot wait not implemented. " + "Poll SnapshotsClient.get() until status is READY" + ) + + # === Compute === + + def get_nodes_by_gpu_type(self, gpu_type: str) -> List[NodeInfo]: + """Get GCE instances by GPU type.""" + raise NotImplementedError( + "GCP node listing not implemented. " + "Query via Kubernetes API instead." + ) + + def get_node_availability(self) -> Dict[str, Dict[str, int]]: + """Get GPU availability.""" + raise NotImplementedError( + "Handled by availability updater service." + ) + + # === Object Storage (GCS) === + + def upload_to_object_storage( + self, + bucket: str, + key: str, + content: bytes, + metadata: Optional[Dict[str, str]] = None, + content_type: str = "application/octet-stream", + ) -> str: + """Upload to Google Cloud Storage.""" + raise NotImplementedError( + "GCS upload not implemented. " + "Use google.cloud.storage.Client().bucket().blob().upload_from_string()" + ) + + def download_from_object_storage( + self, bucket: str, key: str + ) -> Optional[bytes]: + """Download from Google Cloud Storage.""" + raise NotImplementedError( + "GCS download not implemented. " + "Use google.cloud.storage.Client().bucket().blob().download_as_bytes()" + ) diff --git a/terraform-gpu-devservers/shared/snapshot_utils.py b/terraform-gpu-devservers/shared/snapshot_utils.py index f44a2c4f..14bd2a11 100644 --- a/terraform-gpu-devservers/shared/snapshot_utils.py +++ b/terraform-gpu-devservers/shared/snapshot_utils.py @@ -1,101 +1,115 @@ """ Shared snapshot utilities for GPU development server services + +This module provides cloud-agnostic snapshot management using the provider +abstraction layer. It supports AWS, GCP, and custom storage backends. """ -import boto3 import time import logging import os -import subprocess -import json +from datetime import datetime, timedelta, UTC from kubernetes import client from kubernetes.stream import stream -from decimal import Decimal from .db_pool import get_db_cursor +# Import provider interface - lazy loaded to avoid circular imports +_provider = None + +def _get_provider(): + """Get the cloud provider instance (lazy initialization).""" + global _provider + if _provider is None: + import sys + # Add parent directory to path if providers module not found + parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + if parent_dir not in sys.path: + sys.path.insert(0, parent_dir) + from providers import get_cloud_provider + _provider = get_cloud_provider() + return _provider + logger = logging.getLogger(__name__) -ec2_client = boto3.client("ec2") -s3_client = boto3.client("s3") def safe_create_snapshot(volume_id, user_id, snapshot_type="shutdown", disk_name=None, content_s3_path=None, disk_size=None): """ Safely create snapshot, avoiding duplicates if one is already in progress. - + Returns (snapshot_id, was_created) on success. - + IMPORTANT: If snapshot creation succeeds but database update fails, this function will attempt to delete the snapshot and raise an exception to prevent inconsistent state. - The operation is atomic: both AWS snapshot AND database update must succeed. + The operation is atomic: both cloud snapshot AND database update must succeed. Args: - volume_id: EBS volume ID + volume_id: Volume ID (cloud-provider-specific format) user_id: User identifier (email or username) snapshot_type: Type of snapshot (shutdown, migration, etc.) disk_name: Named disk identifier (for tagged disks) - if provided, database will be updated - content_s3_path: S3 path to disk contents listing + content_s3_path: Object storage path to disk contents listing disk_size: Disk usage size (e.g., "1.2G") from du -sh - + Returns: tuple: (snapshot_id, was_created) where was_created is True for new snapshots, False for existing - + Raises: Exception: If snapshot creation fails, or if database update fails (after attempting cleanup) """ + provider = _get_provider() + try: logger.info(f"Checking for existing snapshots for volume {volume_id}") # Check for any in-progress snapshots for this volume - ongoing_response = ec2_client.describe_snapshots( - OwnerIds=["self"], - Filters=[ - {"Name": "volume-id", "Values": [volume_id]}, - {"Name": "status", "Values": ["pending"]} - ] + ongoing_snapshots = provider.list_snapshots( + volume_id=volume_id, + status=["pending"], + use_pagination=False # Small result set expected ) - ongoing_snapshots = ongoing_response.get('Snapshots', []) if ongoing_snapshots: - latest_ongoing = max(ongoing_snapshots, key=lambda s: s['StartTime']) - logger.info(f"Found ongoing snapshot {latest_ongoing['SnapshotId']} for volume {volume_id}") - return latest_ongoing['SnapshotId'], False + # Sort by created_at and get latest + latest_ongoing = max(ongoing_snapshots, key=lambda s: s.created_at) + logger.info(f"Found ongoing snapshot {latest_ongoing.snapshot_id} for volume {volume_id}") + return latest_ongoing.snapshot_id, False # No ongoing snapshots - create a new one logger.info(f"Creating new {snapshot_type} snapshot for volume {volume_id}") timestamp = int(time.time()) - tags = [ - {"Key": "Name", "Value": f"gpu-dev-{snapshot_type}-{user_id.split('@')[0]}-{timestamp}"}, - {"Key": "gpu-dev-user", "Value": user_id}, - {"Key": "gpu-dev-snapshot-type", "Value": snapshot_type}, - {"Key": "SnapshotType", "Value": snapshot_type}, - {"Key": "created_at", "Value": str(timestamp)}, - ] + # Build tags dict for provider + tags = { + "Name": f"gpu-dev-{snapshot_type}-{user_id.split('@')[0]}-{timestamp}", + "gpu-dev-user": user_id, + "gpu-dev-snapshot-type": snapshot_type, + "SnapshotType": snapshot_type, + "created_at": str(timestamp), + } - # Add disk_name tag if provided + # Add optional tags if disk_name: - tags.append({"Key": "disk_name", "Value": disk_name}) - - # Add content_s3_path tag if provided + tags["disk_name"] = disk_name if content_s3_path: - tags.append({"Key": "snapshot_content_s3", "Value": content_s3_path}) + tags["snapshot_content_s3"] = content_s3_path + if disk_size: + tags["disk_size"] = disk_size - # Add disk_size tag if provided + description = f"gpu-dev {snapshot_type} snapshot for {user_id}" + if disk_name: + description += f" (disk: {disk_name})" if disk_size: - tags.append({"Key": "disk_size", "Value": disk_size}) - - snapshot_response = ec2_client.create_snapshot( - VolumeId=volume_id, - Description=f"gpu-dev {snapshot_type} snapshot for {user_id}" + (f" (disk: {disk_name})" if disk_name else "") + (f" ({disk_size})" if disk_size else ""), - TagSpecifications=[{ - "ResourceType": "snapshot", - "Tags": tags - }] + description += f" ({disk_size})" + + snapshot_info = provider.create_snapshot( + volume_id=volume_id, + description=description, + tags=tags, ) - snapshot_id = snapshot_response["SnapshotId"] + snapshot_id = snapshot_info.snapshot_id logger.info(f"Created new snapshot {snapshot_id} for volume {volume_id}" + (f" (disk: {disk_name})" if disk_name else "") + (f" size: {disk_size}" if disk_size else "")) # Update PostgreSQL to mark disk as backing up @@ -110,31 +124,31 @@ def safe_create_snapshot(volume_id, user_id, snapshot_type="shutdown", disk_name pending_snapshot_count = COALESCE(pending_snapshot_count, 0) + 1 WHERE user_id = %s AND disk_name = %s """, (user_id, disk_name)) - + # Verify the update actually affected a row if cur.rowcount == 0: raise Exception(f"Disk '{disk_name}' not found in database for user {user_id}") - + logger.debug(f"Updated database for disk '{disk_name}' - marked as backing up") except Exception as db_error: # Database update failed - snapshot created but database state is inconsistent - # This typically means the disk is orphaned (exists in AWS but not in database) + # This typically means the disk is orphaned (exists in cloud but not in database) logger.error( f"CRITICAL: Snapshot {snapshot_id} created successfully, " f"but database update failed for disk '{disk_name}': {db_error}" ) - - # Clean up both the snapshot and the orphaned volume + + # Clean up both the snapshot and the orphaned volume using provider try: logger.warning(f"Attempting to delete snapshot {snapshot_id} to maintain consistency") - ec2_client.delete_snapshot(SnapshotId=snapshot_id) + provider.delete_snapshot(snapshot_id) logger.info(f"Successfully deleted snapshot {snapshot_id}") except Exception as cleanup_error: logger.error( f"Failed to delete snapshot {snapshot_id}: {cleanup_error}. " f"Snapshot exists but is not tracked in database. Manual cleanup required!" ) - + # If disk not found in database, also delete the orphaned volume if "not found in database" in str(db_error).lower(): try: @@ -142,14 +156,14 @@ def safe_create_snapshot(volume_id, user_id, snapshot_type="shutdown", disk_name f"Disk '{disk_name}' not found in database - " f"deleting orphaned volume {volume_id}" ) - ec2_client.delete_volume(VolumeId=volume_id) + provider.delete_volume(volume_id) logger.info(f"Successfully deleted orphaned volume {volume_id}") except Exception as volume_cleanup_error: logger.error( f"Failed to delete orphaned volume {volume_id}: {volume_cleanup_error}. " f"Manual cleanup may be required." ) - + # Propagate the error so caller knows the operation failed raise Exception( f"Snapshot creation failed: database update error for disk '{disk_name}': {db_error}" @@ -259,34 +273,27 @@ def cleanup_old_snapshots(user_id, keep_count=3, max_age_days=7, max_deletions_p """ Clean up old snapshots for a user, keeping only the most recent ones. Keeps 'keep_count' newest snapshots and deletes any older than max_age_days. - Limited to max_deletions_per_run to prevent lambda timeouts. + Limited to max_deletions_per_run to prevent service timeouts. Returns number of snapshots deleted. """ - try: - from datetime import datetime, timedelta, UTC + provider = _get_provider() + try: logger.info(f"Cleaning up old snapshots for user {user_id}") - # Get all snapshots for this user (with pagination) - paginator = ec2_client.get_paginator('describe_snapshots') - page_iterator = paginator.paginate( - OwnerIds=["self"], - Filters=[ - {"Name": "tag:gpu-dev-user", "Values": [user_id]}, - {"Name": "status", "Values": ["completed"]} - ], - PaginationConfig={'PageSize': 100} + # Get all completed snapshots for this user using provider + snapshots = provider.list_snapshots( + filters={"gpu-dev-user": user_id}, + status=["completed"], + use_pagination=True ) - snapshots = [] - for page in page_iterator: - snapshots.extend(page.get('Snapshots', [])) if len(snapshots) <= keep_count: logger.debug(f"User {user_id} has {len(snapshots)} snapshots, no cleanup needed") return 0 - # Sort by creation time (newest first) - snapshots.sort(key=lambda s: s['StartTime'], reverse=True) + # Sort by creation time (newest first) - created_at is ISO format string + snapshots.sort(key=lambda s: s.created_at, reverse=True) cutoff_date = datetime.now(UTC) - timedelta(days=max_age_days) deleted_count = 0 @@ -297,8 +304,9 @@ def cleanup_old_snapshots(user_id, keep_count=3, max_age_days=7, max_deletions_p logger.info(f"Reached max deletions per run ({max_deletions_per_run}) for user {user_id}") break - snapshot_id = snapshot['SnapshotId'] - snapshot_date = snapshot['StartTime'].replace(tzinfo=None) + snapshot_id = snapshot.snapshot_id + # Parse ISO format timestamp + snapshot_date = datetime.fromisoformat(snapshot.created_at.replace('Z', '+00:00')) # Keep the newest 'keep_count' snapshots if i < keep_count: @@ -309,7 +317,7 @@ def cleanup_old_snapshots(user_id, keep_count=3, max_age_days=7, max_deletions_p if snapshot_date < cutoff_date or i >= keep_count: try: logger.info(f"Deleting old snapshot {snapshot_id} from {snapshot_date}") - ec2_client.delete_snapshot(SnapshotId=snapshot_id) + provider.delete_snapshot(snapshot_id) deleted_count += 1 except Exception as delete_error: logger.warning(f"Could not delete snapshot {snapshot_id}: {delete_error}") @@ -327,49 +335,38 @@ def get_latest_snapshot(user_id, volume_id=None, include_pending=False): Get the most recent snapshot for a user. If volume_id provided, gets snapshots for that specific volume. If include_pending is True, includes pending snapshots. - Returns the latest snapshot dict or None. + Returns the latest SnapshotInfo or None. """ + provider = _get_provider() + try: status_values = ["completed"] if include_pending: - status_values.extend(["pending"]) - - filters = [ - {"Name": "tag:gpu-dev-user", "Values": [user_id]}, - {"Name": "status", "Values": status_values}, - ] - - if volume_id: - filters.append({"Name": "volume-id", "Values": [volume_id]}) - - # Use pagination to handle users with many snapshots - paginator = ec2_client.get_paginator('describe_snapshots') - page_iterator = paginator.paginate( - OwnerIds=["self"], - Filters=filters, - PaginationConfig={'PageSize': 100} + status_values.append("pending") + + # Get snapshots using provider + snapshots = provider.list_snapshots( + filters={"gpu-dev-user": user_id}, + volume_id=volume_id, + status=status_values, + use_pagination=True ) - snapshots = [] - for page in page_iterator: - snapshots.extend(page.get('Snapshots', [])) - # Filter out soft-deleted snapshots (those with delete-date tag) - active_snapshots = [] - for snap in snapshots: - tags = {tag['Key']: tag['Value'] for tag in snap.get('Tags', [])} - if 'delete-date' not in tags: - active_snapshots.append(snap) + active_snapshots = [ + snap for snap in snapshots + if 'delete-date' not in snap.tags + ] if not active_snapshots: status_desc = "completed or pending" if include_pending else "completed" logger.info(f"No {status_desc} snapshots found for user {user_id}") return None - # Get most recent snapshot by start time - latest_snapshot = max(active_snapshots, key=lambda s: s['StartTime']) + # Get most recent snapshot by creation time + latest_snapshot = max(active_snapshots, key=lambda s: s.created_at) logger.info( - f"Found latest snapshot {latest_snapshot['SnapshotId']} ({latest_snapshot['State']}) for user {user_id}") + f"Found latest snapshot {latest_snapshot.snapshot_id} ({latest_snapshot.status}) for user {user_id}") return latest_snapshot except Exception as e: @@ -381,29 +378,22 @@ def cleanup_all_user_snapshots(max_users_per_run=20): """ Run scheduled cleanup of old snapshots for all users. This runs separately from expiry processing. - Limited to max_users_per_run to prevent lambda timeouts. + Limited to max_users_per_run to prevent service timeouts. """ + provider = _get_provider() + try: logger.info("Starting scheduled snapshot cleanup for all users") - # Get all gpu-dev snapshots grouped by user (with pagination) - paginator = ec2_client.get_paginator('describe_snapshots') - page_iterator = paginator.paginate( - OwnerIds=["self"], - Filters=[ - {"Name": "tag-key", "Values": ["gpu-dev-user"]}, - ], - PaginationConfig={'PageSize': 100} - ) - - all_snapshots = [] - for page in page_iterator: - all_snapshots.extend(page.get('Snapshots', [])) + # Get all gpu-dev snapshots (those with gpu-dev-user tag) + # Note: We need to get all snapshots and group by user since + # provider interface doesn't support "tag-key exists" filter + all_snapshots = provider.list_snapshots(use_pagination=True) # Group snapshots by user users_snapshots = {} for snapshot in all_snapshots: - user_tag = next((tag['Value'] for tag in snapshot['Tags'] if tag['Key'] == 'gpu-dev-user'), None) + user_tag = snapshot.tags.get('gpu-dev-user') if user_tag: if user_tag not in users_snapshots: users_snapshots[user_tag] = [] @@ -435,8 +425,8 @@ def cleanup_all_user_snapshots(max_users_per_run=20): def capture_disk_contents(pod_name, namespace, user_id, disk_name, snapshot_id, k8s_client=None, mount_path="/workspace"): """ - Capture disk contents via Kubernetes API exec and upload to S3. - Returns tuple (s3_path, disk_size) or (None, None) if failed. + Capture disk contents via Kubernetes API exec and upload to object storage. + Returns tuple (storage_uri, disk_size) or (None, None) if failed. Args: pod_name: Kubernetes pod name @@ -448,8 +438,10 @@ def capture_disk_contents(pod_name, namespace, user_id, disk_name, snapshot_id, mount_path: Mount point in pod (default: /workspace) Returns: - tuple: (s3_path, disk_size) where disk_size is like "1.2G" or None if failed + tuple: (storage_uri, disk_size) where disk_size is like "1.2G" or None if failed """ + provider = _get_provider() + try: bucket_name = os.environ.get('DISK_CONTENTS_BUCKET') if not bucket_name: @@ -517,11 +509,10 @@ def capture_disk_contents(pod_name, namespace, user_id, disk_name, snapshot_id, logger.warning(f"Kubernetes exec failed: {exec_error}") contents = f"Failed to capture contents: {str(exec_error)}\n\nThis snapshot was created but contents could not be listed." - # Upload to S3 - s3_key = f"{user_id}/{disk_name}/{snapshot_id}-contents.txt" - s3_path = f"s3://{bucket_name}/{s3_key}" + # Upload to object storage using provider + object_key = f"{user_id}/{disk_name}/{snapshot_id}-contents.txt" - logger.info(f"Uploading disk contents to {s3_path}") + logger.info(f"Uploading disk contents to {bucket_name}/{object_key}") metadata = { 'user_id': user_id, @@ -535,79 +526,87 @@ def capture_disk_contents(pod_name, namespace, user_id, disk_name, snapshot_id, if disk_size: metadata['disk_size'] = disk_size - s3_client.put_object( - Bucket=bucket_name, - Key=s3_key, - Body=contents.encode('utf-8'), - ContentType='text/plain', - Metadata=metadata + storage_uri = provider.upload_to_object_storage( + bucket=bucket_name, + key=object_key, + content=contents.encode('utf-8'), + metadata=metadata, + content_type='text/plain' ) - logger.info(f"Successfully uploaded disk contents to {s3_path}") - return s3_path, disk_size + logger.info(f"Successfully uploaded disk contents to {storage_uri}") + return storage_uri, disk_size except Exception as e: logger.error(f"Error capturing disk contents: {str(e)}") return None, None -def get_snapshot_contents(snapshot_id=None, s3_path=None): +def get_snapshot_contents(snapshot_id=None, storage_uri=None): """ - Fetch snapshot contents from S3. - Either snapshot_id or s3_path must be provided. + Fetch snapshot contents from object storage. + Either snapshot_id or storage_uri must be provided. Args: - snapshot_id: Snapshot ID to fetch contents for (will look up S3 path from tags) - s3_path: Direct S3 path (e.g., s3://bucket/user/disk/snap-123-contents.txt) + snapshot_id: Snapshot ID to fetch contents for (will look up storage path from tags) + storage_uri: Direct storage URI (e.g., s3://bucket/user/disk/snap-123-contents.txt) Returns: str: Contents text or None if not found """ + provider = _get_provider() + try: - # If snapshot_id provided, look up S3 path from tags - if snapshot_id and not s3_path: - logger.info(f"Looking up S3 path for snapshot {snapshot_id}") - response = ec2_client.describe_snapshots(SnapshotIds=[snapshot_id]) + # If snapshot_id provided, look up storage path from tags + if snapshot_id and not storage_uri: + logger.info(f"Looking up storage path for snapshot {snapshot_id}") + snapshot = provider.get_snapshot(snapshot_id) - if not response.get('Snapshots'): + if not snapshot: logger.error(f"Snapshot {snapshot_id} not found") return None - snapshot = response['Snapshots'][0] - tags = {tag['Key']: tag['Value'] for tag in snapshot.get('Tags', [])} - s3_path = tags.get('snapshot_content_s3') + storage_uri = snapshot.tags.get('snapshot_content_s3') - if not s3_path: - logger.warning(f"Snapshot {snapshot_id} has no content_s3_path tag") + if not storage_uri: + logger.warning(f"Snapshot {snapshot_id} has no content storage path tag") return None - if not s3_path: - logger.error("No S3 path provided or found") - return None - - # Parse S3 path (s3://bucket/key) - if not s3_path.startswith('s3://'): - logger.error(f"Invalid S3 path format: {s3_path}") + if not storage_uri: + logger.error("No storage path provided or found") return None - path_parts = s3_path[5:].split('/', 1) # Remove 's3://' and split bucket/key - if len(path_parts) != 2: - logger.error(f"Invalid S3 path format: {s3_path}") + # Parse storage URI (s3://bucket/key or gs://bucket/key or file://path) + if storage_uri.startswith('s3://') or storage_uri.startswith('gs://'): + path_parts = storage_uri[5:].split('/', 1) # Remove 's3://' or 'gs://' and split bucket/key + if len(path_parts) != 2: + logger.error(f"Invalid storage URI format: {storage_uri}") + return None + bucket_name, object_key = path_parts + elif storage_uri.startswith('file://'): + # Local filesystem path + file_path = storage_uri[7:] + if os.path.exists(file_path): + with open(file_path, 'r') as f: + return f.read() + else: + logger.error(f"File not found: {file_path}") + return None + else: + logger.error(f"Unsupported storage URI format: {storage_uri}") return None - bucket_name, s3_key = path_parts - - logger.info(f"Fetching disk contents from {s3_path}") + logger.info(f"Fetching disk contents from {storage_uri}") - response = s3_client.get_object(Bucket=bucket_name, Key=s3_key) - contents = response['Body'].read().decode('utf-8') + content_bytes = provider.download_from_object_storage(bucket_name, object_key) + if content_bytes is None: + logger.error(f"Object not found: {storage_uri}") + return None - logger.info(f"Successfully fetched {len(contents)} bytes from S3") + contents = content_bytes.decode('utf-8') + logger.info(f"Successfully fetched {len(contents)} bytes from storage") return contents - except s3_client.exceptions.NoSuchKey: - logger.error(f"S3 object not found: {s3_path}") - return None except Exception as e: logger.error(f"Error fetching snapshot contents: {str(e)}") return None