Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lambda-ecs-durable-python-sam/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This pattern demonstrates how to invoke Amazon ECS tasks from AWS Lambda durable functions using Python. The workflow starts an ECS task, waits for a callback, and resumes based on the task result while maintaining state across the pause/resume cycle.

Learn more about this pattern at Serverless Land Patterns: https://serverlessland.com/patterns/lambda-ecs-python-sam
Learn more about this pattern at Serverless Land Patterns: https://serverlessland.com/patterns/lambda-ecs-durable-python-sam

Important: this application uses various AWS services and there are costs associated with these services after the Free Tier usage - please see the [AWS Pricing page](https://aws.amazon.com/pricing/) for details. You are responsible for any AWS costs incurred. No warranty is implied in this example.

Expand Down
6 changes: 6 additions & 0 deletions lambda-ecs-durable-python-sam/src/callback_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import boto3
import os
import hashlib
from aws_durable_execution_sdk_python import (
DurableContext,
durable_execution,
Expand All @@ -13,13 +14,18 @@ def start_ecs_task_with_callback(cluster, task_definition, subnet1, subnet2, sec
"""
Starts an ECS task and passes the callback token via environment variable.
The ECS task will call Lambda durable execution callback APIs when complete.
Uses callback_token as idempotency token to prevent duplicate tasks on retry.
"""
print(f"[CALLBACK] Starting ECS task with callback token")

# Use callback token hash as clientToken for idempotency (max 64 chars)
client_token = hashlib.sha256(callback_token.encode()).hexdigest()[:64]

response = ecs_client.run_task(
cluster=cluster,
taskDefinition=task_definition,
launchType='FARGATE',
clientToken=client_token,
networkConfiguration={
'awsvpcConfiguration': {
'subnets': [subnet1, subnet2],
Expand Down
111 changes: 52 additions & 59 deletions lambda-ecs-durable-python-sam/src/sync_handler.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,36 @@
import json
import boto3
import os
import hashlib
from aws_durable_execution_sdk_python import (
DurableContext,
durable_execution,
durable_step,
)
from aws_durable_execution_sdk_python.config import Duration
from aws_durable_execution_sdk_python.waits import WaitForConditionConfig, WaitForConditionDecision

ecs_client = boto3.client('ecs')


@durable_step
def start_ecs_task(step_context, cluster, task_definition, subnet1, subnet2, security_group, message, processing_time):
"""
Durable step that starts an ECS task.
This step is checkpointed, so if interrupted, it won't re-execute.
Uses a deterministic clientToken for idempotency in case of retry
before checkpoint is saved.
"""
step_context.logger.info(f"[SYNC] Starting ECS task with message: {message}")

# Generate deterministic idempotency token from inputs
token_input = f"{cluster}:{task_definition}:{message}:{processing_time}"
client_token = hashlib.sha256(token_input.encode()).hexdigest()[:64]

response = ecs_client.run_task(
cluster=cluster,
taskDefinition=task_definition,
launchType='FARGATE',
clientToken=client_token,
networkConfiguration={
'awsvpcConfiguration': {
'subnets': [subnet1, subnet2],
Expand Down Expand Up @@ -50,97 +59,81 @@ def start_ecs_task(step_context, cluster, task_definition, subnet1, subnet2, sec

return task_arn

@durable_step
def check_task_status(step_context, cluster, task_arn):
"""
Durable step that checks ECS task status.
This step is checkpointed and can be retried if it fails.
"""
step_context.logger.info(f"[SYNC] Checking task status: {task_arn}")


def check_ecs_status(cluster, task_arn):
"""Check ECS task status (called by wait_for_condition)."""
describe_response = ecs_client.describe_tasks(
cluster=cluster,
tasks=[task_arn]
)

if not describe_response['tasks']:
raise Exception(f"Task not found: {task_arn}")
return {'status': 'UNKNOWN', 'cluster': cluster, 'task_arn': task_arn}

task = describe_response['tasks'][0]
last_status = task['lastStatus']

step_context.logger.info(f"[SYNC] Task status: {last_status}")

return {
'status': last_status,
'task': task
'status': task['lastStatus'],
'task': task,
'cluster': cluster,
'task_arn': task_arn
}


@durable_execution
def lambda_handler(event, context: DurableContext):
"""
Lambda Durable Function that invokes an ECS task and waits for completion.
Uses the Durable Execution SDK for automatic checkpointing and replay.

This function can run for up to 1 year, with automatic state management
and recovery from failures.
Lambda durable function that invokes an ECS task and waits for completion.
Uses wait_for_condition for polling and durable steps for checkpointing.
"""

# Get configuration from environment variables
cluster = os.environ['ECS_CLUSTER']
task_definition = os.environ['TASK_DEFINITION']
subnet1 = os.environ['SUBNET_1']
subnet2 = os.environ['SUBNET_2']
security_group = os.environ['SECURITY_GROUP']

# Get input parameters
message = event.get('message', 'No message provided')
processing_time = event.get('processingTime', 5)

try:
# Step 1: Start ECS task (checkpointed)
# Step 1: Start ECS task (checkpointed, with idempotency token)
task_arn = context.step(start_ecs_task(
cluster, task_definition, subnet1, subnet2,
security_group, message, processing_time
))

# Poll for task completion using durable waits
max_attempts = 60 # 5 minutes max (60 * 5 seconds)
poll_interval = 5 # Check every 5 seconds
# Step 2: Poll for task completion using wait_for_condition
result = context.wait_for_condition(
lambda state, ctx: check_ecs_status(state['cluster'], state['task_arn']),
config=WaitForConditionConfig(
initial_state={'cluster': cluster, 'task_arn': task_arn, 'status': 'PENDING'},
wait_strategy=lambda state, attempt:
WaitForConditionDecision(should_continue=False, delay=Duration.from_seconds(0)) if state.get('status') == 'STOPPED'
else WaitForConditionDecision(should_continue=True, delay=Duration.from_seconds(5))
)
)

for attempt in range(max_attempts):
# Wait before checking status (no compute charges during wait)
context.wait(Duration.from_seconds(poll_interval))

# Step 2: Check task status (checkpointed)
status_result = context.step(check_task_status(cluster, task_arn))

if status_result['status'] == 'STOPPED':
# Task completed
task = status_result['task']
stop_code = task.get('stopCode', 'Unknown')

if stop_code == 'EssentialContainerExited':
exit_code = task['containers'][0].get('exitCode', 1)

if exit_code == 0:
context.logger.info(f"[SYNC] Task completed successfully")
return {
'statusCode': 200,
'body': json.dumps({
'status': 'success',
'message': f'Processed: {message}',
'processingTime': processing_time,
'taskArn': task_arn
})
}
else:
raise Exception(f"Task failed with exit code: {exit_code}")
else:
raise Exception(f"Task stopped unexpectedly: {stop_code}")
task = result.get('task', {})
stop_code = task.get('stopCode', 'Unknown')

# Timeout
raise Exception(f"Task did not complete within {max_attempts * poll_interval} seconds")
if stop_code == 'EssentialContainerExited':
exit_code = task['containers'][0].get('exitCode', 1)

if exit_code == 0:
context.logger.info(f"[SYNC] Task completed successfully")
return {
'statusCode': 200,
'body': json.dumps({
'status': 'success',
'message': f'Processed: {message}',
'processingTime': processing_time,
'taskArn': task_arn
})
}
else:
raise Exception(f"Task failed with exit code: {exit_code}")
else:
raise Exception(f"Task stopped unexpectedly: {stop_code}")

except Exception as e:
context.logger.error(f"[SYNC] Error: {str(e)}")
Expand Down