diff --git a/tipg/database.py b/tipg/database.py index 09ac9287..7c7f1708 100644 --- a/tipg/database.py +++ b/tipg/database.py @@ -1,12 +1,9 @@ """tipg.db: database events.""" -import functools -import os import pathlib from importlib.resources import files as resources_files -from typing import List, Optional, Union +from typing import List, Optional -import boto3 import orjson from buildpg import asyncpg @@ -18,26 +15,6 @@ DB_CATALOG_FILE = resources_files(__package__) / "sql" / "dbcatalog.sql" -def get_rds_token( - host: Union[str, None], - port: Union[int, None], - user: Union[str, None], - region: Union[str, None], -) -> str: - """Get RDS token for IAM auth""" - logger.debug( - f"Retrieving RDS IAM token with host: {host}, port: {port}, user: {user}, region: {region}" - ) - rds_client = boto3.client("rds") - token = rds_client.generate_db_auth_token( - DBHostname=host, - Port=port, - DBUsername=user, - Region=region or rds_client.meta.region_name, - ) - return token - - class connection_factory: """Connection creation.""" @@ -111,15 +88,7 @@ async def connect_to_db( schemas, tipg_schema, user_sql_files, skip_sql_execution ) - if os.environ.get("IAM_AUTH_ENABLED") == "TRUE": - kwargs["password"] = functools.partial( - get_rds_token, - settings.postgres_host, - settings.postgres_port, - settings.postgres_user, - settings.aws_region, - ) - kwargs["ssl"] = "require" + merged_pool_kwargs = {**settings.pool_kwargs, **(kwargs or {})} app.state.pool = await asyncpg.create_pool_b( str(settings.database_url), @@ -128,7 +97,7 @@ async def connect_to_db( max_queries=settings.db_max_queries, max_inactive_connection_lifetime=settings.db_max_inactive_conn_lifetime, init=con_init, - **kwargs, + **merged_pool_kwargs, ) diff --git a/tipg/settings.py b/tipg/settings.py index f118012b..88e245da 100644 --- a/tipg/settings.py +++ b/tipg/settings.py @@ -4,6 +4,7 @@ import pathlib from typing import Any, Dict, List, Optional +import boto3 from pydantic import ( BaseModel, DirectoryPath, @@ -151,6 +152,33 @@ class PostgresSettings(BaseSettings): model_config = {"env_file": ".env", "extra": "ignore"} + def get_rds_token(self) -> str: + """Generate an RDS IAM token for authentication.""" + rds_client = boto3.client("rds") + token = rds_client.generate_db_auth_token( + DBHostname=self.postgres_host, + Port=self.postgres_port, + DBUsername=self.postgres_user, + Region=self.aws_region or rds_client.meta.region_name, + ) + return token + + @property + def pool_kwargs(self) -> Dict[str, Any]: + """ + Build the default connection parameters for the pool. + + If IAM auth is enabled, use a dynamic password callable (bound to get_rds_token). + Otherwise, use a static password if provided. + """ + kwargs: Dict[str, Any] = {} + if self.iam_auth_enabled: + kwargs["password"] = self.get_rds_token + kwargs["ssl"] = "require" + elif self.postgres_pass: + kwargs["password"] = self.postgres_pass + return kwargs + # https://github.com/tiangolo/full-stack-fastapi-postgresql/blob/master/%7B%7Bcookiecutter.project_slug%7D%7D/backend/app/app/core/config.py#L42 @field_validator("database_url", mode="before") def assemble_db_connection(