Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 3 additions & 34 deletions tipg/database.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
"""tipg.db: database events."""

import functools
import os
import pathlib
from importlib.resources import files as resources_files
from typing import List, Optional, Union
from typing import List, Optional

import boto3
import orjson
from buildpg import asyncpg

Expand All @@ -18,26 +15,6 @@
DB_CATALOG_FILE = resources_files(__package__) / "sql" / "dbcatalog.sql"


def get_rds_token(
host: Union[str, None],
port: Union[int, None],
user: Union[str, None],
region: Union[str, None],
) -> str:
"""Get RDS token for IAM auth"""
logger.debug(
f"Retrieving RDS IAM token with host: {host}, port: {port}, user: {user}, region: {region}"
)
rds_client = boto3.client("rds")
token = rds_client.generate_db_auth_token(
DBHostname=host,
Port=port,
DBUsername=user,
Region=region or rds_client.meta.region_name,
)
return token


class connection_factory:
"""Connection creation."""

Expand Down Expand Up @@ -111,15 +88,7 @@ async def connect_to_db(
schemas, tipg_schema, user_sql_files, skip_sql_execution
)

if os.environ.get("IAM_AUTH_ENABLED") == "TRUE":
kwargs["password"] = functools.partial(
get_rds_token,
settings.postgres_host,
settings.postgres_port,
settings.postgres_user,
settings.aws_region,
)
kwargs["ssl"] = "require"
merged_pool_kwargs = {**settings.pool_kwargs, **(kwargs or {})}

app.state.pool = await asyncpg.create_pool_b(
str(settings.database_url),
Expand All @@ -128,7 +97,7 @@ async def connect_to_db(
max_queries=settings.db_max_queries,
max_inactive_connection_lifetime=settings.db_max_inactive_conn_lifetime,
init=con_init,
**kwargs,
**merged_pool_kwargs,
)


Expand Down
28 changes: 28 additions & 0 deletions tipg/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pathlib
from typing import Any, Dict, List, Optional

import boto3
from pydantic import (
BaseModel,
DirectoryPath,
Expand Down Expand Up @@ -151,6 +152,33 @@ class PostgresSettings(BaseSettings):

model_config = {"env_file": ".env", "extra": "ignore"}

def get_rds_token(self) -> str:
"""Generate an RDS IAM token for authentication."""
rds_client = boto3.client("rds")
token = rds_client.generate_db_auth_token(
DBHostname=self.postgres_host,
Port=self.postgres_port,
DBUsername=self.postgres_user,
Region=self.aws_region or rds_client.meta.region_name,
)
return token

@property
def pool_kwargs(self) -> Dict[str, Any]:
"""
Build the default connection parameters for the pool.

If IAM auth is enabled, use a dynamic password callable (bound to get_rds_token).
Otherwise, use a static password if provided.
"""
kwargs: Dict[str, Any] = {}
if self.iam_auth_enabled:
kwargs["password"] = self.get_rds_token
kwargs["ssl"] = "require"
elif self.postgres_pass:
kwargs["password"] = self.postgres_pass
return kwargs

# https://github.com/tiangolo/full-stack-fastapi-postgresql/blob/master/%7B%7Bcookiecutter.project_slug%7D%7D/backend/app/app/core/config.py#L42
@field_validator("database_url", mode="before")
def assemble_db_connection(
Expand Down
Loading