diff --git a/terraform-gpu-devservers/lambda/reservation_expiry/index.py b/terraform-gpu-devservers/lambda/reservation_expiry/index.py index 7249f0a8..cf647339 100644 --- a/terraform-gpu-devservers/lambda/reservation_expiry/index.py +++ b/terraform-gpu-devservers/lambda/reservation_expiry/index.py @@ -48,6 +48,36 @@ # Global Kubernetes client (reused across Lambda execution) _k8s_client = None +# Internal deadline tracking — set from context in handler() +# Reserve 30 seconds before Lambda timeout for graceful cleanup logging +_DEADLINE_BUFFER_SECONDS = 30 +_lambda_deadline = None + + +def set_lambda_deadline(context): + """Set internal deadline from Lambda context, with buffer for cleanup.""" + global _lambda_deadline + if context and hasattr(context, 'get_remaining_time_in_millis'): + remaining_ms = context.get_remaining_time_in_millis() + _lambda_deadline = time.time() + (remaining_ms / 1000) - _DEADLINE_BUFFER_SECONDS + logger.info(f"Lambda deadline set: {remaining_ms / 1000:.0f}s remaining, internal deadline in {remaining_ms / 1000 - _DEADLINE_BUFFER_SECONDS:.0f}s") + else: + # Fallback: assume 15 min timeout with buffer + _lambda_deadline = time.time() + 900 - _DEADLINE_BUFFER_SECONDS + logger.warning("No Lambda context available, using default 15min deadline") + + +def time_remaining() -> float: + """Return seconds remaining before internal deadline. Negative means overdue.""" + if _lambda_deadline is None: + return 870 # Default: assume plenty of time + return _lambda_deadline - time.time() + + +def is_deadline_approaching(min_seconds: float = 60) -> bool: + """Return True if less than min_seconds remain before the internal deadline.""" + return time_remaining() < min_seconds + def get_k8s_client(): """Get or create the global Kubernetes client (singleton pattern)""" @@ -318,6 +348,7 @@ def cleanup_soft_deleted_snapshots() -> int: def handler(event, context): """Main Lambda handler""" try: + set_lambda_deadline(context) current_time = int(time.time()) logger.info( f"Running reservation expiry and cleanup check at timestamp {current_time} ({datetime.fromtimestamp(current_time)})" @@ -441,6 +472,14 @@ def handler(event, context): # Process active reservations for expiry for reservation in active_reservations: + # Check deadline before processing each reservation + if is_deadline_approaching(min_seconds=120): + logger.warning( + f"Lambda deadline approaching ({time_remaining():.0f}s remaining) — " + f"stopping expiry processing. Remaining reservations will be processed in next invocation." + ) + break + expires_at_str = reservation.get("expires_at", "") try: expires_at = int( @@ -1511,9 +1550,53 @@ def create_warning_message(reservation: dict[str, Any], minutes_left: int) -> st def cleanup_pod(pod_name: str, namespace: str = "gpu-dev", reservation_data: dict = None) -> None: - """Clean up Kubernetes pod and associated resources""" + """Clean up Kubernetes pod and associated resources. + + Operation ordering is designed so that the most critical state updates + happen first. If the Lambda times out partway through, the system is + left in a recoverable state: + + 1. Mark disk as not-in-use in DynamoDB (prevents stuck disks) + 2. Delete DNS records and domain mappings (prevents stale routing) + 3. Delete ALB/NLB resources (prevents stale load balancer rules) + 4. Capture disk contents & initiate EBS snapshot (best-effort) + 5. Delete K8s service and pod (frees GPU resources) + 6. Wait for snapshot & clean up volume (best-effort, deadline-aware) + """ try: - logger.info(f"Cleaning up pod {pod_name} in namespace {namespace}") + logger.info(f"Cleaning up pod {pod_name} in namespace {namespace} ({time_remaining():.0f}s remaining)") + + # ===================================================================== + # PHASE 1: Critical state updates (must complete even if Lambda is + # about to time out — these are fast DynamoDB operations) + # ===================================================================== + + # Mark disk as not in use IMMEDIATELY so users aren't locked out + # of their persistent disk if later operations fail or time out + user_id = None + volume_id = None + disk_name = None + + if reservation_data: + user_id = reservation_data.get('user_id') + volume_id = reservation_data.get('ebs_volume_id') + disk_name = reservation_data.get('disk_name') + reservation_id = reservation_data.get('reservation_id') + + # Fallback: if disk_name not in reservation, look it up from disks table + if user_id and not disk_name and reservation_id: + disk_name = find_disk_by_reservation(user_id, reservation_id) + + if user_id and disk_name: + try: + mark_disk_not_in_use(user_id, disk_name) + logger.info(f"Marked disk '{disk_name}' as not in use (early — before cleanup)") + except Exception as mark_error: + logger.warning(f"Failed to mark disk as not in use early: {mark_error}") + + # ===================================================================== + # PHASE 2: Network cleanup (DNS + ALB) — fast, idempotent + # ===================================================================== # Clean up DNS records if domain is configured if get_dns_enabled() and reservation_data: @@ -1556,177 +1639,121 @@ def cleanup_pod(pod_name: str, namespace: str = "gpu-dev", reservation_data: dic logger.error(f"Error cleaning up ALB/NLB resources: {alb_error}") # Don't re-raise - continue with pod cleanup + # ===================================================================== + # PHASE 3: Snapshot & disk content capture (best-effort, skip if + # deadline is too close) + # ===================================================================== + # Configure Kubernetes client logger.info(f"Setting up Kubernetes client for cleanup...") k8s_client = get_k8s_client() v1 = client.CoreV1Api(k8s_client) logger.info(f"Kubernetes client configured successfully") - # Create shutdown snapshot if pod has persistent storage - try: - user_id = None - volume_id = None - disk_name = None - - # Get user_id, volume_id, and disk_name from reservation data if provided - if reservation_data: - user_id = reservation_data.get('user_id') - volume_id = reservation_data.get('ebs_volume_id') - disk_name = reservation_data.get('disk_name') - - # Quick check - if we have reservation data with EBS info, use it directly - if user_id and volume_id: - logger.info(f"Found persistent storage in reservation data: volume {volume_id} for user {user_id}") - - # If no reservation data or missing info, try to get from pod spec - elif not user_id or not volume_id: - try: - pod = v1.read_namespaced_pod(name=pod_name, namespace=namespace) + snapshot_id = None - # Extract user_id from pod labels or annotations - if pod.metadata.labels: - user_id = pod.metadata.labels.get('user-id') or user_id + # Only attempt snapshot operations if we have enough time + if is_deadline_approaching(min_seconds=120): + logger.warning( + f"Skipping snapshot operations — only {time_remaining():.0f}s remaining. " + f"Snapshot will not be created for this expiry." + ) + else: + try: + # If no reservation data or missing info, try to get from pod spec + if not user_id or not volume_id: + try: + pod = v1.read_namespaced_pod(name=pod_name, namespace=namespace) - # Look for EBS volume in pod spec - if pod.spec.volumes: - for volume in pod.spec.volumes: - if volume.aws_elastic_block_store: - # Extract volume ID from AWS EBS volume - volume_id = volume.aws_elastic_block_store.volume_id - break + # Extract user_id from pod labels or annotations + if pod.metadata.labels: + user_id = pod.metadata.labels.get('user-id') or user_id - except Exception as pod_read_error: - logger.warning(f"Could not read pod {pod_name} for snapshot info: {pod_read_error}") + # Look for EBS volume in pod spec + if pod.spec.volumes: + for volume in pod.spec.volumes: + if volume.aws_elastic_block_store: + volume_id = volume.aws_elastic_block_store.volume_id + break - # If disk_name not in reservation data, try to get it from volume tags - if volume_id and not disk_name: - try: - ec2_client = boto3.client('ec2') - vol_response = ec2_client.describe_volumes(VolumeIds=[volume_id]) - if vol_response['Volumes']: - tags = {tag['Key']: tag['Value'] for tag in vol_response['Volumes'][0].get('Tags', [])} - disk_name = tags.get('disk_name') - logger.info(f"Retrieved disk_name '{disk_name}' from volume tags") - except Exception as tag_error: - logger.warning(f"Could not read volume tags for disk_name: {tag_error}") - - # Create shutdown snapshot if we have the necessary info - if user_id and volume_id: - logger.info(f"Creating shutdown snapshot for user {user_id}, volume {volume_id}, disk {disk_name or 'unnamed'}") - - # Step 1: Capture disk contents before creating snapshot - content_s3_path = None - disk_size = None - if disk_name: + except Exception as pod_read_error: + logger.warning(f"Could not read pod {pod_name} for snapshot info: {pod_read_error}") + + # If disk_name not in reservation data, try to get it from volume tags + if volume_id and not disk_name: try: - logger.info(f"Capturing disk contents for disk '{disk_name}'") - # Create a temporary snapshot ID for the S3 path (we'll update after actual snapshot creation) - temp_snapshot_id = f"pending-{int(time.time())}" - content_s3_path, disk_size = capture_disk_contents( - pod_name=pod_name, - namespace=namespace, - user_id=user_id, - disk_name=disk_name, - snapshot_id=temp_snapshot_id, - k8s_client=get_k8s_client(), - mount_path="/home/dev" - ) - if content_s3_path: - logger.info(f"Successfully captured disk contents to {content_s3_path}" + (f" (size: {disk_size})" if disk_size else "")) - else: - logger.warning(f"Failed to capture disk contents for disk '{disk_name}'") - except Exception as capture_error: - logger.warning(f"Error capturing disk contents: {capture_error}") - # Continue with snapshot even if content capture fails - - # Step 2: Create snapshot with disk_name, content_s3_path, and disk_size - snapshot_id, was_created = safe_create_snapshot( - volume_id=volume_id, - user_id=user_id, - snapshot_type="shutdown", - disk_name=disk_name, - content_s3_path=content_s3_path, - disk_size=disk_size - ) + ec2 = boto3.client('ec2') + vol_response = ec2.describe_volumes(VolumeIds=[volume_id]) + if vol_response['Volumes']: + tags = {tag['Key']: tag['Value'] for tag in vol_response['Volumes'][0].get('Tags', [])} + disk_name = tags.get('disk_name') + logger.info(f"Retrieved disk_name '{disk_name}' from volume tags") + except Exception as tag_error: + logger.warning(f"Could not read volume tags for disk_name: {tag_error}") - if snapshot_id: - logger.info(f"Shutdown snapshot {snapshot_id} initiated for {pod_name}") + # Create shutdown snapshot if we have the necessary info + if user_id and volume_id: + logger.info(f"Creating shutdown snapshot for user {user_id}, volume {volume_id}, disk {disk_name or 'unnamed'}") - # Step 3: Wait for snapshot to complete (with timeout) - try: - logger.info(f"Waiting for snapshot {snapshot_id} to complete...") - ec2_client = boto3.client('ec2') - waiter = ec2_client.get_waiter('snapshot_completed') - waiter.wait( - SnapshotIds=[snapshot_id], - WaiterConfig={ - 'Delay': 15, # Check every 15 seconds - 'MaxAttempts': 120 # Wait up to 30 minutes (15s * 120 = 1800s) - } - ) - logger.info(f"Snapshot {snapshot_id} completed successfully") - - # Step 3.5: Update DynamoDB to reflect snapshot completion - if disk_name: - try: - # Get snapshot details to get volume size and content S3 path - snapshot_response = ec2_client.describe_snapshots(SnapshotIds=[snapshot_id]) - if snapshot_response.get('Snapshots'): - snapshot = snapshot_response['Snapshots'][0] - size_gb = snapshot.get('VolumeSize') - tags = {tag['Key']: tag['Value'] for tag in snapshot.get('Tags', [])} - snapshot_content_s3 = tags.get('snapshot_content_s3') - snapshot_disk_size = tags.get('disk_size') - logger.info(f"Updating DynamoDB for completed snapshot {snapshot_id} (disk: {disk_name}, size: {size_gb}GB, disk_size: {snapshot_disk_size})") - update_disk_snapshot_completed(user_id, disk_name, size_gb, snapshot_content_s3, snapshot_disk_size) - logger.info(f"Successfully updated DynamoDB for disk '{disk_name}'") - except Exception as update_error: - logger.warning(f"Error updating DynamoDB for snapshot completion: {update_error}") - # Don't fail cleanup if DynamoDB update fails - - # Step 4: Delete the EBS volume after snapshot completes + # Capture disk contents before creating snapshot (skip if tight on time) + content_s3_path = None + disk_size = None + if disk_name and not is_deadline_approaching(min_seconds=180): try: - logger.info(f"Deleting EBS volume {volume_id} after successful snapshot") - ec2_client.delete_volume(VolumeId=volume_id) - logger.info(f"Successfully deleted volume {volume_id}") - - # Step 5: Mark disk as no longer in use (allows CLI to show as available) - if disk_name: - try: - mark_disk_not_in_use(user_id, disk_name) - logger.info(f"Marked disk '{disk_name}' as not in use") - except Exception as mark_error: - logger.warning(f"Failed to mark disk as not in use: {mark_error}") - except Exception as delete_error: - logger.error(f"Failed to delete volume {volume_id}: {delete_error}") - # Don't fail the whole cleanup if volume deletion fails - - except Exception as waiter_error: - logger.warning(f"Error waiting for snapshot completion or deleting volume: {waiter_error}") - # Continue with pod deletion even if snapshot wait/delete fails + logger.info(f"Capturing disk contents for disk '{disk_name}'") + temp_snapshot_id = f"pending-{int(time.time())}" + content_s3_path, disk_size = capture_disk_contents( + pod_name=pod_name, + namespace=namespace, + user_id=user_id, + disk_name=disk_name, + snapshot_id=temp_snapshot_id, + k8s_client=get_k8s_client(), + mount_path="/home/dev" + ) + if content_s3_path: + logger.info(f"Successfully captured disk contents to {content_s3_path}" + (f" (size: {disk_size})" if disk_size else "")) + else: + logger.warning(f"Failed to capture disk contents for disk '{disk_name}'") + except Exception as capture_error: + logger.warning(f"Error capturing disk contents: {capture_error}") + elif disk_name: + logger.warning(f"Skipping disk content capture — only {time_remaining():.0f}s remaining") + + # Initiate snapshot (non-blocking creation) + snapshot_id, was_created = safe_create_snapshot( + volume_id=volume_id, + user_id=user_id, + snapshot_type="shutdown", + disk_name=disk_name, + content_s3_path=content_s3_path, + disk_size=disk_size + ) + if snapshot_id: + logger.info(f"Shutdown snapshot {snapshot_id} initiated for {pod_name}") + else: + logger.warning(f"Failed to create shutdown snapshot for {pod_name}") else: - logger.warning(f"Failed to create shutdown snapshot for {pod_name}") - else: - logger.info(f"No persistent storage found for pod {pod_name} - skipping shutdown snapshot") + logger.info(f"No persistent storage found for pod {pod_name} - skipping shutdown snapshot") - except Exception as snapshot_error: - logger.warning(f"Error creating shutdown snapshot for {pod_name}: {snapshot_error}") - # Continue with pod deletion even if snapshot fails + except Exception as snapshot_error: + logger.warning(f"Error during snapshot operations for {pod_name}: {snapshot_error}") + # Continue with pod deletion even if snapshot fails + + # ===================================================================== + # PHASE 4: Delete K8s resources (service + pod) — this frees GPUs + # ===================================================================== # Send final warning message before deletion try: - logger.info(f"Sending final warning message to pod {pod_name}") send_wall_message_to_pod( pod_name, "🚨 FINAL WARNING: Reservation expired! Pod will be deleted now. All unsaved work will be lost!", namespace, ) - logger.info(f"Final warning message sent to pod {pod_name}") except Exception as warn_error: - logger.warning( - f"Could not send final warning to pod {pod_name}: {warn_error}" - ) + logger.warning(f"Could not send final warning to pod {pod_name}: {warn_error}") # Delete the NodePort service first service_name = f"{pod_name}-ssh" @@ -1775,10 +1802,6 @@ def cleanup_pod(pod_name: str, namespace: str = "gpu-dev", reservation_data: dic logger.info(f"Pod cleanup completed successfully for {pod_name}") - # NOTE: EBS volumes (persistent disks) are deleted after snapshot creation - # Snapshots are used to recreate volumes for the user's next reservation - # This ensures clean state and prevents disk attachment conflicts - # Trigger availability table update after pod cleanup try: trigger_availability_update() @@ -1787,25 +1810,72 @@ def cleanup_pod(pod_name: str, namespace: str = "gpu-dev", reservation_data: dic logger.warning( f"Failed to trigger availability update after pod cleanup: {update_error}" ) - # Don't fail the expiry for this - # Final safety: ensure disk is marked as not in use - # This handles edge cases where volume deletion failed but pod is gone - if reservation_data: - final_user_id = reservation_data.get('user_id') - final_disk_name = reservation_data.get('disk_name') - final_reservation_id = reservation_data.get('reservation_id') + # ===================================================================== + # PHASE 5: Best-effort snapshot completion wait & volume cleanup + # Only if we have time remaining before Lambda deadline + # ===================================================================== - # Fallback: if disk_name not in reservation, look it up from disks table - if final_user_id and not final_disk_name and final_reservation_id: - final_disk_name = find_disk_by_reservation(final_user_id, final_reservation_id) + if snapshot_id and volume_id and user_id: + remaining = time_remaining() + if remaining > 60: + # Calculate how long we can afford to wait for the snapshot + # Leave 30s buffer for volume deletion + DynamoDB update + max_wait_seconds = max(30, remaining - 30) + max_attempts = max(2, int(max_wait_seconds / 15)) + + logger.info( + f"Waiting for snapshot {snapshot_id} completion " + f"(max {max_wait_seconds:.0f}s / {max_attempts} attempts, " + f"{remaining:.0f}s remaining)" + ) - if final_user_id and final_disk_name: try: - mark_disk_not_in_use(final_user_id, final_disk_name) - logger.info(f"Final cleanup: ensured disk '{final_disk_name}' is marked as not in use") - except Exception as final_disk_error: - logger.warning(f"Final disk cleanup failed (non-fatal): {final_disk_error}") + ec2 = boto3.client('ec2') + waiter = ec2.get_waiter('snapshot_completed') + waiter.wait( + SnapshotIds=[snapshot_id], + WaiterConfig={ + 'Delay': 15, + 'MaxAttempts': max_attempts + } + ) + logger.info(f"Snapshot {snapshot_id} completed successfully") + + # Update DynamoDB to reflect snapshot completion + if disk_name: + try: + snapshot_response = ec2.describe_snapshots(SnapshotIds=[snapshot_id]) + if snapshot_response.get('Snapshots'): + snapshot = snapshot_response['Snapshots'][0] + size_gb = snapshot.get('VolumeSize') + tags = {tag['Key']: tag['Value'] for tag in snapshot.get('Tags', [])} + snapshot_content_s3 = tags.get('snapshot_content_s3') + snapshot_disk_size = tags.get('disk_size') + logger.info(f"Updating DynamoDB for completed snapshot {snapshot_id} (disk: {disk_name}, size: {size_gb}GB)") + update_disk_snapshot_completed(user_id, disk_name, size_gb, snapshot_content_s3, snapshot_disk_size) + except Exception as update_error: + logger.warning(f"Error updating DynamoDB for snapshot completion: {update_error}") + + # Delete EBS volume after successful snapshot + try: + logger.info(f"Deleting EBS volume {volume_id} after successful snapshot") + ec2.delete_volume(VolumeId=volume_id) + logger.info(f"Successfully deleted volume {volume_id}") + except Exception as delete_error: + logger.error(f"Failed to delete volume {volume_id}: {delete_error}") + + except Exception as waiter_error: + logger.warning( + f"Snapshot {snapshot_id} did not complete within deadline " + f"({time_remaining():.0f}s remaining): {waiter_error}. " + f"Volume {volume_id} will be cleaned up on next run or manually." + ) + else: + logger.warning( + f"Not enough time to wait for snapshot {snapshot_id} " + f"({remaining:.0f}s remaining). Volume {volume_id} cleanup deferred." + ) except Exception as e: logger.error(f"Error cleaning up pod {pod_name}: {str(e)}")