Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tipg/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ async def get_collection_index( # noqa: C901
table_id = table["schema"] + "." + table["name"]
confid = table["schema"] + "_" + table["name"]

if table_id == "pg_temp.tipg_catalog" or table_id == "public.tipg_catalog":
if "tipg_catalog" in table_id:
continue

table_conf = table_confs.get(confid, TableConfig())
Expand Down
30 changes: 29 additions & 1 deletion tipg/database.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""tipg.db: database events."""

import functools
import os
import pathlib
from typing import List, Optional
from typing import List, Optional, Union

import boto3
import orjson
from buildpg import asyncpg

Expand All @@ -20,6 +23,26 @@
DB_CATALOG_FILE = resources_files(__package__) / "sql" / "dbcatalog.sql"


def get_rds_token(
host: Union[str, None],
port: Union[int, None],
user: Union[str, None],
region: Union[str, None],
) -> str:
"""Get RDS token for IAM auth"""
logger.debug(
f"Retrieving RDS IAM token with host: {host}, port: {port}, user: {user}, region: {region}"
)
rds_client = boto3.client("rds")
token = rds_client.generate_db_auth_token(
DBHostname=host,
Port=port,
DBUsername=user,
Region=region or rds_client.meta.region_name,
)
return token


class connection_factory:
"""Connection creation."""

Expand Down Expand Up @@ -90,6 +113,11 @@ async def connect_to_db(

con_init = connection_factory(schemas, user_sql_files, skip_sql_execution)

if os.environ.get("IAM_AUTH_ENABLED") == "TRUE":
kwargs["password"] = functools.partial(
get_rds_token, settings.host, settings.port, settings.user, settings.region
)

app.state.pool = await asyncpg.create_pool_b(
str(settings.database_url),
min_size=settings.db_min_conn_size,
Expand Down
48 changes: 12 additions & 36 deletions tipg/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import json
import pathlib
from typing import Any, Dict, List, Optional
from urllib.parse import quote_plus

import boto3
from pydantic import (
BaseModel,
DirectoryPath,
Field,
PostgresDsn,
ValidationInfo,
field_validator,
model_validator,
)
Expand Down Expand Up @@ -154,45 +153,22 @@ class PostgresSettings(BaseSettings):

# https://github.com/tiangolo/full-stack-fastapi-postgresql/blob/master/%7B%7Bcookiecutter.project_slug%7D%7D/backend/app/app/core/config.py#L42
@field_validator("database_url", mode="before")
def assemble_db_connection(cls, v: Optional[str], info: Any) -> Any:
"""Validate and assemble the database connection string."""
def assemble_db_connection(
cls, v: Optional[str], info: ValidationInfo
) -> PostgresDsn:
"""Validate db url settings."""
if isinstance(v, str):
return v
return PostgresDsn(v)

username = info.data["postgres_user"]
host = info.data.get("postgres_host", "")
port = info.data.get("postgres_port", 5432)
dbname = info.data.get("postgres_dbname", "")

# Determine password/token based on IAM flag
if info.data.get("iam_auth_enabled"):
region = info.data.get("aws_region")
if not region:
raise ValueError(
"aws_region must be provided when IAM authentication is enabled"
)
rds_client = boto3.client("rds", region_name=region)
token = rds_client.generate_db_auth_token(
DBHostname=host,
Port=int(port),
DBUsername=username,
Region=region,
)
password = quote_plus(token)
else:
password = info.data["postgres_pass"]

db_url = PostgresDsn.build(
return PostgresDsn.build(
scheme="postgresql",
username=username,
password=password,
host=host,
port=port,
path=dbname,
username=info.data.get("postgres_user"),
password=info.data.get("postgres_pass"),
host=info.data.get("postgres_host", ""),
port=info.data.get("postgres_port", 5432),
path=info.data.get("postgres_dbname", ""),
)

return db_url


class DatabaseSettings(BaseSettings):
"""TiPg Database settings."""
Expand Down
Loading