From 9982adb4a3a46d5ea163525a939c14427013842c Mon Sep 17 00:00:00 2001 From: Patricia Fricke Date: Tue, 11 Mar 2025 15:07:49 +0000 Subject: [PATCH 1/3] make auth token callable --- tipg/collections.py | 2 +- tipg/database.py | 25 +++++++++++++++++++++++ tipg/settings.py | 48 ++++++++++++--------------------------------- 3 files changed, 38 insertions(+), 37 deletions(-) diff --git a/tipg/collections.py b/tipg/collections.py index 3f20adc9..e801688d 100644 --- a/tipg/collections.py +++ b/tipg/collections.py @@ -962,7 +962,7 @@ async def get_collection_index( # noqa: C901 table_id = table["schema"] + "." + table["name"] confid = table["schema"] + "_" + table["name"] - if table_id == "pg_temp.tipg_catalog" or table_id == "public.tipg_catalog": + if "tipg_catalog" in table_id: continue table_conf = table_confs.get(confid, TableConfig()) diff --git a/tipg/database.py b/tipg/database.py index 446564b4..bb22e9ef 100644 --- a/tipg/database.py +++ b/tipg/database.py @@ -1,8 +1,11 @@ """tipg.db: database events.""" +import functools +import os import pathlib from typing import List, Optional +import boto3 import orjson from buildpg import asyncpg @@ -20,6 +23,23 @@ DB_CATALOG_FILE = resources_files(__package__) / "sql" / "dbcatalog.sql" +def get_rds_token( + host: str | None, port: int | None, user: str | None, region: str | None +) -> str: + """Get RDS token for IAM auth""" + logger.debug( + f"Retrieving RDS IAM token with host: {host}, port: {port}, user: {user}, region: {region}" + ) + rds_client = boto3.client("rds") + token = rds_client.generate_db_auth_token( + DBHostname=host, + Port=port, + DBUsername=user, + Region=region or rds_client.meta.region_name, + ) + return token + + class connection_factory: """Connection creation.""" @@ -90,6 +110,11 @@ async def connect_to_db( con_init = connection_factory(schemas, user_sql_files, skip_sql_execution) + if os.environ.get("IAM_AUTH_ENABLED"): + kwargs["password"] = functools.partial( + get_rds_token, settings.host, settings.port, settings.user, settings.region + ) + app.state.pool = await asyncpg.create_pool_b( str(settings.database_url), min_size=settings.db_min_conn_size, diff --git a/tipg/settings.py b/tipg/settings.py index 5f3ad18b..c0d48c20 100644 --- a/tipg/settings.py +++ b/tipg/settings.py @@ -3,14 +3,13 @@ import json import pathlib from typing import Any, Dict, List, Optional -from urllib.parse import quote_plus -import boto3 from pydantic import ( BaseModel, DirectoryPath, Field, PostgresDsn, + ValidationInfo, field_validator, model_validator, ) @@ -154,45 +153,22 @@ class PostgresSettings(BaseSettings): # https://github.com/tiangolo/full-stack-fastapi-postgresql/blob/master/%7B%7Bcookiecutter.project_slug%7D%7D/backend/app/app/core/config.py#L42 @field_validator("database_url", mode="before") - def assemble_db_connection(cls, v: Optional[str], info: Any) -> Any: - """Validate and assemble the database connection string.""" + def assemble_db_connection( + cls, v: Optional[str], info: ValidationInfo + ) -> PostgresDsn: + """Validate db url settings.""" if isinstance(v, str): - return v + return PostgresDsn(v) - username = info.data["postgres_user"] - host = info.data.get("postgres_host", "") - port = info.data.get("postgres_port", 5432) - dbname = info.data.get("postgres_dbname", "") - - # Determine password/token based on IAM flag - if info.data.get("iam_auth_enabled"): - region = info.data.get("aws_region") - if not region: - raise ValueError( - "aws_region must be provided when IAM authentication is enabled" - ) - rds_client = boto3.client("rds", region_name=region) - token = rds_client.generate_db_auth_token( - DBHostname=host, - Port=int(port), - DBUsername=username, - Region=region, - ) - password = quote_plus(token) - else: - password = info.data["postgres_pass"] - - db_url = PostgresDsn.build( + return PostgresDsn.build( scheme="postgresql", - username=username, - password=password, - host=host, - port=port, - path=dbname, + username=info.data.get("postgres_user"), + password=info.data.get("postgres_pass"), + host=info.data.get("postgres_host", ""), + port=info.data.get("postgres_port", 5432), + path=info.data.get("postgres_dbname", ""), ) - return db_url - class DatabaseSettings(BaseSettings): """TiPg Database settings.""" From a202243ef79eface471a2db3bfc757731bd0e031 Mon Sep 17 00:00:00 2001 From: Patricia Fricke Date: Tue, 11 Mar 2025 15:11:52 +0000 Subject: [PATCH 2/3] make auth token callable --- tipg/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tipg/database.py b/tipg/database.py index bb22e9ef..2f313ffb 100644 --- a/tipg/database.py +++ b/tipg/database.py @@ -110,7 +110,7 @@ async def connect_to_db( con_init = connection_factory(schemas, user_sql_files, skip_sql_execution) - if os.environ.get("IAM_AUTH_ENABLED"): + if os.environ.get("IAM_AUTH_ENABLED") == "TRUE": kwargs["password"] = functools.partial( get_rds_token, settings.host, settings.port, settings.user, settings.region ) From 33f4c8bd9fecb01bc5e8fcb4a8d308a56f5122e6 Mon Sep 17 00:00:00 2001 From: Patricia Fricke Date: Tue, 11 Mar 2025 15:16:29 +0000 Subject: [PATCH 3/3] make auth token callable --- tipg/database.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tipg/database.py b/tipg/database.py index 2f313ffb..c36fd505 100644 --- a/tipg/database.py +++ b/tipg/database.py @@ -3,7 +3,7 @@ import functools import os import pathlib -from typing import List, Optional +from typing import List, Optional, Union import boto3 import orjson @@ -24,7 +24,10 @@ def get_rds_token( - host: str | None, port: int | None, user: str | None, region: str | None + host: Union[str, None], + port: Union[int, None], + user: Union[str, None], + region: Union[str, None], ) -> str: """Get RDS token for IAM auth""" logger.debug(