diff --git a/internal/service/attestation.go b/internal/service/attestation.go index a71be0e..fee7730 100644 --- a/internal/service/attestation.go +++ b/internal/service/attestation.go @@ -10,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/rs/zerolog/log" + "github.com/uptrace/bun" "github.com/highflame-ai/zeroid/domain" "github.com/highflame-ai/zeroid/internal/attestation" @@ -40,6 +41,10 @@ type AttestationService struct { identitySvc *IdentityService verifiers *attestation.Registry policySvc *attestation.PolicyService + // db is the *bun.DB handle used to open transactions in + // VerifyAttestation. The repo methods themselves participate via + // postgres.WithTx(ctx, tx); the service owns the tx lifecycle. + db *bun.DB // permissive is the runtime-mutable form of cfg.Attestation.AllowUnsafeDevStub. // Stored as int32 so SetPermissive can flip it without a mutex — @@ -50,13 +55,16 @@ type AttestationService struct { // NewAttestationService creates a new AttestationService. verifiers and // policySvc are required: VerifyAttestation fails closed when no verifier // is registered for a proof type or no tenant policy exists, unless -// allowUnsafeDevStub is true (transitional bypass). +// allowUnsafeDevStub is true (transitional bypass). db is required so the +// three writes in VerifyAttestation (issue credential, promote trust, +// mark record verified) can be wrapped in a single transaction. func NewAttestationService( repo *postgres.AttestationRepository, credentialSvc *CredentialService, identitySvc *IdentityService, verifiers *attestation.Registry, policySvc *attestation.PolicyService, + db *bun.DB, allowUnsafeDevStub bool, ) *AttestationService { s := &AttestationService{ @@ -65,6 +73,7 @@ func NewAttestationService( identitySvc: identitySvc, verifiers: verifiers, policySvc: policySvc, + db: db, } s.permissive.Store(allowUnsafeDevStub) return s @@ -129,18 +138,24 @@ var ErrAttestationAlreadyVerified = errors.New("attestation already verified") // - Verifier.Verify returns an error → ErrAttestationRejected. // - Record already verified → ErrAttestationAlreadyVerified (rejects retries). // -// Write ordering rationale: credential issuance runs BEFORE identity trust -// promotion, so the most common failure (IssueCredential) leaves nothing -// committed. Trust promotion and record update run last, in that order, so -// a failure between them leaves trust promoted (harmless — backed by a -// valid proof) with the record unmarked. The re-verify guard prevents a -// second IssueCredential call in that retry window. +// Atomicity + serialization: the three side-effecting writes run inside a +// single bun.RunInTx whose first statement is SELECT ... FOR UPDATE on the +// attestation row, so concurrent /verify calls on the same record +// serialize. See the inline comments at the RunInTx site for the full +// commentary on lock order, the in-tx guard, and why the closure-scoped +// locals are assigned to the outer return values exactly once on success. func (s *AttestationService) VerifyAttestation(ctx context.Context, id, accountID, projectID string) (*VerifyAttestationResult, error) { record, err := s.repo.GetByID(ctx, id, accountID, projectID) if err != nil { return nil, err } - if record.IsVerified { + // Pre-tx fast-fail. The authoritative check happens inside the tx + // with the row locked; this one rejects the obvious already-done + // case before we run the verifier, open another DB connection for + // the tx, etc. Real saving on retry storms — failing here is one + // SELECT, failing inside the tx is BEGIN + SELECT FOR UPDATE + + // ROLLBACK plus everything between this point and the lock. + if record.IsVerified || record.CredentialID != "" { return nil, fmt.Errorf("%w: record %s", ErrAttestationAlreadyVerified, record.ID) } @@ -186,17 +201,15 @@ func (s *AttestationService) VerifyAttestation(ctx context.Context, id, accountI } } - // Load the identity without promoting yet — IssueCredential needs a - // valid, non-nil, usable identity and re-fetching guarantees we see - // the current state (another request might have deactivated it). - identity, err := s.identitySvc.GetIdentity(ctx, record.IdentityID, accountID, projectID) - if err != nil { - return nil, fmt.Errorf("failed to load identity for verified attestation: %w", err) - } - - // Step 1: issue the credential. This is the most likely failure point - // (policy checks, scope derivation, signing). Running it first means - // a failure leaves no partial state behind. + // All side-effects below happen inside RunInTx. Local closure-scoped + // vars hold the result; the outer return values are assigned exactly + // once on commit, so a rollback can't leave a partially-mutated + // record visible to the caller. + // + // Isolation: nil TxOptions means Postgres default (READ COMMITTED). + // That's sufficient because we serialize the only contended row with + // SELECT ... FOR UPDATE; we don't need REPEATABLE READ semantics + // across the whole transaction. // // GrantType is fixed to client_credentials regardless of how the // identity will subsequently authenticate. Verified attestation is a @@ -205,37 +218,87 @@ func (s *AttestationService) VerifyAttestation(ctx context.Context, id, accountI // returned token represents that boot-time trust, not a user-driven // session. Downstream flows can still token-exchange / jwt-bearer // against this credential; the bootstrap shape just doesn't change. - accessToken, cred, err := s.credentialSvc.IssueCredential(ctx, IssueRequest{ - Identity: identity, - GrantType: domain.GrantTypeClientCredentials, - }) - if err != nil { - return nil, fmt.Errorf("failed to issue post-attestation credential: %w", err) - } + var ( + accessToken *domain.AccessToken + cred *domain.IssuedCredential + verifiedRecord *domain.AttestationRecord + ) + txErr := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + ctx = postgres.WithTx(ctx, tx) - // Step 2: promote trust level. Backed by the just-verified proof. - promotedTrust := trustLevelForAttestation(record.Level) - if _, err := s.identitySvc.UpdateIdentity(ctx, record.IdentityID, accountID, projectID, UpdateIdentityRequest{ - TrustLevel: promotedTrust, - }); err != nil { - return nil, fmt.Errorf("failed to promote identity trust level: %w", err) - } + // Re-fetch with row lock. Concurrent verifies on the same record + // queue here; the second one re-reads the row after the first + // commits, sees CredentialID set, and bails out below. + // + // Lock-order note: this acquires the attestation_records row + // lock FIRST, then implicitly the identities row lock (via + // UpdateIdentity's UPDATE below). Any future code path that + // needs to hold both locks should acquire them in the same order + // to avoid deadlocks. Postgres detects deadlocks (40P01) and + // aborts one tx, so the worst case is a transient retry, not + // silent corruption — but cleaner to keep the order consistent. + // + // The repo method's error already names what failed; no outer + // wrap so the operator-visible message stays accurate even when + // the inner error is the WithTx contract violation rather than + // a runtime DB problem. + locked, err := s.repo.GetByIDForUpdate(ctx, id, accountID, projectID) + if err != nil { + return err + } + if locked.IsVerified || locked.CredentialID != "" { + return fmt.Errorf("%w: record %s", ErrAttestationAlreadyVerified, locked.ID) + } - // Step 3: commit the record with verified flag, audit fields, and - // credential link in a single write. - now := time.Now() - record.IsVerified = true - record.VerifiedAt = &now - if result.ExpiresAt != nil { - record.ExpiresAt = result.ExpiresAt - } - record.CredentialID = cred.ID - if err := s.repo.Update(ctx, record); err != nil { - return nil, fmt.Errorf("failed to update attestation record: %w", err) + // Re-load the identity inside the tx so we don't act on a stale + // snapshot from before someone deactivated it. With READ COMMITTED + // this read sees committed state as of the statement, so a + // concurrent UpdateIdentity that already finished will be visible. + identity, err := s.identitySvc.GetIdentity(ctx, locked.IdentityID, accountID, projectID) + if err != nil { + return fmt.Errorf("failed to load identity for verified attestation: %w", err) + } + + issued, issuedCred, err := s.credentialSvc.IssueCredential(ctx, IssueRequest{ + Identity: identity, + GrantType: domain.GrantTypeClientCredentials, + }) + if err != nil { + return fmt.Errorf("failed to issue post-attestation credential: %w", err) + } + + promotedTrust := trustLevelForAttestation(locked.Level) + if _, err := s.identitySvc.UpdateIdentity(ctx, locked.IdentityID, accountID, projectID, UpdateIdentityRequest{ + TrustLevel: promotedTrust, + }); err != nil { + return fmt.Errorf("failed to promote identity trust level: %w", err) + } + + now := time.Now() + locked.IsVerified = true + locked.VerifiedAt = &now + if result.ExpiresAt != nil { + locked.ExpiresAt = result.ExpiresAt + } + locked.CredentialID = issuedCred.ID + if err := s.repo.Update(ctx, locked); err != nil { + return fmt.Errorf("failed to update attestation record: %w", err) + } + + // Promote local-success values to the outer scope. Last step in + // the closure so a return-error path above never assigns a + // partial result. + accessToken = issued + cred = issuedCred + verifiedRecord = locked + return nil + }) + if txErr != nil { + return nil, txErr } return &VerifyAttestationResult{ - Record: record, + Record: verifiedRecord, AccessToken: accessToken, Credential: cred, }, nil diff --git a/internal/store/postgres/attestation.go b/internal/store/postgres/attestation.go index 1f22edf..30b7d10 100644 --- a/internal/store/postgres/attestation.go +++ b/internal/store/postgres/attestation.go @@ -21,7 +21,8 @@ func NewAttestationRepository(db *bun.DB) *AttestationRepository { // Create inserts a new attestation record. func (r *AttestationRepository) Create(ctx context.Context, record *domain.AttestationRecord) error { - _, err := r.db.NewInsert().Model(record).Exec(ctx) + db := dbOrTx(ctx, r.db) + _, err := db.NewInsert().Model(record).Exec(ctx) if err != nil { return fmt.Errorf("failed to create attestation record: %w", err) } @@ -31,7 +32,8 @@ func (r *AttestationRepository) Create(ctx context.Context, record *domain.Attes // GetByID retrieves an attestation record by its UUID. func (r *AttestationRepository) GetByID(ctx context.Context, id, accountID, projectID string) (*domain.AttestationRecord, error) { record := &domain.AttestationRecord{} - err := r.db.NewSelect().Model(record). + db := dbOrTx(ctx, r.db) + err := db.NewSelect().Model(record). Where("id = ?", id). Where("account_id = ?", accountID). Where("project_id = ?", projectID). @@ -42,11 +44,47 @@ func (r *AttestationRepository) GetByID(ctx context.Context, id, accountID, proj return record, nil } +// GetByIDForUpdate retrieves an attestation record by its UUID and acquires +// a row-level write lock (Postgres SELECT ... FOR UPDATE). The lock is held +// until the surrounding transaction commits or rolls back, so this method +// MUST be called inside a transaction (postgres.WithTx). +// +// Outside an explicit transaction Postgres still executes the SELECT +// successfully, but the implicit per-statement transaction commits as +// soon as the statement returns and the lock is released — concurrent +// callers see no useful serialization. We fail fast here rather than +// downgrade silently: misuse should surface as a loud error, not a +// race that only manifests under load. +// +// Use this in flows where the same attestation must serialize against +// concurrent writers — most notably AttestationService.VerifyAttestation, +// where two simultaneous /verify calls on the same record could otherwise +// each pass the IsVerified guard, each issue a credential, and leave the +// DB with two credentials from one proof. +func (r *AttestationRepository) GetByIDForUpdate(ctx context.Context, id, accountID, projectID string) (*domain.AttestationRecord, error) { + if !hasTx(ctx) { + return nil, fmt.Errorf("GetByIDForUpdate must be called inside a postgres.WithTx context — the row lock is meaningless without one") + } + record := &domain.AttestationRecord{} + db := dbOrTx(ctx, r.db) + err := db.NewSelect().Model(record). + Where("id = ?", id). + Where("account_id = ?", accountID). + Where("project_id = ?", projectID). + For("UPDATE"). + Scan(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get attestation record for update: %w", err) + } + return record, nil +} + // GetHighestVerifiedLevel returns the highest verified attestation level for an identity. // Returns an empty string if no verified attestation exists. func (r *AttestationRepository) GetHighestVerifiedLevel(ctx context.Context, identityID string) (string, error) { var level string - err := r.db.NewSelect(). + db := dbOrTx(ctx, r.db) + err := db.NewSelect(). TableExpr("attestation_records"). ColumnExpr("level"). Where("identity_id = ?", identityID). @@ -66,8 +104,11 @@ func (r *AttestationRepository) GetHighestVerifiedLevel(ctx context.Context, ide } // Update saves changes to an attestation record (e.g., mark as verified). +// Participates in a caller-provided transaction via postgres.WithTx(ctx, tx); +// falls through to a single auto-commit update otherwise. func (r *AttestationRepository) Update(ctx context.Context, record *domain.AttestationRecord) error { - _, err := r.db.NewUpdate().Model(record). + db := dbOrTx(ctx, r.db) + _, err := db.NewUpdate().Model(record). Where("id = ?", record.ID). Exec(ctx) if err != nil { diff --git a/internal/store/postgres/credential.go b/internal/store/postgres/credential.go index 2c7a770..c9cf488 100644 --- a/internal/store/postgres/credential.go +++ b/internal/store/postgres/credential.go @@ -20,9 +20,12 @@ func NewCredentialRepository(db *bun.DB) *CredentialRepository { return &CredentialRepository{db: db} } -// Create inserts a new issued credential. +// Create inserts a new issued credential. Participates in a caller-provided +// transaction via postgres.WithTx(ctx, tx); falls through to a single +// auto-commit insert otherwise. func (r *CredentialRepository) Create(ctx context.Context, cred *domain.IssuedCredential) error { - _, err := r.db.NewInsert().Model(cred).Exec(ctx) + db := dbOrTx(ctx, r.db) + _, err := db.NewInsert().Model(cred).Exec(ctx) if err != nil { return fmt.Errorf("failed to create credential: %w", err) } @@ -32,7 +35,8 @@ func (r *CredentialRepository) Create(ctx context.Context, cred *domain.IssuedCr // GetByID retrieves a credential by its UUID. func (r *CredentialRepository) GetByID(ctx context.Context, id, accountID, projectID string) (*domain.IssuedCredential, error) { cred := &domain.IssuedCredential{} - err := r.db.NewSelect().Model(cred). + db := dbOrTx(ctx, r.db) + err := db.NewSelect().Model(cred). Where("id = ?", id). Where("account_id = ?", accountID). Where("project_id = ?", projectID). @@ -46,7 +50,8 @@ func (r *CredentialRepository) GetByID(ctx context.Context, id, accountID, proje // GetByJTI retrieves a credential by its JWT ID (jti claim). func (r *CredentialRepository) GetByJTI(ctx context.Context, jti string) (*domain.IssuedCredential, error) { cred := &domain.IssuedCredential{} - err := r.db.NewSelect().Model(cred). + db := dbOrTx(ctx, r.db) + err := db.NewSelect().Model(cred). Where("jti = ?", jti). Scan(ctx) if err != nil { @@ -58,7 +63,8 @@ func (r *CredentialRepository) GetByJTI(ctx context.Context, jti string) (*domai // ListByIdentity returns all credentials for a given identity. func (r *CredentialRepository) ListByIdentity(ctx context.Context, identityID, accountID, projectID string) ([]*domain.IssuedCredential, error) { var creds []*domain.IssuedCredential - err := r.db.NewSelect().Model(&creds). + db := dbOrTx(ctx, r.db) + err := db.NewSelect().Model(&creds). Where("identity_id = ?", identityID). Where("account_id = ?", accountID). Where("project_id = ?", projectID). @@ -79,7 +85,8 @@ func (r *CredentialRepository) ListByIdentity(ctx context.Context, identityID, a func (r *CredentialRepository) RevokeAllActiveForIdentity(ctx context.Context, identityID, reason string) (int64, error) { now := time.Now() var count int64 - if err := r.db.NewRaw( + db := dbOrTx(ctx, r.db) + if err := db.NewRaw( "SELECT revoke_credentials_cascade(?, ?, ?)", identityID, now, reason, ).Scan(ctx, &count); err != nil { @@ -95,7 +102,8 @@ func (r *CredentialRepository) RevokeAllActiveForIdentity(ctx context.Context, i func (r *CredentialRepository) Revoke(ctx context.Context, id, accountID, projectID, reason string) error { now := time.Now() var count int64 - if err := r.db.NewRaw( + db := dbOrTx(ctx, r.db) + if err := db.NewRaw( "SELECT revoke_credential_cascade(?, ?, ?, ?, ?)", id, accountID, projectID, now, reason, ).Scan(ctx, &count); err != nil { diff --git a/internal/store/postgres/identity.go b/internal/store/postgres/identity.go index 021d4b0..0ef8509 100644 --- a/internal/store/postgres/identity.go +++ b/internal/store/postgres/identity.go @@ -25,7 +25,8 @@ func NewIdentityRepository(db *bun.DB) *IdentityRepository { // Create inserts a new identity. func (r *IdentityRepository) Create(ctx context.Context, identity *domain.Identity) error { - _, err := r.db.NewInsert().Model(identity).Exec(ctx) + db := dbOrTx(ctx, r.db) + _, err := db.NewInsert().Model(identity).Exec(ctx) if err != nil { return fmt.Errorf("failed to create identity: %w", err) } @@ -35,7 +36,8 @@ func (r *IdentityRepository) Create(ctx context.Context, identity *domain.Identi // GetByID retrieves an identity by its UUID, scoped to account + project. func (r *IdentityRepository) GetByID(ctx context.Context, id, accountID, projectID string) (*domain.Identity, error) { identity := &domain.Identity{} - err := r.db.NewSelect().Model(identity). + db := dbOrTx(ctx, r.db) + err := db.NewSelect().Model(identity). Where("id = ?", id). Where("account_id = ?", accountID). Where("project_id = ?", projectID). @@ -49,7 +51,8 @@ func (r *IdentityRepository) GetByID(ctx context.Context, id, accountID, project // GetByExternalID retrieves an identity by external ID within a tenant. func (r *IdentityRepository) GetByExternalID(ctx context.Context, externalID, accountID, projectID string) (*domain.Identity, error) { identity := &domain.Identity{} - err := r.db.NewSelect().Model(identity). + db := dbOrTx(ctx, r.db) + err := db.NewSelect().Model(identity). Where("external_id = ?", externalID). Where("account_id = ?", accountID). Where("project_id = ?", projectID). @@ -63,7 +66,8 @@ func (r *IdentityRepository) GetByExternalID(ctx context.Context, externalID, ac // GetByWIMSEURI retrieves an identity by its WIMSE URI, scoped to tenant. func (r *IdentityRepository) GetByWIMSEURI(ctx context.Context, wimseURI, accountID, projectID string) (*domain.Identity, error) { identity := &domain.Identity{} - err := r.db.NewSelect().Model(identity). + db := dbOrTx(ctx, r.db) + err := db.NewSelect().Model(identity). Where("wimse_uri = ?", wimseURI). Where("account_id = ?", accountID). Where("project_id = ?", projectID). @@ -79,7 +83,8 @@ func (r *IdentityRepository) GetByWIMSEURI(ctx context.Context, wimseURI, accoun // and filters using JSONB containment: labels @> {"key": "value"}. func (r *IdentityRepository) List(ctx context.Context, accountID, projectID string, identityTypes []string, label, trustLevel, isActive, search string, limit, offset int) ([]*domain.Identity, int, error) { var identities []*domain.Identity - q := r.db.NewSelect().Model(&identities). + db := dbOrTx(ctx, r.db) + q := db.NewSelect().Model(&identities). Where("account_id = ?", accountID). Where("project_id = ?", projectID). OrderExpr("created_at DESC") @@ -132,10 +137,13 @@ func (r *IdentityRepository) List(ctx context.Context, accountID, projectID stri return identities, total, nil } -// Update saves changes to an existing identity. +// Update saves changes to an existing identity. Participates in a caller- +// provided transaction via postgres.WithTx(ctx, tx); falls through to a +// single auto-commit update otherwise. func (r *IdentityRepository) Update(ctx context.Context, identity *domain.Identity) error { identity.ModifiedBy = middleware.GetCallerName(ctx) - _, err := r.db.NewUpdate().Model(identity). + db := dbOrTx(ctx, r.db) + _, err := db.NewUpdate().Model(identity). Where("id = ? AND account_id = ? AND project_id = ?", identity.ID, identity.AccountID, identity.ProjectID). Exec(ctx) if err != nil { @@ -145,16 +153,26 @@ func (r *IdentityRepository) Update(ctx context.Context, identity *domain.Identi } // Delete removes an identity. +// +// The pre-DELETE UPDATE stamps modified_by so the AFTER DELETE trigger can +// read the actor from OLD.modified_by. Its error is propagated rather than +// swallowed: in Postgres, a failed statement inside a transaction aborts +// the whole tx, so the subsequent DELETE would fail with a generic +// "current transaction is aborted" message that loses the original cause. +// Outside a tx the same propagation just makes a benign-looking failure +// loud — that's still preferable to silently triggering audit gaps. func (r *IdentityRepository) Delete(ctx context.Context, id, accountID, projectID string) error { - // Pre-stamp modified_by so the AFTER DELETE trigger can read the actor from OLD.modified_by. + db := dbOrTx(ctx, r.db) if callerID := middleware.GetCallerName(ctx); callerID != "" { - _, _ = r.db.NewUpdate(). + if _, err := db.NewUpdate(). TableExpr("identities"). Set("modified_by = ?", callerID). Where("id = ? AND account_id = ? AND project_id = ?", id, accountID, projectID). - Exec(ctx) + Exec(ctx); err != nil { + return fmt.Errorf("failed to stamp modified_by before delete: %w", err) + } } - _, err := r.db.NewDelete(). + _, err := db.NewDelete(). TableExpr("identities"). Where("id = ?", id). Where("account_id = ?", accountID). diff --git a/internal/store/postgres/tx.go b/internal/store/postgres/tx.go new file mode 100644 index 0000000..644dfba --- /dev/null +++ b/internal/store/postgres/tx.go @@ -0,0 +1,66 @@ +package postgres + +import ( + "context" + + "github.com/uptrace/bun" +) + +// txKey is the context-value key for an in-flight transaction. Unexported +// + struct{}-typed so external packages can't collide with us in the same +// context tree. +type txKey struct{} + +// WithTx attaches a transaction to ctx. Repo methods that call +// dbOrTx pick the transaction up automatically. Use this together with +// bun.DB.RunInTx in services that need to coordinate writes across multiple +// repos atomically: +// +// err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +// ctx = postgres.WithTx(ctx, tx) +// if err := s.fooRepo.Create(ctx, foo); err != nil { return err } +// if err := s.barRepo.Update(ctx, bar); err != nil { return err } +// return nil +// }) +// +// Repo callers that don't open a transaction get the repo's own *bun.DB +// handle and behave as before (auto-commit per statement). +// +// Nesting: calling WithTx on a context that already carries a tx +// REPLACES the parent tx for any descendant ctx — there is no automatic +// savepoint or merge. Postgres doesn't support truly nested transactions +// anyway, and bun.RunInTx returns the existing tx when called recursively. +// For most flows the right pattern is "open one tx at the outermost +// service call, attach it once, do all the work inside" — don't nest. +func WithTx(ctx context.Context, tx bun.Tx) context.Context { + return context.WithValue(ctx, txKey{}, tx) +} + +// txFromContext extracts the in-flight transaction attached by WithTx. +// Returns the tx and true when one is set; the zero bun.Tx and false +// otherwise. Used as the shared core of dbOrTx and hasTx so the +// type-assertion key and shape live in exactly one place. +func txFromContext(ctx context.Context) (bun.Tx, bool) { + tx, ok := ctx.Value(txKey{}).(bun.Tx) + return tx, ok +} + +// dbOrTx returns the in-flight transaction from ctx if one is set, falling +// back to the repo's default DB handle. Repo methods that participate in +// transactions call this once at the top and use the returned bun.IDB for +// every statement in the method. +func dbOrTx(ctx context.Context, fallback bun.IDB) bun.IDB { + if tx, ok := txFromContext(ctx); ok { + return tx + } + return fallback +} + +// hasTx reports whether ctx carries an in-flight transaction attached +// with WithTx. Repo methods that ONLY make sense inside a transaction +// (e.g., SELECT ... FOR UPDATE callers) use this to fail fast on misuse +// instead of silently downgrading to a per-statement implicit tx. +func hasTx(ctx context.Context) bool { + _, ok := txFromContext(ctx) + return ok +} diff --git a/server.go b/server.go index 7712b27..7f1996f 100644 --- a/server.go +++ b/server.go @@ -183,7 +183,7 @@ func NewServer(cfg Config) (*Server, error) { signalSvc := service.NewSignalService(signalRepo, credentialRepo, identityRepo) identitySvc := service.NewIdentityService(identityRepo, credentialPolicySvc, apiKeyRepo, credentialSvc, signalSvc, cfg.WIMSEDomain) attestationPolicySvc := attestation.NewPolicyService(attestationPolicyRepo, attestationVerifiers) - attestationSvc := service.NewAttestationService(attestationRepo, credentialSvc, identitySvc, attestationVerifiers, attestationPolicySvc, cfg.Attestation.AllowUnsafeDevStub) + attestationSvc := service.NewAttestationService(attestationRepo, credentialSvc, identitySvc, attestationVerifiers, attestationPolicySvc, db, cfg.Attestation.AllowUnsafeDevStub) oauthClientSvc := service.NewOAuthClientService(oauthClientRepo) apiKeySvc := service.NewAPIKeyService(apiKeyRepo, credentialPolicySvc, identitySvc) refreshTokenSvc := service.NewRefreshTokenService(refreshTokenRepo, db) diff --git a/tests/integration/attestation_oidc_test.go b/tests/integration/attestation_oidc_test.go index bebb6b4..52b5e41 100644 --- a/tests/integration/attestation_oidc_test.go +++ b/tests/integration/attestation_oidc_test.go @@ -2,6 +2,7 @@ package integration_test import ( "bytes" + "context" "crypto/rand" "crypto/rsa" "encoding/base64" @@ -9,6 +10,7 @@ import ( "io" "net/http" "net/http/httptest" + "sync" "testing" "time" @@ -327,6 +329,140 @@ func TestAttestationDoubleVerifyIsRejected(t *testing.T) { assertErrorBodyContains(t, second, "already verified") } +// TestAttestationConcurrentVerifyMintsExactlyOneCredential pins the +// row-lock added in the #98 transaction wrap. Without SELECT ... FOR +// UPDATE on the attestation record, two simultaneous /verify calls +// could each pass the IsVerified guard, each enter the tx, each +// IssueCredential, and leave the DB with two credentials minted from +// one proof. The lock serializes the second verify behind the first; +// when it acquires the lock the record already has CredentialID set and +// it bails out via ErrAttestationAlreadyVerified. +// +// The goroutines below intentionally use raw http.NewRequest / +// http.DefaultClient.Do rather than the verifyAttestation helper. That +// helper's call chain ends in require.NoError, which calls t.FailNow — +// per the testing package contract FailNow MUST be called from the +// goroutine running the test, not workers. Doing the HTTP call inline +// keeps every assertion on the main goroutine after wg.Wait(). +func TestAttestationConcurrentVerifyMintsExactlyOneCredential(t *testing.T) { + iss := newOIDCIssuer(t) + defer iss.close() + + reg := registerAgent(t, uid("attest-race")) + upsertOIDCPolicy(t, map[string]any{ + "issuers": []map[string]any{{"url": iss.URL}}, + }) + + token := iss.sign(map[string]any{"sub": "ci-job-race"}) + id := submitAttestation(t, reg.AgentID, "oidc_token", token) + + body, err := json.Marshal(map[string]any{"attestation_id": id}) + require.NoError(t, err) + url := testServer.URL + adminPath("/attestation/verify") + headers := adminHeaders() + + const N = 8 + results := make([]struct { + status int + err error + }, N) + + var wg sync.WaitGroup + wg.Add(N) + for i := range N { + go func(slot int) { + defer wg.Done() + req, reqErr := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) + if reqErr != nil { + results[slot].err = reqErr + return + } + req.Header.Set("Content-Type", "application/json") + for k, v := range headers { + req.Header.Set(k, v) + } + resp, doErr := http.DefaultClient.Do(req) + if doErr != nil { + results[slot].err = doErr + return + } + results[slot].status = resp.StatusCode + _ = resp.Body.Close() + }(i) + } + wg.Wait() + + // Exactly one verify wins, the rest see ErrAttestationAlreadyVerified + // once they acquire the row lock and re-check the guard. + var ok, conflict int + for slot, r := range results { + require.NoErrorf(t, r.err, "goroutine %d failed transport-level: %v", slot, r.err) + switch r.status { + case http.StatusOK: + ok++ + case http.StatusConflict: + conflict++ + default: + t.Errorf("goroutine %d: unexpected status %d (want 200 or 409)", slot, r.status) + } + } + assert.Equal(t, 1, ok, "exactly one concurrent verify must succeed") + assert.Equal(t, N-1, conflict, "all other concurrent verifies must be rejected as already-verified") + + // Hard guarantee: the credential count for the identity is exactly 1. + count, err := testDB.NewSelect(). + Table("issued_credentials"). + Where("identity_id = ?", reg.AgentID). + Count(context.Background()) + require.NoError(t, err) + assert.Equal(t, 1, count, + "exactly one credential must be persisted for the attested identity — the row lock must serialize concurrent verifies") +} + +// TestAttestationVerifyGuardCatchesPartialFailureRetry pins the +// defense-in-depth guard added alongside the transaction wrap in #98: +// the three writes in VerifyAttestation now run in a single bun.RunInTx, +// so a partial-state record (CredentialID set, IsVerified=false) is +// no longer reachable through the normal flow. But the guard at the +// top of the function still trips on CredentialID != "" so a record +// in that state — reachable only via direct DB manipulation, a future +// code path that bypasses the verify flow, or a hypothetical +// commit-then-error-return bug — does not get a second credential +// minted. We plant the state by hand to exercise the guard. +func TestAttestationVerifyGuardCatchesPartialFailureRetry(t *testing.T) { + iss := newOIDCIssuer(t) + defer iss.close() + + reg := registerAgent(t, uid("attest-partial")) + upsertOIDCPolicy(t, map[string]any{ + "issuers": []map[string]any{{"url": iss.URL}}, + }) + + token := iss.sign(map[string]any{"sub": "ci-job-partial"}) + id := submitAttestation(t, reg.AgentID, "oidc_token", token) + + first := verifyAttestation(t, id) + require.Equal(t, http.StatusOK, first.StatusCode, "first verify expected 200") + _ = first.Body.Close() + + // Simulate Step 2 / Step 3 failure: rewind IsVerified to false but + // leave CredentialID intact. This is the exact state the guard is + // meant to catch — a retry would mint a second credential without it. + _, err := testDB.NewUpdate(). + Table("attestation_records"). + Set("is_verified = false"). + Set("verified_at = NULL"). + Where("id = ?", id). + Exec(context.Background()) + require.NoError(t, err, "failed to plant partial-failure state") + + second := verifyAttestation(t, id) + defer func() { _ = second.Body.Close() }() + assert.Equal(t, http.StatusConflict, second.StatusCode, + "retry on a record with credential_id set must be rejected even when is_verified=false") + assertErrorBodyContains(t, second, "already verified") +} + // TestAttestationPolicyUpsertReactivatesDisabled verifies the upsert-against- // inactive-row bug is fixed: disabling a policy via is_active=false and then // PUTting a fresh config must update the row in place, not violate the diff --git a/tests/integration/postgres_tx_test.go b/tests/integration/postgres_tx_test.go new file mode 100644 index 0000000..15a262e --- /dev/null +++ b/tests/integration/postgres_tx_test.go @@ -0,0 +1,160 @@ +package integration_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + + "github.com/highflame-ai/zeroid/domain" + "github.com/highflame-ai/zeroid/internal/store/postgres" +) + +// TestPostgresWithTxRollbackPersistsNothing pins the foundational invariant +// behind the #98 transaction wrap in AttestationService.VerifyAttestation: +// when a closure passed to bun.RunInTx returns a non-nil error, every +// participating repo write rolls back. Without this guarantee, the +// attestation atomicity claim collapses — Step 1's IssueCredential could +// commit even when Step 2 or 3 fails inside the closure. +// +// This test exercises the WithTx + dbOrTx mechanism end-to-end against +// the real Postgres testcontainer, separately from the higher-level +// VerifyAttestation flow that uses it. We run an Identity insert + an +// IssuedCredential insert inside a tx, force an error after both writes, +// and confirm neither row landed in the DB. +func TestPostgresWithTxRollbackPersistsNothing(t *testing.T) { + ctx := context.Background() + identityRepo := postgres.NewIdentityRepository(testDB) + credRepo := postgres.NewCredentialRepository(testDB) + + identityID := uuid.NewString() + credID := uuid.NewString() + jti := "tx-rollback-" + uuid.NewString() + + rollbackErr := errors.New("force rollback") + + txErr := testDB.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + ctx = postgres.WithTx(ctx, tx) + + identity := &domain.Identity{ + ID: identityID, + AccountID: testAccountID, + ProjectID: testProjectID, + ExternalID: "tx-rollback-id-" + identityID[:8], + Name: "tx rollback fixture", + IdentityType: domain.IdentityTypeAgent, + TrustLevel: domain.TrustLevelUnverified, + Status: domain.IdentityStatusActive, + WIMSEURI: "spiffe://test/" + identityID, + AllowedScopes: []string{}, + Capabilities: []byte("{}"), + Labels: []byte("{}"), + Metadata: []byte("{}"), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + if err := identityRepo.Create(ctx, identity); err != nil { + return err + } + + cred := &domain.IssuedCredential{ + ID: credID, + IdentityID: &identityID, + AccountID: testAccountID, + ProjectID: testProjectID, + JTI: jti, + Subject: identity.WIMSEURI, + Scopes: []string{}, + IssuedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Hour), + TTLSeconds: 3600, + GrantType: domain.GrantTypeClientCredentials, + } + if err := credRepo.Create(ctx, cred); err != nil { + return err + } + + return rollbackErr + }) + + require.Error(t, txErr) + require.ErrorIs(t, txErr, rollbackErr, + "the closure's error must surface to the caller — RunInTx should not swallow it") + + // Foundational claim: nothing committed. + identityCount, err := testDB.NewSelect(). + Table("identities"). + Where("id = ?", identityID). + Count(ctx) + require.NoError(t, err) + assert.Zero(t, identityCount, "identity row must not exist after rollback") + + credCount, err := testDB.NewSelect(). + Table("issued_credentials"). + Where("id = ?", credID). + Count(ctx) + require.NoError(t, err) + assert.Zero(t, credCount, "credential row must not exist after rollback") +} + +// TestGetByIDForUpdateRequiresTx pins the contract that GetByIDForUpdate +// fails fast when called without a postgres.WithTx context. Without this +// guard, a future caller that forgets to open a transaction would +// silently downgrade to a per-statement implicit tx — the SELECT FOR +// UPDATE acquires the lock and immediately releases it on the implicit +// commit, providing no useful serialization. The bug only manifests +// under concurrent load. Loud failure here = caught at code-review time +// instead of in production. +func TestGetByIDForUpdateRequiresTx(t *testing.T) { + repo := postgres.NewAttestationRepository(testDB) + _, err := repo.GetByIDForUpdate(context.Background(), uuid.NewString(), testAccountID, testProjectID) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be called inside", + "the contract violation message must name the missing WithTx so a future debugger sees the cause") +} + +// TestPostgresWithoutTxFallsBackToAutoCommit pins the other half of the +// dbOrTx contract: when no tx is attached to ctx, repo writes use the +// repo's default *bun.DB handle and auto-commit per statement, exactly +// as before the transaction work was introduced. Without this property, +// every existing call site of these repos would have changed behavior. +func TestPostgresWithoutTxFallsBackToAutoCommit(t *testing.T) { + ctx := context.Background() + identityRepo := postgres.NewIdentityRepository(testDB) + + identityID := uuid.NewString() + identity := &domain.Identity{ + ID: identityID, + AccountID: testAccountID, + ProjectID: testProjectID, + ExternalID: "tx-fallback-id-" + identityID[:8], + Name: "tx fallback fixture", + IdentityType: domain.IdentityTypeAgent, + TrustLevel: domain.TrustLevelUnverified, + Status: domain.IdentityStatusActive, + WIMSEURI: "spiffe://test-fallback/" + identityID, + AllowedScopes: []string{}, + Capabilities: []byte("{}"), + Labels: []byte("{}"), + Metadata: []byte("{}"), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + require.NoError(t, identityRepo.Create(ctx, identity), + "Create with no tx in ctx must succeed via auto-commit") + + count, err := testDB.NewSelect(). + Table("identities"). + Where("id = ?", identityID). + Count(ctx) + require.NoError(t, err) + assert.Equal(t, 1, count, "identity row must be persisted by the auto-commit path") + + // Tidy. No tx so this also auto-commits. + require.NoError(t, identityRepo.Delete(ctx, identityID, testAccountID, testProjectID)) +}