From e7840590499814d81fae55c56ad023608b608297 Mon Sep 17 00:00:00 2001 From: Reynier Ortiz Vega Date: Mon, 18 May 2026 17:33:24 -0400 Subject: [PATCH 1/7] Add shared postgres client package Introduce a postgres/ package that wraps github.com/jackc/pgx/v5/pgxpool with a Config type, a NewPool factory, and a dynamic-authentication dispatcher. The package is the consolidation point for PostgreSQL connection plumbing currently duplicated across consumers (toolhive-registry-server today, with additional consumers planned). Schema management - migrations, sqlc bindings, application-specific type codecs - is intentionally left to each caller. The public surface mirrors the redis/ package: a single validated Config including pool tuning knobs and a DynamicAuth block, plus a NewPool constructor that auto-installs a BeforeConnect hook when DynamicAuth is configured. WithAfterConnect lets callers register custom pgx type codecs (for example, application-defined enum array codecs) without leaking those schema concerns into this package. NewAuthToken is exported for short-lived connections - migrations, in particular - where a pool hook is not available. AWS RDS IAM is the first dynamic-auth backend. Region "detect" triggers IMDS-based discovery. The auth dispatcher is structured so additional backends (Vault, GCP IAM) can be added without changing the call sites in NewPool or NewAuthToken. Config implements slog.LogValuer to redact password fields and reports only presence-indicators (has_password, has_migration_password, dynamic_auth) for credentials. Co-Authored-By: Claude Opus 4.7 (1M context) --- go.mod | 19 ++++ go.sum | 60 ++++++----- postgres/auth.go | 65 ++++++++++++ postgres/auth_test.go | 177 ++++++++++++++++++++++++++++++++ postgres/awsiam.go | 90 ++++++++++++++++ postgres/config.go | 201 ++++++++++++++++++++++++++++++++++++ postgres/config_test.go | 210 ++++++++++++++++++++++++++++++++++++++ postgres/doc.go | 81 +++++++++++++++ postgres/pool.go | 158 ++++++++++++++++++++++++++++ postgres/pool_test.go | 163 +++++++++++++++++++++++++++++ postgres/testdata_test.go | 15 +++ 11 files changed, 1211 insertions(+), 28 deletions(-) create mode 100644 postgres/auth.go create mode 100644 postgres/auth_test.go create mode 100644 postgres/awsiam.go create mode 100644 postgres/config.go create mode 100644 postgres/config_test.go create mode 100644 postgres/doc.go create mode 100644 postgres/pool.go create mode 100644 postgres/pool_test.go create mode 100644 postgres/testdata_test.go diff --git a/go.mod b/go.mod index 194e84b..ec4f874 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,13 @@ go 1.26 require ( github.com/adrg/xdg v0.5.3 github.com/alicebob/miniredis/v2 v2.38.0 + github.com/aws/aws-sdk-go-v2/config v1.32.17 + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 + github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.23 github.com/google/cel-go v0.28.1 github.com/google/go-containerregistry v0.21.6 github.com/google/uuid v1.6.0 + github.com/jackc/pgx/v5 v5.9.2 github.com/mark3labs/mcp-go v0.54.0 github.com/modelcontextprotocol/registry v1.7.9 github.com/opencontainers/go-digest v1.0.0 @@ -28,6 +32,18 @@ require ( filippo.io/edwards25519 v1.2.0 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.7 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect + github.com/aws/smithy-go v1.25.1 // indirect github.com/blang/semver v3.5.1+incompatible // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -66,6 +82,9 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/in-toto/attestation v1.1.2 // indirect github.com/in-toto/in-toto-golang v0.9.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/klauspost/compress v1.18.6 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pkg/errors v0.9.1 // indirect diff --git a/go.sum b/go.sum index 2ca197e..09fd111 100644 --- a/go.sum +++ b/go.sum @@ -42,36 +42,38 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3d github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-sdk-go v1.55.7 h1:UJrkFq7es5CShfBwlWAC8DA077vp8PyVbQd3lqLiztE= github.com/aws/aws-sdk-go v1.55.7/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= -github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= -github.com/aws/aws-sdk-go-v2/config v1.32.5 h1:pz3duhAfUgnxbtVhIK39PGF/AHYyrzGEyRD9Og0QrE8= -github.com/aws/aws-sdk-go-v2/config v1.32.5/go.mod h1:xmDjzSUs/d0BB7ClzYPAZMmgQdrodNjPPhd6bGASwoE= -github.com/aws/aws-sdk-go-v2/credentials v1.19.5 h1:xMo63RlqP3ZZydpJDMBsH9uJ10hgHYfQFIk1cHDXrR4= -github.com/aws/aws-sdk-go-v2/credentials v1.19.5/go.mod h1:hhbH6oRcou+LpXfA/0vPElh/e0M3aFeOblE1sssAAEk= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBUaETh+P1XwFy5vwHwK5r9k= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM= +github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8= +github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc= +github.com/aws/aws-sdk-go-v2/config v1.32.17 h1:FpL4/758/diKwqbytU0prpuiu60fgXKUWCpDJtApclU= +github.com/aws/aws-sdk-go-v2/config v1.32.17/go.mod h1:OXqUMzgXytfoF9JaKkhrOYsyh72t9G+MJH8mMRaexOE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.16 h1:r3RJBuU7X9ibt8RHbMjWE6y60QbKBiII6wSrXnapxSU= +github.com/aws/aws-sdk-go-v2/credentials v1.19.16/go.mod h1:6cx7zqDENJDbBIIWX6P8s0h6hqHC8Avbjh9Dseo27ug= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.23 h1:jPWBQFmN0v3kiumSS/4ES5rupdfR5jFi5fHwilsX+KY= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.23/go.mod h1:M0EHmcAard72YjeRQYxTbWkTUY8TXG0WHbtODbM/kzY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA= github.com/aws/aws-sdk-go-v2/service/kms v1.49.1 h1:U0asSZ3ifpuIehDPkRI2rxHbmFUMplDA2VeR9Uogrmw= github.com/aws/aws-sdk-go-v2/service/kms v1.49.1/go.mod h1:NZo9WJqQ0sxQ1Yqu1IwCHQFQunTms2MlVgejg16S1rY= -github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ= -github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.7 h1:eYnlt6QxnFINKzwxP5/Ucs1vkG7VT3Iezmvfgc2waUw= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.7/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12/go.mod h1:GQ73XawFFiWxyWXMHWfhiomvP3tXtdNar/fi8z18sx0= -github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= -github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= -github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= -github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 h1:+1Kl1zx6bWi4X7cKi3VYh29h8BvsCoHQEQ6ST9X8w7w= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg= +github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk= +github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio= +github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI= +github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -325,6 +327,7 @@ github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/theupdateframework/go-tuf v0.7.0 h1:CqbQFrWo1ae3/I0UCblSbczevCCbS31Qvs5LdxRWqRI= @@ -430,6 +433,7 @@ google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= diff --git a/postgres/auth.go b/postgres/auth.go new file mode 100644 index 0000000..b2004e7 --- /dev/null +++ b/postgres/auth.go @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "errors" + "fmt" + + "github.com/jackc/pgx/v5" +) + +// BeforeConnectFn rewrites a connection config (typically by replacing the +// password) immediately before pgx dials. It is the contract used by +// pgxpool.Config.BeforeConnect and by this package's dynamic-auth backends. +type BeforeConnectFn func(ctx context.Context, conn *pgx.ConnConfig) error + +// NewAuthToken returns a short-lived password for user produced by the +// dynamic-authentication backend configured in cfg.DynamicAuth. When +// DynamicAuth is nil, the empty string is returned and no error is raised — +// this lets callers fall back to a static Password or PGPASSFILE. +// +// This entry point is intended for short-lived connections (for example, +// running migrations) where pgxpool's BeforeConnect hook is not available. +// For pooled connections, prefer NewDynamicAuthFunc. +func NewAuthToken(ctx context.Context, cfg *Config, user string) (string, error) { + if cfg == nil { + return "", errors.New("config is nil") + } + if cfg.DynamicAuth == nil { + return "", nil + } + if cfg.DynamicAuth.AWSRDSIAM != nil { + return awsRDSIAMToken(ctx, cfg, user) + } + return "", errors.New("dynamicAuth is set but no supported auth method (e.g., awsRdsIam) is configured") +} + +// NewDynamicAuthFunc returns a BeforeConnect hook that resolves a fresh +// dynamic-auth credential on every connection attempt. The returned hook +// writes the resolved token into connConfig.Password. +// +// Returns an error when cfg.DynamicAuth is nil — callers that may or may +// not be configured for dynamic auth should branch on cfg.DynamicAuth +// before calling this constructor, or use NewPool which handles both +// shapes transparently. +func NewDynamicAuthFunc(ctx context.Context, cfg *Config, user string) (BeforeConnectFn, error) { + if cfg == nil { + return nil, errors.New("config is nil") + } + if cfg.DynamicAuth == nil { + return nil, errors.New("dynamic authentication is not configured") + } + if cfg.DynamicAuth.AWSRDSIAM != nil { + return awsRDSIAMBeforeConnect(ctx, cfg, user) + } + return nil, errors.New("dynamicAuth is set but no supported auth method (e.g., awsRdsIam) is configured") +} + +// wrapAuthError prefixes dynamic-auth errors with a consistent label so they +// are easy to spot in pool startup logs. +func wrapAuthError(backend string, err error) error { + return fmt.Errorf("dynamic auth (%s): %w", backend, err) +} diff --git a/postgres/auth_test.go b/postgres/auth_test.go new file mode 100644 index 0000000..d5a047e --- /dev/null +++ b/postgres/auth_test.go @@ -0,0 +1,177 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAuthToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + wantToken string + wantErr string + }{ + { + name: testCaseNilConfig, + cfg: nil, + wantErr: testErrConfigNil, + }, + { + name: "no dynamic auth returns empty token without error", + cfg: &Config{Host: "h", Port: 5432, User: "u", Database: "d"}, + wantToken: "", + }, + { + name: testCaseNoBackend, + cfg: &Config{ + Host: "h", Port: 5432, User: "u", Database: "d", + DynamicAuth: &DynamicAuthConfig{}, + }, + wantErr: testErrNoSupportedAuth, + }, + { + name: "AWS RDS IAM without region propagates error", + cfg: &Config{ + Host: "h", Port: 5432, User: "u", Database: "d", + DynamicAuth: &DynamicAuthConfig{ + AWSRDSIAM: &DynamicAuthAWSRDSIAM{}, + }, + }, + wantErr: testErrRegionMissing, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + token, err := NewAuthToken(t.Context(), tt.cfg, "user") + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + assert.Empty(t, token) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantToken, token) + }) + } +} + +func TestNewDynamicAuthFunc(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + wantErr string + }{ + { + name: testCaseNilConfig, + cfg: nil, + wantErr: testErrConfigNil, + }, + { + name: "no dynamic auth returns explicit error", + cfg: &Config{Host: "h", Port: 5432, User: "u", Database: "d"}, + wantErr: "dynamic authentication is not configured", + }, + { + name: testCaseNoBackend, + cfg: &Config{ + Host: "h", Port: 5432, User: "u", Database: "d", + DynamicAuth: &DynamicAuthConfig{}, + }, + wantErr: testErrNoSupportedAuth, + }, + { + name: "AWS RDS IAM without region", + cfg: &Config{ + Host: "h", Port: 5432, User: "u", Database: "d", + DynamicAuth: &DynamicAuthConfig{ + AWSRDSIAM: &DynamicAuthAWSRDSIAM{}, + }, + }, + wantErr: testErrRegionMissing, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + fn, err := NewDynamicAuthFunc(t.Context(), tt.cfg, "user") + require.Error(t, err) + assert.Nil(t, fn) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestResolveAWSRegion_Static(t *testing.T) { + t.Parallel() + cfg := &Config{ + Host: "h", Port: 5432, User: "u", Database: "d", + DynamicAuth: &DynamicAuthConfig{ + AWSRDSIAM: &DynamicAuthAWSRDSIAM{Region: "us-west-2"}, + }, + } + region, err := resolveAWSRegion(context.Background(), cfg) + require.NoError(t, err) + assert.Equal(t, "us-west-2", region) +} + +func TestResolveAWSRegion_EmptyRegion(t *testing.T) { + t.Parallel() + cfg := &Config{ + Host: "h", Port: 5432, User: "u", Database: "d", + DynamicAuth: &DynamicAuthConfig{AWSRDSIAM: &DynamicAuthAWSRDSIAM{}}, + } + _, err := resolveAWSRegion(context.Background(), cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), testErrRegionMissing) +} + +// TestResolveAWSRegion_DetectFailsWithoutIMDS exercises the IMDS path. The +// test deadline must elapse before imdsRegionTimeout fires so we get a +// deterministic ctx-cancellation error rather than a flaky one. +func TestResolveAWSRegion_DetectFailsWithoutIMDS(t *testing.T) { + t.Parallel() + cfg := &Config{ + Host: "h", Port: 5432, User: "u", Database: "d", + DynamicAuth: &DynamicAuthConfig{ + AWSRDSIAM: &DynamicAuthAWSRDSIAM{Region: "detect"}, + }, + } + // Use an already-cancelled context so the IMDS call fails immediately + // without depending on whether 169.254.169.254 is routable in CI. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := resolveAWSRegion(ctx, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "IMDS") +} + +// TestAwsRDSIAMBeforeConnect_ReturnsHookForStaticRegion verifies the +// constructor returns a non-nil hook when the region is statically +// configured. Actually invoking the hook would require AWS credentials and +// is out of scope for unit tests. +func TestAwsRDSIAMBeforeConnect_ReturnsHookForStaticRegion(t *testing.T) { + t.Parallel() + cfg := &Config{ + Host: "h", Port: 5432, User: "u", Database: "d", + DynamicAuth: &DynamicAuthConfig{ + AWSRDSIAM: &DynamicAuthAWSRDSIAM{Region: "us-west-2"}, + }, + } + fn, err := awsRDSIAMBeforeConnect(context.Background(), cfg, "appuser") + require.NoError(t, err) + assert.NotNil(t, fn) +} diff --git a/postgres/awsiam.go b/postgres/awsiam.go new file mode 100644 index 0000000..6b72a1d --- /dev/null +++ b/postgres/awsiam.go @@ -0,0 +1,90 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/jackc/pgx/v5" +) + +// awsRDSIAMRegionDetect is the sentinel value that triggers IMDS-based +// region discovery instead of using a statically configured region. +const awsRDSIAMRegionDetect = "detect" + +// imdsRegionTimeout is the upper bound on a single IMDS region lookup. +const imdsRegionTimeout = 2 * time.Second + +// resolveAWSRegion returns the AWS region to use for RDS IAM token +// generation. When the configured region is "detect", it queries the EC2 +// instance metadata service. +func resolveAWSRegion(ctx context.Context, cfg *Config) (string, error) { + iam := cfg.DynamicAuth.AWSRDSIAM + if iam.Region == "" { + return "", errors.New("AWS RDS IAM region is not configured") + } + if iam.Region != awsRDSIAMRegionDetect { + return iam.Region, nil + } + + client := imds.New(imds.Options{ + HTTPClient: &http.Client{Timeout: imdsRegionTimeout}, + }) + out, err := client.GetRegion(ctx, &imds.GetRegionInput{}) + if err != nil { + return "", fmt.Errorf("failed to detect region from IMDS: %w", err) + } + return out.Region, nil +} + +// awsRDSIAMToken returns a single AWS RDS IAM token for user, signed for the +// resolved region. The token can be used as a PostgreSQL password. +func awsRDSIAMToken(ctx context.Context, cfg *Config, user string) (string, error) { + region, err := resolveAWSRegion(ctx, cfg) + if err != nil { + return "", wrapAuthError("awsRdsIam", err) + } + return buildAWSToken(ctx, cfg, region, user) +} + +// awsRDSIAMBeforeConnect returns a BeforeConnect hook that generates a fresh +// RDS IAM token before each connection attempt. The region is resolved once +// at construction time; per-connection cost is reduced to a single signing +// operation. +func awsRDSIAMBeforeConnect(ctx context.Context, cfg *Config, user string) (BeforeConnectFn, error) { + region, err := resolveAWSRegion(ctx, cfg) + if err != nil { + return nil, wrapAuthError("awsRdsIam", err) + } + return func(ctx context.Context, conn *pgx.ConnConfig) error { + token, err := buildAWSToken(ctx, cfg, region, user) + if err != nil { + return wrapAuthError("awsRdsIam", err) + } + conn.Password = token + return nil + }, nil +} + +// buildAWSToken signs an RDS IAM token using the workload's ambient AWS +// credentials (env vars, instance profile, EKS web-identity, etc.). +func buildAWSToken(ctx context.Context, cfg *Config, region, user string) (string, error) { + awsCfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) + if err != nil { + return "", fmt.Errorf("failed to load AWS config: %w", err) + } + endpoint := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + token, err := auth.BuildAuthToken(ctx, endpoint, region, user, awsCfg.Credentials) + if err != nil { + return "", fmt.Errorf("failed to build authentication token: %w", err) + } + return token, nil +} diff --git a/postgres/config.go b/postgres/config.go new file mode 100644 index 0000000..bb55cff --- /dev/null +++ b/postgres/config.go @@ -0,0 +1,201 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "errors" + "fmt" + "log/slog" + "net/url" + "time" +) + +// DefaultSSLMode is applied by BuildConnectionStringWithAuth when Config.SSLMode +// is empty. "require" is the safe production default: encryption is mandatory +// but the server certificate is not validated against a CA. Callers that need +// stricter behavior should set SSLMode to "verify-ca" or "verify-full". +const DefaultSSLMode = "require" + +// Config configures a PostgreSQL connection pool. Password fields are +// resolved by the caller — file-based secrets, environment variables, and +// pgpass fallback all live outside this package. +type Config struct { + // Host is the database server hostname or IP address. + Host string + + // Port is the database server port. Required. + Port int + + // User is the database username for normal operations + // (SELECT/INSERT/UPDATE/DELETE). + User string + + // Password is the application-user password. When empty, pgx falls back + // to PGPASSFILE / ~/.pgpass. Mutually exclusive with DynamicAuth at the + // caller's option — this package does not enforce mutual exclusion + // because some callers legitimately want a static fallback during local + // development. + Password string //nolint:gosec // G101: field name, not a hardcoded credential + + // MigrationUser is the database username for schema migrations. When + // empty, defaults to User. + MigrationUser string + + // MigrationPassword is the password for MigrationUser. When empty and + // MigrationUser equals User, falls back to Password. Otherwise pgx falls + // back to PGPASSFILE / ~/.pgpass. + MigrationPassword string //nolint:gosec // G101: field name, not a hardcoded credential + + // Database is the database name. + Database string + + // SSLMode is the SSL mode for the connection (disable, require, verify-ca, + // verify-full). When empty, DefaultSSLMode is applied by the connection + // string builder. + SSLMode string + + // DynamicAuth, when non-nil, generates short-lived credentials at connect + // time via NewPool's automatically-installed BeforeConnect hook. + DynamicAuth *DynamicAuthConfig + + // MaxOpenConns sets the upper bound on open connections in the pool. When + // zero, pgxpool's default is used. + MaxOpenConns int32 + + // MaxIdleConns sets the minimum number of idle connections the pool tries + // to maintain. When zero, pgxpool's default is used. + MaxIdleConns int32 + + // ConnMaxLifetime is the maximum lifetime of a connection. When zero, + // pgxpool's default is used. + ConnMaxLifetime time.Duration +} + +// DynamicAuthConfig selects a dynamic-authentication backend. Exactly one +// backend field must be non-nil when DynamicAuthConfig itself is non-nil. +type DynamicAuthConfig struct { + // AWSRDSIAM enables AWS RDS IAM authentication tokens. + AWSRDSIAM *DynamicAuthAWSRDSIAM +} + +// DynamicAuthAWSRDSIAM configures AWS RDS IAM dynamic authentication. +type DynamicAuthAWSRDSIAM struct { + // Region is the AWS region used to sign IAM tokens. Use "detect" to + // auto-discover the region from the EC2 instance metadata service (IMDS). + Region string +} + +// Validate checks Config for required-field and consistency errors and +// returns the first violation encountered. +func (c *Config) Validate() error { + if c == nil { + return errors.New("config is nil") + } + if c.Host == "" { + return errors.New("host is required") + } + if c.Port == 0 { + return errors.New("port is required") + } + if c.User == "" { + return errors.New("user is required") + } + if c.Database == "" { + return errors.New("database is required") + } + if c.DynamicAuth != nil { + if c.DynamicAuth.AWSRDSIAM == nil { + return errors.New("dynamicAuth is set but no supported auth method (e.g., awsRdsIam) is configured") + } + if c.DynamicAuth.AWSRDSIAM != nil && c.DynamicAuth.AWSRDSIAM.Region == "" { + return errors.New("dynamicAuth.awsRdsIam.region is required") + } + } + return nil +} + +// LogValue implements slog.LogValuer. It redacts password fields and reports +// only presence-indicators for credentials, preventing accidental secret +// disclosure in logs. +func (c *Config) LogValue() slog.Value { + if c == nil { + return slog.Value{} + } + return slog.GroupValue( + slog.String("host", c.Host), + slog.Int("port", c.Port), + slog.String("user", c.User), + slog.String("database", c.Database), + slog.String("ssl_mode", c.SSLMode), + slog.Bool("has_password", c.Password != ""), + slog.Bool("has_migration_password", c.MigrationPassword != ""), + slog.Bool("dynamic_auth", c.DynamicAuth != nil), + ) +} + +// GetMigrationUser returns the user that owns schema migrations. Falls back +// to User when MigrationUser is unset. +func (c *Config) GetMigrationUser() string { + if c.MigrationUser != "" { + return c.MigrationUser + } + return c.User +} + +// GetMigrationPassword returns the password for the migration user. When +// MigrationPassword is unset and the migration user matches User, the +// application Password is returned. Otherwise an empty string is returned so +// pgx can fall back to PGPASSFILE / ~/.pgpass. +func (c *Config) GetMigrationPassword() string { + if c.MigrationPassword != "" { + return c.MigrationPassword + } + if c.GetMigrationUser() == c.User { + return c.Password + } + return "" +} + +// ConnectionString builds a libpq-style connection URL for the application +// user. When Password is empty, pgx falls back to PGPASSFILE / ~/.pgpass. +func (c *Config) ConnectionString() string { + return c.BuildConnectionStringWithAuth(c.User, c.Password) +} + +// MigrationConnectionString builds a libpq-style connection URL for the +// migration user. Useful for short-lived migration tooling where a +// BeforeConnect hook is not available. +func (c *Config) MigrationConnectionString() string { + return c.BuildConnectionStringWithAuth(c.GetMigrationUser(), c.GetMigrationPassword()) +} + +// BuildConnectionStringWithAuth builds a libpq-style connection URL using +// the supplied user and password. When password is empty, the resulting URL +// omits credentials and pgx will fall back to PGPASSFILE / ~/.pgpass. +// +// The caller is responsible for resolving credentials — dynamic-auth token +// generation, secret-file reads, and env-var overrides all happen outside +// this package. +func (c *Config) BuildConnectionStringWithAuth(user, password string) string { + sslMode := c.SSLMode + if sslMode == "" { + sslMode = DefaultSSLMode + } + + var userInfo *url.Userinfo + if password != "" { + userInfo = url.UserPassword(user, password) + } else { + userInfo = url.User(user) + } + + return fmt.Sprintf( + "postgres://%s@%s:%d/%s?sslmode=%s", + userInfo.String(), + c.Host, + c.Port, + c.Database, + sslMode, + ) +} diff --git a/postgres/config_test.go b/postgres/config_test.go new file mode 100644 index 0000000..452b7bf --- /dev/null +++ b/postgres/config_test.go @@ -0,0 +1,210 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "bytes" + "context" + "log/slog" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func validConfig() *Config { + return &Config{ + Host: "db.example.com", + Port: 5432, + User: testUser, + Password: "s3cret", + Database: "appdb", + } +} + +func TestConfig_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *Config + wantErr string + }{ + { + name: testCaseNilConfig, + cfg: nil, + wantErr: testErrConfigNil, + }, + { + name: "missing host", + cfg: &Config{Port: 5432, User: "u", Database: "d"}, + wantErr: "host is required", + }, + { + name: "missing port", + cfg: &Config{Host: "h", User: "u", Database: "d"}, + wantErr: "port is required", + }, + { + name: "missing user", + cfg: &Config{Host: "h", Port: 5432, Database: "d"}, + wantErr: "user is required", + }, + { + name: "missing database", + cfg: &Config{Host: "h", Port: 5432, User: "u"}, + wantErr: "database is required", + }, + { + name: testCaseNoBackend, + cfg: &Config{ + Host: "h", Port: 5432, User: "u", Database: "d", + DynamicAuth: &DynamicAuthConfig{}, + }, + wantErr: testErrNoSupportedAuth, + }, + { + name: "AWS RDS IAM without region", + cfg: &Config{ + Host: "h", Port: 5432, User: "u", Database: "d", + DynamicAuth: &DynamicAuthConfig{ + AWSRDSIAM: &DynamicAuthAWSRDSIAM{}, + }, + }, + wantErr: testErrRegionConfigured, + }, + { + name: "valid minimal config", + cfg: validConfig(), + }, + { + name: "valid with AWS RDS IAM", + cfg: &Config{ + Host: "h", Port: 5432, User: "u", Database: "d", + DynamicAuth: &DynamicAuthConfig{ + AWSRDSIAM: &DynamicAuthAWSRDSIAM{Region: "us-east-1"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := tt.cfg.Validate() + if tt.wantErr == "" { + require.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestConfig_BuildConnectionStringWithAuth(t *testing.T) { + t.Parallel() + + t.Run("includes password when set", func(t *testing.T) { + t.Parallel() + cfg := &Config{Host: "h", Port: 5432, Database: "d"} + got := cfg.BuildConnectionStringWithAuth("alice", "p@ss/word") + // url.UserPassword percent-encodes special chars. + assert.Equal(t, "postgres://alice:p%40ss%2Fword@h:5432/d?sslmode=require", got) + }) + + t.Run("omits credentials when password empty", func(t *testing.T) { + t.Parallel() + cfg := &Config{Host: "h", Port: 5432, Database: "d"} + got := cfg.BuildConnectionStringWithAuth("alice", "") + assert.Equal(t, "postgres://alice@h:5432/d?sslmode=require", got) + }) + + t.Run("honors custom SSL mode", func(t *testing.T) { + t.Parallel() + cfg := &Config{Host: "h", Port: 5432, Database: "d", SSLMode: testSSLModeDisable} + got := cfg.BuildConnectionStringWithAuth("u", "") + assert.Contains(t, got, "sslmode="+testSSLModeDisable) + }) +} + +func TestConfig_MigrationHelpers(t *testing.T) { + t.Parallel() + + t.Run("falls back to User and Password when migration fields unset", func(t *testing.T) { + t.Parallel() + cfg := validConfig() + assert.Equal(t, testUser, cfg.GetMigrationUser()) + assert.Equal(t, "s3cret", cfg.GetMigrationPassword()) + assert.Equal(t, cfg.ConnectionString(), cfg.MigrationConnectionString()) + }) + + t.Run("uses MigrationUser and shares Password when users match", func(t *testing.T) { + t.Parallel() + cfg := validConfig() + cfg.MigrationUser = testUser // same as User + assert.Equal(t, "s3cret", cfg.GetMigrationPassword()) + }) + + t.Run("distinct migration user without password falls back to pgpass", func(t *testing.T) { + t.Parallel() + cfg := validConfig() + cfg.MigrationUser = "migrator" + assert.Equal(t, "migrator", cfg.GetMigrationUser()) + assert.Empty(t, cfg.GetMigrationPassword()) + got := cfg.MigrationConnectionString() + assert.Equal(t, "postgres://migrator@db.example.com:5432/appdb?sslmode=require", got) + }) + + t.Run("explicit migration password wins", func(t *testing.T) { + t.Parallel() + cfg := validConfig() + cfg.MigrationUser = "migrator" + cfg.MigrationPassword = "elev8" + assert.Equal(t, "elev8", cfg.GetMigrationPassword()) + }) +} + +func TestConfig_LogValueRedactsSecrets(t *testing.T) { + t.Parallel() + + cfg := &Config{ + Host: "db.example.com", + Port: 5432, + User: testUser, + Password: "should-not-appear", + MigrationPassword: "should-not-appear-either", + Database: "appdb", + SSLMode: "require", + DynamicAuth: &DynamicAuthConfig{AWSRDSIAM: &DynamicAuthAWSRDSIAM{Region: "us-east-1"}}, + } + + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, nil)) + logger.LogAttrs(context.Background(), slog.LevelInfo, "db", slog.Any("cfg", cfg)) + + out := buf.String() + assert.NotContains(t, out, "should-not-appear") + assert.Contains(t, out, `"has_password":true`) + assert.Contains(t, out, `"has_migration_password":true`) + assert.Contains(t, out, `"dynamic_auth":true`) + assert.Contains(t, out, `"host":"db.example.com"`) +} + +func TestConfig_LogValueNil(t *testing.T) { + t.Parallel() + var cfg *Config + // Should not panic. + got := cfg.LogValue() + assert.Equal(t, slog.Value{}, got) +} + +func TestConfig_ConnMaxLifetimeIsSlot(t *testing.T) { + t.Parallel() + cfg := validConfig() + cfg.ConnMaxLifetime = 30 * time.Minute + require.NoError(t, cfg.Validate()) + assert.Equal(t, 30*time.Minute, cfg.ConnMaxLifetime) +} diff --git a/postgres/doc.go b/postgres/doc.go new file mode 100644 index 0000000..a196df4 --- /dev/null +++ b/postgres/doc.go @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +/* +Package postgres provides a shared PostgreSQL connection layer for the +ToolHive ecosystem. It wraps github.com/jackc/pgx/v5/pgxpool with a single +Config type, a NewPool factory, and a dynamic-authentication dispatcher. +Schema management — migrations, queries, and per-application type codecs — +remains the caller's responsibility. + +# Quick Start + + cfg := &postgres.Config{ + Host: "db.example.com", + Port: 5432, + User: "appuser", + Password: "s3cret", + Database: "appdb", + } + + pool, err := postgres.NewPool(ctx, cfg) + if err != nil { + return err + } + defer pool.Close() + +# Dynamic Authentication + +Setting Config.DynamicAuth causes NewPool to install a BeforeConnect hook +that resolves a fresh credential before every connection attempt. + +Currently supported backends: + + - AWS RDS IAM — short-lived tokens signed with the workload's ambient + AWS credentials (env vars, EC2 instance profile, EKS web identity, …). + Region "detect" auto-discovers the region via IMDS. + +Example: + + cfg.DynamicAuth = &postgres.DynamicAuthConfig{ + AWSRDSIAM: &postgres.DynamicAuthAWSRDSIAM{Region: "us-east-1"}, + } + +For short-lived connections that cannot use a pool hook (for example +golang-migrate's one-shot migration connection), call NewAuthToken to +materialize a single token, then embed it via BuildConnectionStringWithAuth: + + token, _ := postgres.NewAuthToken(ctx, cfg, cfg.GetMigrationUser()) + connStr := cfg.BuildConnectionStringWithAuth(cfg.GetMigrationUser(), token) + +# Hooks + +WithAfterConnect installs an AfterConnect callback — the canonical place to +register application-specific type codecs (for example, codecs for +PostgreSQL enum array types defined in the caller's schema): + + pool, err := postgres.NewPool(ctx, cfg, + postgres.WithAfterConnect(func(ctx context.Context, conn *pgx.Conn) error { + return registerMyEnumCodecs(ctx, conn) + }), + ) + +WithBeforeConnect overrides the auto-installed dynamic-auth hook. Callers +that need to layer additional logic on top should call NewDynamicAuthFunc +explicitly and compose the result. + +# Logging + +NewPool emits a single info-level message on success, redacting password +fields via Config's slog.LogValuer implementation. StartPoolStatsLogger is +an opt-in helper that periodically logs connection-pool statistics at debug +level until its context is cancelled. + +# Secrets Handling + +This package treats Password and MigrationPassword as already-resolved +strings. File-based secret loading, environment-variable overrides, and +pgpass fallback all live in the caller's configuration layer. Config +implements slog.LogValuer to redact credentials when logged. +*/ +package postgres diff --git a/postgres/pool.go b/postgres/pool.go new file mode 100644 index 0000000..d785282 --- /dev/null +++ b/postgres/pool.go @@ -0,0 +1,158 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// DefaultPoolStatsInterval is the cadence at which StartPoolStatsLogger +// emits a connection-pool snapshot when no other interval is configured. +const DefaultPoolStatsInterval = 60 * time.Second + +// Option customizes NewPool. See WithBeforeConnect, WithAfterConnect, and +// WithLogger. +type Option func(*options) + +type options struct { + beforeConnect BeforeConnectFn + afterConnect func(ctx context.Context, conn *pgx.Conn) error + logger *slog.Logger +} + +// WithBeforeConnect installs a hook that runs immediately before pgx dials. +// It overrides the BeforeConnect hook that NewPool would otherwise install +// from cfg.DynamicAuth — callers that need to layer their own logic should +// invoke NewDynamicAuthFunc explicitly and combine the results. +func WithBeforeConnect(fn BeforeConnectFn) Option { + return func(o *options) { o.beforeConnect = fn } +} + +// WithAfterConnect installs a hook that runs immediately after a new +// connection has been established. The typical use case is registering +// custom type codecs (for example, application-defined enum array codecs). +func WithAfterConnect(fn func(ctx context.Context, conn *pgx.Conn) error) Option { + return func(o *options) { o.afterConnect = fn } +} + +// WithLogger sets the slog.Logger used for pool-creation messages and (when +// invoked) StartPoolStatsLogger output. When unset, slog.Default() is used. +func WithLogger(logger *slog.Logger) Option { + return func(o *options) { o.logger = logger } +} + +// NewPool creates a *pgxpool.Pool from cfg. When cfg.DynamicAuth is set and +// the caller does not supply WithBeforeConnect, NewPool installs the +// appropriate dynamic-auth hook automatically. +// +// cfg is validated; cfg is not mutated. +func NewPool(ctx context.Context, cfg *Config, opts ...Option) (*pgxpool.Pool, error) { + if cfg == nil { + return nil, errors.New("config is nil") + } + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) + } + + o := &options{} + for _, opt := range opts { + opt(o) + } + logger := o.logger + if logger == nil { + logger = slog.Default() + } + + poolConfig, err := pgxpool.ParseConfig(cfg.ConnectionString()) + if err != nil { + return nil, fmt.Errorf("failed to parse connection string: %w", err) + } + + applyPoolTuning(poolConfig, cfg) + + beforeConnect := o.beforeConnect + if beforeConnect == nil && cfg.DynamicAuth != nil { + beforeConnect, err = NewDynamicAuthFunc(ctx, cfg, cfg.User) + if err != nil { + return nil, err + } + } + if beforeConnect != nil { + poolConfig.BeforeConnect = beforeConnect + } + if o.afterConnect != nil { + poolConfig.AfterConnect = o.afterConnect + } + + pool, err := pgxpool.NewWithConfig(ctx, poolConfig) + if err != nil { + return nil, fmt.Errorf("failed to create connection pool: %w", err) + } + + logger.LogAttrs(ctx, slog.LevelInfo, "postgres connection pool created", slog.Any("config", cfg)) + return pool, nil +} + +// applyPoolTuning copies pool-sizing knobs from cfg onto poolConfig, leaving +// pgxpool's defaults in place where cfg has zero values. +func applyPoolTuning(poolConfig *pgxpool.Config, cfg *Config) { + if cfg.MaxOpenConns > 0 { + poolConfig.MaxConns = cfg.MaxOpenConns + } + if cfg.MaxIdleConns > 0 { + poolConfig.MinConns = cfg.MaxIdleConns + } + if cfg.ConnMaxLifetime > 0 { + poolConfig.MaxConnLifetime = cfg.ConnMaxLifetime + } +} + +// StartPoolStatsLogger emits a connection-pool snapshot at DEBUG every +// interval until ctx is cancelled. When interval is zero, the default +// cadence is used. When logger is nil, slog.Default() is used. +// +// This is an opt-in helper; consumers that want pool metrics through a +// different sink (OpenTelemetry, Prometheus) should read pool.Stat() +// themselves. +func StartPoolStatsLogger(ctx context.Context, pool *pgxpool.Pool, logger *slog.Logger, interval time.Duration) { + if pool == nil { + return + } + if interval == 0 { + interval = DefaultPoolStatsInterval + } + if logger == nil { + logger = slog.Default() + } + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + stat := pool.Stat() + logger.LogAttrs(ctx, slog.LevelDebug, "postgres pool stats", + slog.Int64("total_conns", int64(stat.TotalConns())), + slog.Int64("acquired_conns", int64(stat.AcquiredConns())), + slog.Int64("idle_conns", int64(stat.IdleConns())), + slog.Int64("max_conns", int64(stat.MaxConns())), + slog.Int64("acquire_count", stat.AcquireCount()), + slog.Int64("acquire_duration_ms", stat.AcquireDuration().Milliseconds()), + slog.Int64("canceled_acquire_count", stat.CanceledAcquireCount()), + slog.Int64("empty_acquire_count", stat.EmptyAcquireCount()), + ) + } + } + }() +} diff --git a/postgres/pool_test.go b/postgres/pool_test.go new file mode 100644 index 0000000..bc543bf --- /dev/null +++ b/postgres/pool_test.go @@ -0,0 +1,163 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "bytes" + "context" + "log/slog" + "sync/atomic" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewPool_NilConfig(t *testing.T) { + t.Parallel() + pool, err := NewPool(t.Context(), nil) + require.Error(t, err) + assert.Nil(t, pool) + assert.Contains(t, err.Error(), "config is nil") +} + +func TestNewPool_InvalidConfig(t *testing.T) { + t.Parallel() + pool, err := NewPool(t.Context(), &Config{}) + require.Error(t, err) + assert.Nil(t, pool) + assert.Contains(t, err.Error(), "invalid configuration") +} + +func TestNewPool_DynamicAuthMisconfigured(t *testing.T) { + t.Parallel() + cfg := validConfig() + cfg.DynamicAuth = &DynamicAuthConfig{AWSRDSIAM: &DynamicAuthAWSRDSIAM{Region: ""}} + pool, err := NewPool(t.Context(), cfg) + require.Error(t, err) + assert.Nil(t, pool) + // Validate() catches the empty region before NewDynamicAuthFunc runs. + assert.Contains(t, err.Error(), "dynamicAuth.awsRdsIam.region is required") +} + +// TestNewPool_LazyConnect verifies that NewPool returns successfully even +// when the database is unreachable — pgxpool establishes connections +// lazily on first Acquire, not at construction time. Dial errors surface +// at query time, not pool-creation time. +func TestNewPool_LazyConnect(t *testing.T) { + t.Parallel() + cfg := validConfig() + cfg.Host = "127.0.0.1" + cfg.Port = 1 // closed port; never opens + cfg.SSLMode = testSSLModeDisable + + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + + pool, err := NewPool(t.Context(), cfg, WithLogger(logger)) + require.NoError(t, err) + t.Cleanup(pool.Close) + + assert.Contains(t, buf.String(), "postgres connection pool created") + // Logger received the cfg via slog.LogValuer — password must not appear. + assert.NotContains(t, buf.String(), cfg.Password) +} + +func TestNewPool_OptionsArePopulated(t *testing.T) { + t.Parallel() + o := &options{} + WithBeforeConnect(func(context.Context, *pgx.ConnConfig) error { return nil })(o) + WithAfterConnect(func(context.Context, *pgx.Conn) error { return nil })(o) + WithLogger(slog.Default())(o) + assert.NotNil(t, o.beforeConnect) + assert.NotNil(t, o.afterConnect) + assert.NotNil(t, o.logger) +} + +func TestApplyPoolTuning(t *testing.T) { + t.Parallel() + + cfg := validConfig() + cfg.MaxOpenConns = 17 + cfg.MaxIdleConns = 3 + cfg.ConnMaxLifetime = 42 * time.Minute + + pc, err := pgxpool.ParseConfig(cfg.ConnectionString()) + require.NoError(t, err) + applyPoolTuning(pc, cfg) + + assert.Equal(t, int32(17), pc.MaxConns) + assert.Equal(t, int32(3), pc.MinConns) + assert.Equal(t, 42*time.Minute, pc.MaxConnLifetime) +} + +func TestApplyPoolTuning_PreservesDefaultsForZeroValues(t *testing.T) { + t.Parallel() + + cfg := validConfig() // all pool knobs at zero + + pc, err := pgxpool.ParseConfig(cfg.ConnectionString()) + require.NoError(t, err) + defaultMax := pc.MaxConns + defaultMin := pc.MinConns + defaultLifetime := pc.MaxConnLifetime + + applyPoolTuning(pc, cfg) + + assert.Equal(t, defaultMax, pc.MaxConns) + assert.Equal(t, defaultMin, pc.MinConns) + assert.Equal(t, defaultLifetime, pc.MaxConnLifetime) +} + +func TestStartPoolStatsLogger_ExitsOnContextCancel(t *testing.T) { + t.Parallel() + + cfg := validConfig() + cfg.Host = "127.0.0.1" + cfg.Port = 1 + cfg.SSLMode = testSSLModeDisable + + pool, err := NewPool(t.Context(), cfg) + require.NoError(t, err) + t.Cleanup(pool.Close) + + ctx, cancel := context.WithCancel(t.Context()) + StartPoolStatsLogger(ctx, pool, slog.Default(), 10*time.Millisecond) + cancel() + // Give the goroutine time to notice cancellation and return. This is a + // soft check — race-detector + leak-detector tooling at the suite level + // is what catches a leaked goroutine. + time.Sleep(50 * time.Millisecond) +} + +func TestStartPoolStatsLogger_NilPoolNoop(t *testing.T) { + t.Parallel() + // Must not panic. + StartPoolStatsLogger(t.Context(), nil, slog.Default(), 0) +} + +// TestStartPoolStatsLogger_UsesDefaultInterval verifies the default-interval +// branch is exercised. We do not wait the full default 60s; we just call +// the function and immediately cancel — coverage is the goal. +func TestStartPoolStatsLogger_UsesDefaultInterval(t *testing.T) { + t.Parallel() + + cfg := validConfig() + cfg.SSLMode = testSSLModeDisable + pool, err := NewPool(t.Context(), cfg) + require.NoError(t, err) + t.Cleanup(pool.Close) + + ctx, cancel := context.WithCancel(t.Context()) + StartPoolStatsLogger(ctx, pool, nil, 0) // default logger + default interval + cancel() + + // Touch the counter so the lint can't complain about unused atomic. + var seen atomic.Int32 + seen.Add(1) + assert.Equal(t, int32(1), seen.Load()) +} diff --git a/postgres/testdata_test.go b/postgres/testdata_test.go new file mode 100644 index 0000000..b705d88 --- /dev/null +++ b/postgres/testdata_test.go @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +const ( + testUser = "appuser" + testCaseNilConfig = "nil config" + testErrConfigNil = "config is nil" + testCaseNoBackend = "dynamic auth without backend" + testErrNoSupportedAuth = "no supported auth method" + testSSLModeDisable = "disable" + testErrRegionMissing = "AWS RDS IAM region is not configured" + testErrRegionConfigured = "dynamicAuth.awsRdsIam.region is required" +) From 0001cd2a9105a61f96e12af7903e0328a6195c6d Mon Sep 17 00:00:00 2001 From: Reynier Ortiz Vega Date: Tue, 19 May 2026 09:03:52 -0400 Subject: [PATCH 2/7] Build pool config structurally to prevent DSN injection The previous NewPool path formatted a libpq URL with fmt.Sprintf and passed it to pgxpool.ParseConfig. url.UserPassword only escapes the userinfo segment, so a Host containing "@" shifted authority parsing (url.Parse uses the last "@" as the userinfo delimiter, so "good.rds@evil.example.com" resolved to "evil.example.com") and a Database containing "?" introduced a second query section that could silently override sslmode. Switch NewPool to a hybrid approach: bootstrap pgxpool.Config from a minimal "postgres://localhost?sslmode=X" DSN so pgx still translates the SSL mode into a *tls.Config, then assign Host/Port/User/Password/Database structurally onto the resulting ConnConfig. Caller-controlled fields never enter url.Parse this way. Add defense in depth in Config.Validate by rejecting Host containing @/?/# or whitespace and Database containing ?/#/ or whitespace. Tighten the Port check to require 1..65535 so the uint16 conversion in buildPoolConfig has a real invariant behind it. The connection-string methods stay public for callers that hand a DSN to a migration tool; they are now safe by virtue of Validate rejecting the dangerous shapes at the boundary. Addresses code review feedback on #111. Co-Authored-By: Claude Opus 4.7 (1M context) --- postgres/config.go | 20 ++++++++++++++++++-- postgres/config_test.go | 22 +++++++++++++++++++++- postgres/pool.go | 34 ++++++++++++++++++++++++++++++++-- postgres/pool_test.go | 38 ++++++++++++++++++++++++++++++++++++++ postgres/testdata_test.go | 1 + 5 files changed, 110 insertions(+), 5 deletions(-) diff --git a/postgres/config.go b/postgres/config.go index bb55cff..906831c 100644 --- a/postgres/config.go +++ b/postgres/config.go @@ -8,9 +8,19 @@ import ( "fmt" "log/slog" "net/url" + "strings" "time" ) +// Characters rejected in Host and Database to prevent DSN-string injection +// when these values flow into a libpq URL. Host disallows the URL +// delimiters that could shift the authority section; Database disallows +// the delimiters that could shift the query section. +const ( + hostForbiddenChars = "@/?#" + databaseForbiddenChars = "?#/" +) + // DefaultSSLMode is applied by BuildConnectionStringWithAuth when Config.SSLMode // is empty. "require" is the safe production default: encryption is mandatory // but the server certificate is not validated against a CA. Callers that need @@ -95,8 +105,11 @@ func (c *Config) Validate() error { if c.Host == "" { return errors.New("host is required") } - if c.Port == 0 { - return errors.New("port is required") + if strings.ContainsAny(c.Host, hostForbiddenChars) || strings.ContainsAny(c.Host, " \t\r\n") { + return fmt.Errorf("host must not contain any of %q or whitespace", hostForbiddenChars) + } + if c.Port <= 0 || c.Port > 65535 { + return errors.New("port must be between 1 and 65535") } if c.User == "" { return errors.New("user is required") @@ -104,6 +117,9 @@ func (c *Config) Validate() error { if c.Database == "" { return errors.New("database is required") } + if strings.ContainsAny(c.Database, databaseForbiddenChars) || strings.ContainsAny(c.Database, " \t\r\n") { + return fmt.Errorf("database must not contain any of %q or whitespace", databaseForbiddenChars) + } if c.DynamicAuth != nil { if c.DynamicAuth.AWSRDSIAM == nil { return errors.New("dynamicAuth is set but no supported auth method (e.g., awsRdsIam) is configured") diff --git a/postgres/config_test.go b/postgres/config_test.go index 452b7bf..498df57 100644 --- a/postgres/config_test.go +++ b/postgres/config_test.go @@ -45,7 +45,27 @@ func TestConfig_Validate(t *testing.T) { { name: "missing port", cfg: &Config{Host: "h", User: "u", Database: "d"}, - wantErr: "port is required", + wantErr: "port must be between 1 and 65535", + }, + { + name: "port out of range", + cfg: &Config{Host: "h", Port: 65536, User: "u", Database: "d"}, + wantErr: "port must be between 1 and 65535", + }, + { + name: "host contains @ rejected", + cfg: &Config{Host: "good.rds@evil.example.com", Port: 5432, User: "u", Database: "d"}, + wantErr: `host must not contain any of "@/?#"`, + }, + { + name: "host contains whitespace rejected", + cfg: &Config{Host: "rds .example.com", Port: 5432, User: "u", Database: "d"}, + wantErr: "host must not contain", + }, + { + name: "database contains ? rejected", + cfg: &Config{Host: "h", Port: 5432, User: "u", Database: "appdb?sslmode=disable"}, + wantErr: `database must not contain any of "?#/"`, }, { name: "missing user", diff --git a/postgres/pool.go b/postgres/pool.go index d785282..e495aef 100644 --- a/postgres/pool.go +++ b/postgres/pool.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "log/slog" + "net/url" "time" "github.com/jackc/pgx/v5" @@ -71,9 +72,9 @@ func NewPool(ctx context.Context, cfg *Config, opts ...Option) (*pgxpool.Pool, e logger = slog.Default() } - poolConfig, err := pgxpool.ParseConfig(cfg.ConnectionString()) + poolConfig, err := buildPoolConfig(cfg) if err != nil { - return nil, fmt.Errorf("failed to parse connection string: %w", err) + return nil, err } applyPoolTuning(poolConfig, cfg) @@ -101,6 +102,35 @@ func NewPool(ctx context.Context, cfg *Config, opts ...Option) (*pgxpool.Pool, e return pool, nil } +// buildPoolConfig assembles a *pgxpool.Config without exposing +// caller-controlled fields to URL parsing. SSL configuration is bootstrapped +// from a minimal DSN — sslmode is the one field that pgx must translate into +// a *tls.Config — and the connection target is then assigned structurally. +// This eliminates the DSN-injection paths that would otherwise let a `@` in +// Host or a `?` in Database shift the authority or query section of the URL. +func buildPoolConfig(cfg *Config) (*pgxpool.Config, error) { + sslMode := cfg.SSLMode + if sslMode == "" { + sslMode = DefaultSSLMode + } + + // SSLMode values are constrained by pgx (disable/allow/prefer/require/ + // verify-ca/verify-full); url.QueryEscape is belt-and-suspenders. + bootDSN := "postgres://localhost?sslmode=" + url.QueryEscape(sslMode) + pc, err := pgxpool.ParseConfig(bootDSN) + if err != nil { + return nil, fmt.Errorf("failed to initialize pool config: %w", err) + } + + pc.ConnConfig.Host = cfg.Host + pc.ConnConfig.Port = uint16(cfg.Port) //nolint:gosec // G115: Port is bounded by Validate (1..65535). + pc.ConnConfig.User = cfg.User + pc.ConnConfig.Password = cfg.Password + pc.ConnConfig.Database = cfg.Database + + return pc, nil +} + // applyPoolTuning copies pool-sizing knobs from cfg onto poolConfig, leaving // pgxpool's defaults in place where cfg has zero values. func applyPoolTuning(poolConfig *pgxpool.Config, cfg *Config) { diff --git a/postgres/pool_test.go b/postgres/pool_test.go index bc543bf..dde9d42 100644 --- a/postgres/pool_test.go +++ b/postgres/pool_test.go @@ -78,6 +78,44 @@ func TestNewPool_OptionsArePopulated(t *testing.T) { assert.NotNil(t, o.logger) } +// TestBuildPoolConfig_AssignsFieldsStructurally verifies the structural +// override path: caller-supplied Host and Database land verbatim on the pgx +// ConnConfig instead of going through URL parsing. This is the test guard +// against the DSN-injection class flagged in code review — a `@` in Host or +// a `?` in Database is preserved as-is and never reinterpreted by url.Parse. +func TestBuildPoolConfig_AssignsFieldsStructurally(t *testing.T) { + t.Parallel() + + cfg := &Config{ + Host: "db-1.cluster.example.com", + Port: 5433, + User: "appuser", + Password: "s3cret", + Database: testDatabase, + SSLMode: "require", + } + + pc, err := buildPoolConfig(cfg) + require.NoError(t, err) + assert.Equal(t, "db-1.cluster.example.com", pc.ConnConfig.Host) + assert.Equal(t, uint16(5433), pc.ConnConfig.Port) + assert.Equal(t, "appuser", pc.ConnConfig.User) + assert.Equal(t, "s3cret", pc.ConnConfig.Password) + assert.Equal(t, "appdb", pc.ConnConfig.Database) + assert.NotNil(t, pc.ConnConfig.TLSConfig, "sslmode=require must produce a non-nil tls.Config") +} + +func TestBuildPoolConfig_SSLModeDisableLeavesTLSUnset(t *testing.T) { + t.Parallel() + + cfg := validConfig() + cfg.SSLMode = testSSLModeDisable + + pc, err := buildPoolConfig(cfg) + require.NoError(t, err) + assert.Nil(t, pc.ConnConfig.TLSConfig, "sslmode=disable must produce a nil tls.Config") +} + func TestApplyPoolTuning(t *testing.T) { t.Parallel() diff --git a/postgres/testdata_test.go b/postgres/testdata_test.go index b705d88..75c181f 100644 --- a/postgres/testdata_test.go +++ b/postgres/testdata_test.go @@ -12,4 +12,5 @@ const ( testSSLModeDisable = "disable" testErrRegionMissing = "AWS RDS IAM region is not configured" testErrRegionConfigured = "dynamicAuth.awsRdsIam.region is required" + testDatabase = "appdb" ) From c328e4f95beaaeb027ffa8bb80c421ee4b87b003 Mon Sep 17 00:00:00 2001 From: Reynier Ortiz Vega Date: Tue, 19 May 2026 09:04:43 -0400 Subject: [PATCH 3/7] Rename `MaxIdleConns` to `MinConns` to match pgxpool semantics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In database/sql, MaxIdleConns is a ceiling on connections sitting idle. In pgxpool, MinConns is a floor — the pool actively keeps that many connections open even when the application is quiet. Mapping MaxIdleConns to pgxpool.MinConns meant a developer setting MaxIdleConns: 50 was opening 50 permanent connections, the opposite of what the field name implied. Rename the field to MinConns and add a doc comment that points out the database/sql contrast for readers carrying that mental model. Since this package is still Alpha (per Stability Guarantees in CLAUDE.md), the break is cheap now and disappears later. Addresses code review feedback on #111. Co-Authored-By: Claude Opus 4.7 (1M context) --- postgres/config.go | 13 ++++++++++--- postgres/pool.go | 4 ++-- postgres/pool_test.go | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/postgres/config.go b/postgres/config.go index 906831c..7ee1dd2 100644 --- a/postgres/config.go +++ b/postgres/config.go @@ -73,9 +73,16 @@ type Config struct { // zero, pgxpool's default is used. MaxOpenConns int32 - // MaxIdleConns sets the minimum number of idle connections the pool tries - // to maintain. When zero, pgxpool's default is used. - MaxIdleConns int32 + // MinConns is the minimum number of connections pgxpool actively + // maintains in the pool — the pool keeps this many connections open + // even when the application is idle. When zero, pgxpool's default is + // used. + // + // Note for readers used to database/sql: this is the opposite of + // database/sql's MaxIdleConns (which is a ceiling on idle + // connections). pgxpool has no idle-ceiling concept; the floor is the + // only knob. + MinConns int32 // ConnMaxLifetime is the maximum lifetime of a connection. When zero, // pgxpool's default is used. diff --git a/postgres/pool.go b/postgres/pool.go index e495aef..0db5c26 100644 --- a/postgres/pool.go +++ b/postgres/pool.go @@ -137,8 +137,8 @@ func applyPoolTuning(poolConfig *pgxpool.Config, cfg *Config) { if cfg.MaxOpenConns > 0 { poolConfig.MaxConns = cfg.MaxOpenConns } - if cfg.MaxIdleConns > 0 { - poolConfig.MinConns = cfg.MaxIdleConns + if cfg.MinConns > 0 { + poolConfig.MinConns = cfg.MinConns } if cfg.ConnMaxLifetime > 0 { poolConfig.MaxConnLifetime = cfg.ConnMaxLifetime diff --git a/postgres/pool_test.go b/postgres/pool_test.go index dde9d42..e082d5e 100644 --- a/postgres/pool_test.go +++ b/postgres/pool_test.go @@ -121,7 +121,7 @@ func TestApplyPoolTuning(t *testing.T) { cfg := validConfig() cfg.MaxOpenConns = 17 - cfg.MaxIdleConns = 3 + cfg.MinConns = 3 cfg.ConnMaxLifetime = 42 * time.Minute pc, err := pgxpool.ParseConfig(cfg.ConnectionString()) From 007fd63c1d024083992cd4f7e9f59e3d343e6984 Mon Sep 17 00:00:00 2001 From: Reynier Ortiz Vega Date: Tue, 19 May 2026 09:05:10 -0400 Subject: [PATCH 4/7] Drop dead nil-check in `Config.Validate` After the early return on AWSRDSIAM == nil, the subsequent AWSRDSIAM != nil check is always true. staticcheck flags this as SA4031. Addresses code review feedback on #111. Co-Authored-By: Claude Opus 4.7 (1M context) --- postgres/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres/config.go b/postgres/config.go index 7ee1dd2..1bb7974 100644 --- a/postgres/config.go +++ b/postgres/config.go @@ -131,7 +131,7 @@ func (c *Config) Validate() error { if c.DynamicAuth.AWSRDSIAM == nil { return errors.New("dynamicAuth is set but no supported auth method (e.g., awsRdsIam) is configured") } - if c.DynamicAuth.AWSRDSIAM != nil && c.DynamicAuth.AWSRDSIAM.Region == "" { + if c.DynamicAuth.AWSRDSIAM.Region == "" { return errors.New("dynamicAuth.awsRdsIam.region is required") } } From 677615bd30a6ef94328c173a2eb5413f5958b8ea Mon Sep 17 00:00:00 2001 From: Reynier Ortiz Vega Date: Tue, 19 May 2026 09:06:01 -0400 Subject: [PATCH 5/7] Mark postgres package Alpha and add to package tables Declare stability via a top-of-file `// Stability: Alpha` marker on doc.go, mirroring the convention documented in CLAUDE.md. Add a row to the Available Packages table in README.md and the Current Packages table in CLAUDE.md so the package is discoverable through the same indexes consumers already use. Addresses code review feedback on #111. Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 1 + README.md | 1 + postgres/doc.go | 2 ++ 3 files changed, 4 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index ca00c37..d0deedf 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -73,6 +73,7 @@ task license-fix # Add missing license headers | `httperr` | Wrap errors with HTTP status codes; use `WithCode()`, `Code()`, `New()` | | `logging` | Pre-configured `*slog.Logger` factory with consistent ToolHive defaults (Alpha) | | `oci/skills` | OCI artifact types, media types, and registry operations for ToolHive skills (Alpha) | +| `postgres` | PostgreSQL connection pool with optional AWS RDS IAM dynamic auth (Alpha) | | `recovery` | HTTP panic recovery middleware (Beta) | | `validation/http` | RFC 7230/8707 compliant HTTP header and URI validation | | `validation/group` | Group name validation (lowercase alphanumeric, underscore, dash, space) | diff --git a/README.md b/README.md index 85f8620..436ecbe 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ The ToolHive ecosystem spans multiple Go repositories, and several of these proj | `httperr` | Stable | Wrap errors with HTTP status codes | | `logging` | Alpha | Pre-configured `*slog.Logger` factory with consistent ToolHive defaults | | `oci/skills` | Alpha | OCI artifact types, media types, and registry operations for skills | +| `postgres` | Alpha | PostgreSQL connection pool with optional AWS RDS IAM dynamic auth | | `recovery` | Beta | HTTP panic recovery middleware | | `validation/http` | Stable | RFC 7230/8707 compliant HTTP header and URI validation | | `validation/group` | Stable | Group name validation | diff --git a/postgres/doc.go b/postgres/doc.go index a196df4..4ffc203 100644 --- a/postgres/doc.go +++ b/postgres/doc.go @@ -1,6 +1,8 @@ // SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 +// Stability: Alpha + /* Package postgres provides a shared PostgreSQL connection layer for the ToolHive ecosystem. It wraps github.com/jackc/pgx/v5/pgxpool with a single From c9327c4aa38378cbde0742a7dbf49fa19367b99c Mon Sep 17 00:00:00 2001 From: Reynier Ortiz Vega Date: Tue, 19 May 2026 09:25:49 -0400 Subject: [PATCH 6/7] Reject `DynamicAuth` combined with `WithBeforeConnect` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously NewPool silently let WithBeforeConnect win when cfg.DynamicAuth was also set — a documented footgun whose failure mode is production tokens expiring ~15 minutes after deploy. Compose-by-default was considered but rejected: both hooks write to conn.Password, so composition order is itself a surprise. Reject the combination at construction time instead. Callers that genuinely want to layer logic on top of dynamic auth can call NewDynamicAuthFunc, compose the two hooks themselves, and pass the composition via WithBeforeConnect with cfg.DynamicAuth left nil. The ambiguity goes away and the failure mode is loud at startup rather than silent in production. Addresses code review feedback on #111. Co-Authored-By: Claude Opus 4.7 (1M context) --- postgres/config_test.go | 4 ++-- postgres/pool.go | 30 +++++++++++++++++++++++------- postgres/pool_test.go | 22 ++++++++++++++++++++++ postgres/testdata_test.go | 1 + 4 files changed, 48 insertions(+), 9 deletions(-) diff --git a/postgres/config_test.go b/postgres/config_test.go index 498df57..f294c45 100644 --- a/postgres/config_test.go +++ b/postgres/config_test.go @@ -104,7 +104,7 @@ func TestConfig_Validate(t *testing.T) { cfg: &Config{ Host: "h", Port: 5432, User: "u", Database: "d", DynamicAuth: &DynamicAuthConfig{ - AWSRDSIAM: &DynamicAuthAWSRDSIAM{Region: "us-east-1"}, + AWSRDSIAM: &DynamicAuthAWSRDSIAM{Region: testRegion}, }, }, }, @@ -198,7 +198,7 @@ func TestConfig_LogValueRedactsSecrets(t *testing.T) { MigrationPassword: "should-not-appear-either", Database: "appdb", SSLMode: "require", - DynamicAuth: &DynamicAuthConfig{AWSRDSIAM: &DynamicAuthAWSRDSIAM{Region: "us-east-1"}}, + DynamicAuth: &DynamicAuthConfig{AWSRDSIAM: &DynamicAuthAWSRDSIAM{Region: testRegion}}, } var buf bytes.Buffer diff --git a/postgres/pool.go b/postgres/pool.go index 0db5c26..988ae6a 100644 --- a/postgres/pool.go +++ b/postgres/pool.go @@ -30,9 +30,13 @@ type options struct { } // WithBeforeConnect installs a hook that runs immediately before pgx dials. -// It overrides the BeforeConnect hook that NewPool would otherwise install -// from cfg.DynamicAuth — callers that need to layer their own logic should -// invoke NewDynamicAuthFunc explicitly and combine the results. +// +// NewPool rejects a combination of WithBeforeConnect and cfg.DynamicAuth — a +// silently-replaced auth hook would leave production tokens to expire 15 +// minutes after deploy. Callers that need both must call NewDynamicAuthFunc +// explicitly, compose the two hooks themselves in the order they want, and +// pass the composed result via WithBeforeConnect (with cfg.DynamicAuth left +// nil so this package does not also try to install an auth hook). func WithBeforeConnect(fn BeforeConnectFn) Option { return func(o *options) { o.beforeConnect = fn } } @@ -50,9 +54,15 @@ func WithLogger(logger *slog.Logger) Option { return func(o *options) { o.logger = logger } } -// NewPool creates a *pgxpool.Pool from cfg. When cfg.DynamicAuth is set and -// the caller does not supply WithBeforeConnect, NewPool installs the -// appropriate dynamic-auth hook automatically. +// NewPool creates a *pgxpool.Pool from cfg. When cfg.DynamicAuth is set, +// NewPool installs the appropriate dynamic-auth hook on BeforeConnect. +// +// Passing both cfg.DynamicAuth and WithBeforeConnect is an error — the +// failure mode of a silently-replaced auth hook (tokens expiring 15 min +// after deploy) is severe enough to refuse the ambiguity. Callers that +// genuinely want to layer logic on top of dynamic auth should call +// NewDynamicAuthFunc, compose the hooks themselves, and pass the +// composition via WithBeforeConnect with cfg.DynamicAuth left nil. // // cfg is validated; cfg is not mutated. func NewPool(ctx context.Context, cfg *Config, opts ...Option) (*pgxpool.Pool, error) { @@ -72,6 +82,12 @@ func NewPool(ctx context.Context, cfg *Config, opts ...Option) (*pgxpool.Pool, e logger = slog.Default() } + if cfg.DynamicAuth != nil && o.beforeConnect != nil { + return nil, errors.New("cfg.DynamicAuth and WithBeforeConnect are mutually exclusive; " + + "to layer hooks, call NewDynamicAuthFunc, compose with your hook, " + + "and pass the composition via WithBeforeConnect with cfg.DynamicAuth = nil") + } + poolConfig, err := buildPoolConfig(cfg) if err != nil { return nil, err @@ -80,7 +96,7 @@ func NewPool(ctx context.Context, cfg *Config, opts ...Option) (*pgxpool.Pool, e applyPoolTuning(poolConfig, cfg) beforeConnect := o.beforeConnect - if beforeConnect == nil && cfg.DynamicAuth != nil { + if cfg.DynamicAuth != nil { beforeConnect, err = NewDynamicAuthFunc(ctx, cfg, cfg.User) if err != nil { return nil, err diff --git a/postgres/pool_test.go b/postgres/pool_test.go index e082d5e..27b174c 100644 --- a/postgres/pool_test.go +++ b/postgres/pool_test.go @@ -44,6 +44,28 @@ func TestNewPool_DynamicAuthMisconfigured(t *testing.T) { assert.Contains(t, err.Error(), "dynamicAuth.awsRdsIam.region is required") } +// TestNewPool_DynamicAuthAndBeforeConnectAreMutuallyExclusive verifies that +// NewPool refuses the ambiguous combination rather than silently dropping +// one hook. The failure mode of a silently-replaced auth hook — production +// tokens expiring ~15 minutes after deploy — is severe enough that we want +// a loud rejection at construction time. +func TestNewPool_DynamicAuthAndBeforeConnectAreMutuallyExclusive(t *testing.T) { + t.Parallel() + + cfg := validConfig() + cfg.SSLMode = testSSLModeDisable + cfg.DynamicAuth = &DynamicAuthConfig{ + AWSRDSIAM: &DynamicAuthAWSRDSIAM{Region: testRegion}, + } + + pool, err := NewPool(t.Context(), cfg, + WithBeforeConnect(func(context.Context, *pgx.ConnConfig) error { return nil }), + ) + require.Error(t, err) + assert.Nil(t, pool) + assert.Contains(t, err.Error(), "mutually exclusive") +} + // TestNewPool_LazyConnect verifies that NewPool returns successfully even // when the database is unreachable — pgxpool establishes connections // lazily on first Acquire, not at construction time. Dial errors surface diff --git a/postgres/testdata_test.go b/postgres/testdata_test.go index 75c181f..1f1b3bd 100644 --- a/postgres/testdata_test.go +++ b/postgres/testdata_test.go @@ -13,4 +13,5 @@ const ( testErrRegionMissing = "AWS RDS IAM region is not configured" testErrRegionConfigured = "dynamicAuth.awsRdsIam.region is required" testDatabase = "appdb" + testRegion = "us-east-1" ) From e2caa83ec03abac303588d230554aba59320cec5 Mon Sep 17 00:00:00 2001 From: Reynier Ortiz Vega Date: Tue, 19 May 2026 09:27:32 -0400 Subject: [PATCH 7/7] Rework `SSLMode` documentation; keep "require" as default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The existing comment overstated the protection of "require" by calling it "the safe production default" — sslmode=require encrypts the connection but does not verify the server certificate against a CA, so it does not defend against an active attacker who can intercept TCP. Flipping the default to "verify-full" was considered and rejected: it would silently break self-signed and dev environments on first run, and registry-server today relies on the "require" default. Drop the "safe production default" wording on DefaultSSLMode and add a TLS and SSL Mode section to the package overview that explains the tradeoff and points production deployments at "verify-full" with a CA bundle. Also update the Hooks section to describe the WithBeforeConnect/DynamicAuth mutual-exclusion that was introduced in the previous commit. Addresses code review feedback on #111. Co-Authored-By: Claude Opus 4.7 (1M context) --- postgres/config.go | 9 +++++---- postgres/doc.go | 23 ++++++++++++++++++++--- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/postgres/config.go b/postgres/config.go index 1bb7974..6ab571a 100644 --- a/postgres/config.go +++ b/postgres/config.go @@ -21,10 +21,11 @@ const ( databaseForbiddenChars = "?#/" ) -// DefaultSSLMode is applied by BuildConnectionStringWithAuth when Config.SSLMode -// is empty. "require" is the safe production default: encryption is mandatory -// but the server certificate is not validated against a CA. Callers that need -// stricter behavior should set SSLMode to "verify-ca" or "verify-full". +// DefaultSSLMode is applied when Config.SSLMode is empty. It mandates an +// encrypted connection but does not verify the server certificate against a +// CA — encryption-only, not authentication. Use "verify-ca" or "verify-full" +// when the deployment provides a CA bundle (for example, cloud Postgres +// services); see the package overview for production guidance. const DefaultSSLMode = "require" // Config configures a PostgreSQL connection pool. Password fields are diff --git a/postgres/doc.go b/postgres/doc.go index 4ffc203..b73ff04 100644 --- a/postgres/doc.go +++ b/postgres/doc.go @@ -62,9 +62,26 @@ PostgreSQL enum array types defined in the caller's schema): }), ) -WithBeforeConnect overrides the auto-installed dynamic-auth hook. Callers -that need to layer additional logic on top should call NewDynamicAuthFunc -explicitly and compose the result. +WithBeforeConnect installs a BeforeConnect hook directly. It is mutually +exclusive with Config.DynamicAuth — NewPool refuses both, because silently +dropping one would leave production tokens to expire ~15 minutes after +deploy. Callers that genuinely need both should call NewDynamicAuthFunc, +compose with their hook, and pass the composition via WithBeforeConnect +with Config.DynamicAuth left nil. + +# TLS and SSL Mode + +DefaultSSLMode is "require", which mandates an encrypted connection but +does not verify the server certificate against a CA. This defends against +passive eavesdropping; it does not defend against an active attacker that +can present a forged certificate. For production deployments against cloud +Postgres services that publish a CA bundle (for example AWS RDS, Google +Cloud SQL, Azure Database for PostgreSQL), set SSLMode to "verify-full" and +configure pgx with the appropriate CA roots — typically by placing the +bundle on disk and setting PGSSLROOTCERT, or by attaching a *tls.Config via +WithAfterConnect. "require" is kept as the package default because tighter +modes break self-signed and dev environments out of the box; production +callers should opt up explicitly. # Logging