diff --git a/tipg/collections.py b/tipg/collections.py index 3f20adc9..e801688d 100644 --- a/tipg/collections.py +++ b/tipg/collections.py @@ -962,7 +962,7 @@ async def get_collection_index( # noqa: C901 table_id = table["schema"] + "." + table["name"] confid = table["schema"] + "_" + table["name"] - if table_id == "pg_temp.tipg_catalog" or table_id == "public.tipg_catalog": + if "tipg_catalog" in table_id: continue table_conf = table_confs.get(confid, TableConfig()) diff --git a/tipg/database.py b/tipg/database.py index 446564b4..c36fd505 100644 --- a/tipg/database.py +++ b/tipg/database.py @@ -1,8 +1,11 @@ """tipg.db: database events.""" +import functools +import os import pathlib -from typing import List, Optional +from typing import List, Optional, Union +import boto3 import orjson from buildpg import asyncpg @@ -20,6 +23,26 @@ DB_CATALOG_FILE = resources_files(__package__) / "sql" / "dbcatalog.sql" +def get_rds_token( + host: Union[str, None], + port: Union[int, None], + user: Union[str, None], + region: Union[str, None], +) -> str: + """Get RDS token for IAM auth""" + logger.debug( + f"Retrieving RDS IAM token with host: {host}, port: {port}, user: {user}, region: {region}" + ) + rds_client = boto3.client("rds") + token = rds_client.generate_db_auth_token( + DBHostname=host, + Port=port, + DBUsername=user, + Region=region or rds_client.meta.region_name, + ) + return token + + class connection_factory: """Connection creation.""" @@ -90,6 +113,11 @@ async def connect_to_db( con_init = connection_factory(schemas, user_sql_files, skip_sql_execution) + if os.environ.get("IAM_AUTH_ENABLED") == "TRUE": + kwargs["password"] = functools.partial( + get_rds_token, settings.host, settings.port, settings.user, settings.region + ) + app.state.pool = await asyncpg.create_pool_b( str(settings.database_url), min_size=settings.db_min_conn_size, diff --git a/tipg/settings.py b/tipg/settings.py index 5f3ad18b..c0d48c20 100644 --- a/tipg/settings.py +++ b/tipg/settings.py @@ -3,14 +3,13 @@ import json import pathlib from typing import Any, Dict, List, Optional -from urllib.parse import quote_plus -import boto3 from pydantic import ( BaseModel, DirectoryPath, Field, PostgresDsn, + ValidationInfo, field_validator, model_validator, ) @@ -154,45 +153,22 @@ class PostgresSettings(BaseSettings): # https://github.com/tiangolo/full-stack-fastapi-postgresql/blob/master/%7B%7Bcookiecutter.project_slug%7D%7D/backend/app/app/core/config.py#L42 @field_validator("database_url", mode="before") - def assemble_db_connection(cls, v: Optional[str], info: Any) -> Any: - """Validate and assemble the database connection string.""" + def assemble_db_connection( + cls, v: Optional[str], info: ValidationInfo + ) -> PostgresDsn: + """Validate db url settings.""" if isinstance(v, str): - return v + return PostgresDsn(v) - username = info.data["postgres_user"] - host = info.data.get("postgres_host", "") - port = info.data.get("postgres_port", 5432) - dbname = info.data.get("postgres_dbname", "") - - # Determine password/token based on IAM flag - if info.data.get("iam_auth_enabled"): - region = info.data.get("aws_region") - if not region: - raise ValueError( - "aws_region must be provided when IAM authentication is enabled" - ) - rds_client = boto3.client("rds", region_name=region) - token = rds_client.generate_db_auth_token( - DBHostname=host, - Port=int(port), - DBUsername=username, - Region=region, - ) - password = quote_plus(token) - else: - password = info.data["postgres_pass"] - - db_url = PostgresDsn.build( + return PostgresDsn.build( scheme="postgresql", - username=username, - password=password, - host=host, - port=port, - path=dbname, + username=info.data.get("postgres_user"), + password=info.data.get("postgres_pass"), + host=info.data.get("postgres_host", ""), + port=info.data.get("postgres_port", 5432), + path=info.data.get("postgres_dbname", ""), ) - return db_url - class DatabaseSettings(BaseSettings): """TiPg Database settings."""