From 3705cad6b7f7736c79c65baaea145365e2139015 Mon Sep 17 00:00:00 2001 From: Patricia Fricke Date: Tue, 18 Mar 2025 14:19:26 +0000 Subject: [PATCH 1/3] iam kwargs refactor --- tipg/database.py | 37 +++---------------------------------- tipg/settings.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/tipg/database.py b/tipg/database.py index 09ac9287..1aa781bf 100644 --- a/tipg/database.py +++ b/tipg/database.py @@ -1,12 +1,9 @@ """tipg.db: database events.""" -import functools -import os import pathlib from importlib.resources import files as resources_files -from typing import List, Optional, Union +from typing import List, Optional -import boto3 import orjson from buildpg import asyncpg @@ -18,26 +15,6 @@ DB_CATALOG_FILE = resources_files(__package__) / "sql" / "dbcatalog.sql" -def get_rds_token( - host: Union[str, None], - port: Union[int, None], - user: Union[str, None], - region: Union[str, None], -) -> str: - """Get RDS token for IAM auth""" - logger.debug( - f"Retrieving RDS IAM token with host: {host}, port: {port}, user: {user}, region: {region}" - ) - rds_client = boto3.client("rds") - token = rds_client.generate_db_auth_token( - DBHostname=host, - Port=port, - DBUsername=user, - Region=region or rds_client.meta.region_name, - ) - return token - - class connection_factory: """Connection creation.""" @@ -111,15 +88,7 @@ async def connect_to_db( schemas, tipg_schema, user_sql_files, skip_sql_execution ) - if os.environ.get("IAM_AUTH_ENABLED") == "TRUE": - kwargs["password"] = functools.partial( - get_rds_token, - settings.postgres_host, - settings.postgres_port, - settings.postgres_user, - settings.aws_region, - ) - kwargs["ssl"] = "require" + merged_pool_kwargs = {**settings.pool_kwargs, **(kwargs or {})} app.state.pool = await asyncpg.create_pool_b( str(settings.database_url), @@ -128,7 +97,7 @@ async def connect_to_db( max_queries=settings.db_max_queries, max_inactive_connection_lifetime=settings.db_max_inactive_conn_lifetime, init=con_init, - **kwargs, + kwargs=merged_pool_kwargs, ) diff --git a/tipg/settings.py b/tipg/settings.py index f118012b..3eece286 100644 --- a/tipg/settings.py +++ b/tipg/settings.py @@ -1,5 +1,6 @@ """tipg config.""" +import boto3 import json import pathlib from typing import Any, Dict, List, Optional @@ -151,6 +152,33 @@ class PostgresSettings(BaseSettings): model_config = {"env_file": ".env", "extra": "ignore"} + def get_rds_token(self) -> str: + """Generate an RDS IAM token for authentication.""" + rds_client = boto3.client("rds") + token = rds_client.generate_db_auth_token( + DBHostname=self.postgres_host, + Port=self.postgres_port, + DBUsername=self.postgres_user, + Region=self.aws_region or rds_client.meta.region_name, + ) + return token + + @property + def pool_kwargs(self) -> Dict[str, Any]: + """ + Build the default connection parameters for the pool. + + If IAM auth is enabled, use a dynamic password callable (bound to get_rds_token). + Otherwise, use a static password if provided. + """ + kwargs: Dict[str, Any] + if self.iam_auth_enabled: + kwargs["password"] = self.get_rds_token + kwargs["ssl"] = "require" + elif self.postgres_pass: + kwargs["password"] = self.postgres_pass + return kwargs + # https://github.com/tiangolo/full-stack-fastapi-postgresql/blob/master/%7B%7Bcookiecutter.project_slug%7D%7D/backend/app/app/core/config.py#L42 @field_validator("database_url", mode="before") def assemble_db_connection( From 58dfbfc85fa39e09ab796ea13476f698db069bc1 Mon Sep 17 00:00:00 2001 From: Patricia Fricke Date: Tue, 18 Mar 2025 14:22:25 +0000 Subject: [PATCH 2/3] iam kwargs refactor --- tipg/settings.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tipg/settings.py b/tipg/settings.py index 3eece286..88e245da 100644 --- a/tipg/settings.py +++ b/tipg/settings.py @@ -1,10 +1,10 @@ """tipg config.""" -import boto3 import json import pathlib from typing import Any, Dict, List, Optional +import boto3 from pydantic import ( BaseModel, DirectoryPath, @@ -162,7 +162,7 @@ def get_rds_token(self) -> str: Region=self.aws_region or rds_client.meta.region_name, ) return token - + @property def pool_kwargs(self) -> Dict[str, Any]: """ @@ -171,7 +171,7 @@ def pool_kwargs(self) -> Dict[str, Any]: If IAM auth is enabled, use a dynamic password callable (bound to get_rds_token). Otherwise, use a static password if provided. """ - kwargs: Dict[str, Any] + kwargs: Dict[str, Any] = {} if self.iam_auth_enabled: kwargs["password"] = self.get_rds_token kwargs["ssl"] = "require" From 0c060a79d387060344588c8b8a76467e63c06521 Mon Sep 17 00:00:00 2001 From: Patricia Fricke Date: Tue, 18 Mar 2025 14:27:06 +0000 Subject: [PATCH 3/3] change kwargs argument in pool function --- tipg/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tipg/database.py b/tipg/database.py index 1aa781bf..7c7f1708 100644 --- a/tipg/database.py +++ b/tipg/database.py @@ -97,7 +97,7 @@ async def connect_to_db( max_queries=settings.db_max_queries, max_inactive_connection_lifetime=settings.db_max_inactive_conn_lifetime, init=con_init, - kwargs=merged_pool_kwargs, + **merged_pool_kwargs, )