diff --git a/lambda-ecs-durable-python-sam/README.md b/lambda-ecs-durable-python-sam/README.md index 1962094da..88013eb75 100644 --- a/lambda-ecs-durable-python-sam/README.md +++ b/lambda-ecs-durable-python-sam/README.md @@ -2,7 +2,7 @@ This pattern demonstrates how to invoke Amazon ECS tasks from AWS Lambda durable functions using Python. The workflow starts an ECS task, waits for a callback, and resumes based on the task result while maintaining state across the pause/resume cycle. -Learn more about this pattern at Serverless Land Patterns: https://serverlessland.com/patterns/lambda-ecs-python-sam +Learn more about this pattern at Serverless Land Patterns: https://serverlessland.com/patterns/lambda-ecs-durable-python-sam Important: this application uses various AWS services and there are costs associated with these services after the Free Tier usage - please see the [AWS Pricing page](https://aws.amazon.com/pricing/) for details. You are responsible for any AWS costs incurred. No warranty is implied in this example. diff --git a/lambda-ecs-durable-python-sam/src/callback_handler.py b/lambda-ecs-durable-python-sam/src/callback_handler.py index 39d223b73..ab54d5fde 100644 --- a/lambda-ecs-durable-python-sam/src/callback_handler.py +++ b/lambda-ecs-durable-python-sam/src/callback_handler.py @@ -1,6 +1,7 @@ import json import boto3 import os +import hashlib from aws_durable_execution_sdk_python import ( DurableContext, durable_execution, @@ -13,13 +14,18 @@ def start_ecs_task_with_callback(cluster, task_definition, subnet1, subnet2, sec """ Starts an ECS task and passes the callback token via environment variable. The ECS task will call Lambda durable execution callback APIs when complete. + Uses callback_token as idempotency token to prevent duplicate tasks on retry. """ print(f"[CALLBACK] Starting ECS task with callback token") + # Use callback token hash as clientToken for idempotency (max 64 chars) + client_token = hashlib.sha256(callback_token.encode()).hexdigest()[:64] + response = ecs_client.run_task( cluster=cluster, taskDefinition=task_definition, launchType='FARGATE', + clientToken=client_token, networkConfiguration={ 'awsvpcConfiguration': { 'subnets': [subnet1, subnet2], diff --git a/lambda-ecs-durable-python-sam/src/sync_handler.py b/lambda-ecs-durable-python-sam/src/sync_handler.py index 8ba2739e6..10f3b6bef 100644 --- a/lambda-ecs-durable-python-sam/src/sync_handler.py +++ b/lambda-ecs-durable-python-sam/src/sync_handler.py @@ -1,27 +1,36 @@ import json import boto3 import os +import hashlib from aws_durable_execution_sdk_python import ( DurableContext, durable_execution, durable_step, ) from aws_durable_execution_sdk_python.config import Duration +from aws_durable_execution_sdk_python.waits import WaitForConditionConfig, WaitForConditionDecision ecs_client = boto3.client('ecs') + @durable_step def start_ecs_task(step_context, cluster, task_definition, subnet1, subnet2, security_group, message, processing_time): """ Durable step that starts an ECS task. - This step is checkpointed, so if interrupted, it won't re-execute. + Uses a deterministic clientToken for idempotency in case of retry + before checkpoint is saved. """ step_context.logger.info(f"[SYNC] Starting ECS task with message: {message}") + # Generate deterministic idempotency token from inputs + token_input = f"{cluster}:{task_definition}:{message}:{processing_time}" + client_token = hashlib.sha256(token_input.encode()).hexdigest()[:64] + response = ecs_client.run_task( cluster=cluster, taskDefinition=task_definition, launchType='FARGATE', + clientToken=client_token, networkConfiguration={ 'awsvpcConfiguration': { 'subnets': [subnet1, subnet2], @@ -50,97 +59,81 @@ def start_ecs_task(step_context, cluster, task_definition, subnet1, subnet2, sec return task_arn -@durable_step -def check_task_status(step_context, cluster, task_arn): - """ - Durable step that checks ECS task status. - This step is checkpointed and can be retried if it fails. - """ - step_context.logger.info(f"[SYNC] Checking task status: {task_arn}") - + +def check_ecs_status(cluster, task_arn): + """Check ECS task status (called by wait_for_condition).""" describe_response = ecs_client.describe_tasks( cluster=cluster, tasks=[task_arn] ) if not describe_response['tasks']: - raise Exception(f"Task not found: {task_arn}") + return {'status': 'UNKNOWN', 'cluster': cluster, 'task_arn': task_arn} task = describe_response['tasks'][0] - last_status = task['lastStatus'] - - step_context.logger.info(f"[SYNC] Task status: {last_status}") - return { - 'status': last_status, - 'task': task + 'status': task['lastStatus'], + 'task': task, + 'cluster': cluster, + 'task_arn': task_arn } + @durable_execution def lambda_handler(event, context: DurableContext): """ - Lambda Durable Function that invokes an ECS task and waits for completion. - Uses the Durable Execution SDK for automatic checkpointing and replay. - - This function can run for up to 1 year, with automatic state management - and recovery from failures. + Lambda durable function that invokes an ECS task and waits for completion. + Uses wait_for_condition for polling and durable steps for checkpointing. """ - # Get configuration from environment variables cluster = os.environ['ECS_CLUSTER'] task_definition = os.environ['TASK_DEFINITION'] subnet1 = os.environ['SUBNET_1'] subnet2 = os.environ['SUBNET_2'] security_group = os.environ['SECURITY_GROUP'] - # Get input parameters message = event.get('message', 'No message provided') processing_time = event.get('processingTime', 5) try: - # Step 1: Start ECS task (checkpointed) + # Step 1: Start ECS task (checkpointed, with idempotency token) task_arn = context.step(start_ecs_task( cluster, task_definition, subnet1, subnet2, security_group, message, processing_time )) - # Poll for task completion using durable waits - max_attempts = 60 # 5 minutes max (60 * 5 seconds) - poll_interval = 5 # Check every 5 seconds + # Step 2: Poll for task completion using wait_for_condition + result = context.wait_for_condition( + lambda state, ctx: check_ecs_status(state['cluster'], state['task_arn']), + config=WaitForConditionConfig( + initial_state={'cluster': cluster, 'task_arn': task_arn, 'status': 'PENDING'}, + wait_strategy=lambda state, attempt: + WaitForConditionDecision(should_continue=False, delay=Duration.from_seconds(0)) if state.get('status') == 'STOPPED' + else WaitForConditionDecision(should_continue=True, delay=Duration.from_seconds(5)) + ) + ) - for attempt in range(max_attempts): - # Wait before checking status (no compute charges during wait) - context.wait(Duration.from_seconds(poll_interval)) - - # Step 2: Check task status (checkpointed) - status_result = context.step(check_task_status(cluster, task_arn)) - - if status_result['status'] == 'STOPPED': - # Task completed - task = status_result['task'] - stop_code = task.get('stopCode', 'Unknown') - - if stop_code == 'EssentialContainerExited': - exit_code = task['containers'][0].get('exitCode', 1) - - if exit_code == 0: - context.logger.info(f"[SYNC] Task completed successfully") - return { - 'statusCode': 200, - 'body': json.dumps({ - 'status': 'success', - 'message': f'Processed: {message}', - 'processingTime': processing_time, - 'taskArn': task_arn - }) - } - else: - raise Exception(f"Task failed with exit code: {exit_code}") - else: - raise Exception(f"Task stopped unexpectedly: {stop_code}") + task = result.get('task', {}) + stop_code = task.get('stopCode', 'Unknown') - # Timeout - raise Exception(f"Task did not complete within {max_attempts * poll_interval} seconds") + if stop_code == 'EssentialContainerExited': + exit_code = task['containers'][0].get('exitCode', 1) + + if exit_code == 0: + context.logger.info(f"[SYNC] Task completed successfully") + return { + 'statusCode': 200, + 'body': json.dumps({ + 'status': 'success', + 'message': f'Processed: {message}', + 'processingTime': processing_time, + 'taskArn': task_arn + }) + } + else: + raise Exception(f"Task failed with exit code: {exit_code}") + else: + raise Exception(f"Task stopped unexpectedly: {stop_code}") except Exception as e: context.logger.error(f"[SYNC] Error: {str(e)}")