Skip to content
Merged
153 changes: 108 additions & 45 deletions internal/service/attestation.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/uptrace/bun"

"github.com/highflame-ai/zeroid/domain"
"github.com/highflame-ai/zeroid/internal/attestation"
Expand Down Expand Up @@ -40,6 +41,10 @@ type AttestationService struct {
identitySvc *IdentityService
verifiers *attestation.Registry
policySvc *attestation.PolicyService
// db is the *bun.DB handle used to open transactions in
// VerifyAttestation. The repo methods themselves participate via
// postgres.WithTx(ctx, tx); the service owns the tx lifecycle.
db *bun.DB

// permissive is the runtime-mutable form of cfg.Attestation.AllowUnsafeDevStub.
// Stored as int32 so SetPermissive can flip it without a mutex —
Expand All @@ -50,13 +55,16 @@ type AttestationService struct {
// NewAttestationService creates a new AttestationService. verifiers and
// policySvc are required: VerifyAttestation fails closed when no verifier
// is registered for a proof type or no tenant policy exists, unless
// allowUnsafeDevStub is true (transitional bypass).
// allowUnsafeDevStub is true (transitional bypass). db is required so the
// three writes in VerifyAttestation (issue credential, promote trust,
// mark record verified) can be wrapped in a single transaction.
func NewAttestationService(
repo *postgres.AttestationRepository,
credentialSvc *CredentialService,
identitySvc *IdentityService,
verifiers *attestation.Registry,
policySvc *attestation.PolicyService,
db *bun.DB,
allowUnsafeDevStub bool,
) *AttestationService {
s := &AttestationService{
Expand All @@ -65,6 +73,7 @@ func NewAttestationService(
identitySvc: identitySvc,
verifiers: verifiers,
policySvc: policySvc,
db: db,
}
s.permissive.Store(allowUnsafeDevStub)
return s
Expand Down Expand Up @@ -129,18 +138,24 @@ var ErrAttestationAlreadyVerified = errors.New("attestation already verified")
// - Verifier.Verify returns an error → ErrAttestationRejected.
// - Record already verified → ErrAttestationAlreadyVerified (rejects retries).
//
// Write ordering rationale: credential issuance runs BEFORE identity trust
// promotion, so the most common failure (IssueCredential) leaves nothing
// committed. Trust promotion and record update run last, in that order, so
// a failure between them leaves trust promoted (harmless — backed by a
// valid proof) with the record unmarked. The re-verify guard prevents a
// second IssueCredential call in that retry window.
// Atomicity + serialization: the three side-effecting writes run inside a
// single bun.RunInTx whose first statement is SELECT ... FOR UPDATE on the
// attestation row, so concurrent /verify calls on the same record
// serialize. See the inline comments at the RunInTx site for the full
// commentary on lock order, the in-tx guard, and why the closure-scoped
// locals are assigned to the outer return values exactly once on success.
func (s *AttestationService) VerifyAttestation(ctx context.Context, id, accountID, projectID string) (*VerifyAttestationResult, error) {
record, err := s.repo.GetByID(ctx, id, accountID, projectID)
if err != nil {
return nil, err
}
if record.IsVerified {
// Pre-tx fast-fail. The authoritative check happens inside the tx
// with the row locked; this one rejects the obvious already-done
// case before we run the verifier, open another DB connection for
// the tx, etc. Real saving on retry storms — failing here is one
// SELECT, failing inside the tx is BEGIN + SELECT FOR UPDATE +
// ROLLBACK plus everything between this point and the lock.
if record.IsVerified || record.CredentialID != "" {
return nil, fmt.Errorf("%w: record %s", ErrAttestationAlreadyVerified, record.ID)
}

Expand Down Expand Up @@ -186,17 +201,15 @@ func (s *AttestationService) VerifyAttestation(ctx context.Context, id, accountI
}
}

// Load the identity without promoting yet — IssueCredential needs a
// valid, non-nil, usable identity and re-fetching guarantees we see
// the current state (another request might have deactivated it).
identity, err := s.identitySvc.GetIdentity(ctx, record.IdentityID, accountID, projectID)
if err != nil {
return nil, fmt.Errorf("failed to load identity for verified attestation: %w", err)
}

// Step 1: issue the credential. This is the most likely failure point
// (policy checks, scope derivation, signing). Running it first means
// a failure leaves no partial state behind.
// All side-effects below happen inside RunInTx. Local closure-scoped
// vars hold the result; the outer return values are assigned exactly
// once on commit, so a rollback can't leave a partially-mutated
// record visible to the caller.
//
// Isolation: nil TxOptions means Postgres default (READ COMMITTED).
// That's sufficient because we serialize the only contended row with
// SELECT ... FOR UPDATE; we don't need REPEATABLE READ semantics
// across the whole transaction.
//
// GrantType is fixed to client_credentials regardless of how the
// identity will subsequently authenticate. Verified attestation is a
Expand All @@ -205,37 +218,87 @@ func (s *AttestationService) VerifyAttestation(ctx context.Context, id, accountI
// returned token represents that boot-time trust, not a user-driven
// session. Downstream flows can still token-exchange / jwt-bearer
// against this credential; the bootstrap shape just doesn't change.
accessToken, cred, err := s.credentialSvc.IssueCredential(ctx, IssueRequest{
Identity: identity,
GrantType: domain.GrantTypeClientCredentials,
})
if err != nil {
return nil, fmt.Errorf("failed to issue post-attestation credential: %w", err)
}
var (
accessToken *domain.AccessToken
cred *domain.IssuedCredential
verifiedRecord *domain.AttestationRecord
)
txErr := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
ctx = postgres.WithTx(ctx, tx)

// Step 2: promote trust level. Backed by the just-verified proof.
promotedTrust := trustLevelForAttestation(record.Level)
if _, err := s.identitySvc.UpdateIdentity(ctx, record.IdentityID, accountID, projectID, UpdateIdentityRequest{
TrustLevel: promotedTrust,
}); err != nil {
return nil, fmt.Errorf("failed to promote identity trust level: %w", err)
}
// Re-fetch with row lock. Concurrent verifies on the same record
// queue here; the second one re-reads the row after the first
// commits, sees CredentialID set, and bails out below.
//
// Lock-order note: this acquires the attestation_records row
// lock FIRST, then implicitly the identities row lock (via
// UpdateIdentity's UPDATE below). Any future code path that
// needs to hold both locks should acquire them in the same order
// to avoid deadlocks. Postgres detects deadlocks (40P01) and
// aborts one tx, so the worst case is a transient retry, not
// silent corruption — but cleaner to keep the order consistent.
//
// The repo method's error already names what failed; no outer
// wrap so the operator-visible message stays accurate even when
// the inner error is the WithTx contract violation rather than
// a runtime DB problem.
locked, err := s.repo.GetByIDForUpdate(ctx, id, accountID, projectID)
if err != nil {
return err
}
if locked.IsVerified || locked.CredentialID != "" {
return fmt.Errorf("%w: record %s", ErrAttestationAlreadyVerified, locked.ID)
}

// Step 3: commit the record with verified flag, audit fields, and
// credential link in a single write.
now := time.Now()
record.IsVerified = true
record.VerifiedAt = &now
if result.ExpiresAt != nil {
record.ExpiresAt = result.ExpiresAt
}
record.CredentialID = cred.ID
if err := s.repo.Update(ctx, record); err != nil {
return nil, fmt.Errorf("failed to update attestation record: %w", err)
// Re-load the identity inside the tx so we don't act on a stale
// snapshot from before someone deactivated it. With READ COMMITTED
// this read sees committed state as of the statement, so a
// concurrent UpdateIdentity that already finished will be visible.
identity, err := s.identitySvc.GetIdentity(ctx, locked.IdentityID, accountID, projectID)
if err != nil {
return fmt.Errorf("failed to load identity for verified attestation: %w", err)
}

issued, issuedCred, err := s.credentialSvc.IssueCredential(ctx, IssueRequest{
Identity: identity,
GrantType: domain.GrantTypeClientCredentials,
})
if err != nil {
return fmt.Errorf("failed to issue post-attestation credential: %w", err)
}

promotedTrust := trustLevelForAttestation(locked.Level)
if _, err := s.identitySvc.UpdateIdentity(ctx, locked.IdentityID, accountID, projectID, UpdateIdentityRequest{
TrustLevel: promotedTrust,
}); err != nil {
return fmt.Errorf("failed to promote identity trust level: %w", err)
}

now := time.Now()
locked.IsVerified = true
locked.VerifiedAt = &now
if result.ExpiresAt != nil {
locked.ExpiresAt = result.ExpiresAt
}
locked.CredentialID = issuedCred.ID
if err := s.repo.Update(ctx, locked); err != nil {
return fmt.Errorf("failed to update attestation record: %w", err)
}

// Promote local-success values to the outer scope. Last step in
// the closure so a return-error path above never assigns a
// partial result.
accessToken = issued
cred = issuedCred
verifiedRecord = locked
return nil
})
if txErr != nil {
return nil, txErr
}

return &VerifyAttestationResult{
Record: record,
Record: verifiedRecord,
AccessToken: accessToken,
Credential: cred,
}, nil
Expand Down
49 changes: 45 additions & 4 deletions internal/store/postgres/attestation.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ func NewAttestationRepository(db *bun.DB) *AttestationRepository {

// Create inserts a new attestation record.
func (r *AttestationRepository) Create(ctx context.Context, record *domain.AttestationRecord) error {
_, err := r.db.NewInsert().Model(record).Exec(ctx)
db := dbOrTx(ctx, r.db)
_, err := db.NewInsert().Model(record).Exec(ctx)
if err != nil {
return fmt.Errorf("failed to create attestation record: %w", err)
}
Expand All @@ -31,7 +32,8 @@ func (r *AttestationRepository) Create(ctx context.Context, record *domain.Attes
// GetByID retrieves an attestation record by its UUID.
func (r *AttestationRepository) GetByID(ctx context.Context, id, accountID, projectID string) (*domain.AttestationRecord, error) {
record := &domain.AttestationRecord{}
err := r.db.NewSelect().Model(record).
db := dbOrTx(ctx, r.db)
err := db.NewSelect().Model(record).
Where("id = ?", id).
Where("account_id = ?", accountID).
Where("project_id = ?", projectID).
Expand All @@ -42,11 +44,47 @@ func (r *AttestationRepository) GetByID(ctx context.Context, id, accountID, proj
return record, nil
}

// GetByIDForUpdate retrieves an attestation record by its UUID and acquires
// a row-level write lock (Postgres SELECT ... FOR UPDATE). The lock is held
// until the surrounding transaction commits or rolls back, so this method
// MUST be called inside a transaction (postgres.WithTx).
//
// Outside an explicit transaction Postgres still executes the SELECT
// successfully, but the implicit per-statement transaction commits as
// soon as the statement returns and the lock is released — concurrent
// callers see no useful serialization. We fail fast here rather than
// downgrade silently: misuse should surface as a loud error, not a
// race that only manifests under load.
//
// Use this in flows where the same attestation must serialize against
// concurrent writers — most notably AttestationService.VerifyAttestation,
// where two simultaneous /verify calls on the same record could otherwise
// each pass the IsVerified guard, each issue a credential, and leave the
// DB with two credentials from one proof.
func (r *AttestationRepository) GetByIDForUpdate(ctx context.Context, id, accountID, projectID string) (*domain.AttestationRecord, error) {
if !hasTx(ctx) {
return nil, fmt.Errorf("GetByIDForUpdate must be called inside a postgres.WithTx context — the row lock is meaningless without one")
}
record := &domain.AttestationRecord{}
db := dbOrTx(ctx, r.db)
err := db.NewSelect().Model(record).
Where("id = ?", id).
Where("account_id = ?", accountID).
Where("project_id = ?", projectID).
For("UPDATE").
Scan(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get attestation record for update: %w", err)
}
return record, nil
}

// GetHighestVerifiedLevel returns the highest verified attestation level for an identity.
// Returns an empty string if no verified attestation exists.
func (r *AttestationRepository) GetHighestVerifiedLevel(ctx context.Context, identityID string) (string, error) {
var level string
err := r.db.NewSelect().
db := dbOrTx(ctx, r.db)
err := db.NewSelect().
TableExpr("attestation_records").
ColumnExpr("level").
Where("identity_id = ?", identityID).
Expand All @@ -66,8 +104,11 @@ func (r *AttestationRepository) GetHighestVerifiedLevel(ctx context.Context, ide
}

// Update saves changes to an attestation record (e.g., mark as verified).
// Participates in a caller-provided transaction via postgres.WithTx(ctx, tx);
// falls through to a single auto-commit update otherwise.
func (r *AttestationRepository) Update(ctx context.Context, record *domain.AttestationRecord) error {
_, err := r.db.NewUpdate().Model(record).
db := dbOrTx(ctx, r.db)
_, err := db.NewUpdate().Model(record).
Where("id = ?", record.ID).
Exec(ctx)
if err != nil {
Expand Down
22 changes: 15 additions & 7 deletions internal/store/postgres/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ func NewCredentialRepository(db *bun.DB) *CredentialRepository {
return &CredentialRepository{db: db}
}

// Create inserts a new issued credential.
// Create inserts a new issued credential. Participates in a caller-provided
// transaction via postgres.WithTx(ctx, tx); falls through to a single
// auto-commit insert otherwise.
func (r *CredentialRepository) Create(ctx context.Context, cred *domain.IssuedCredential) error {
_, err := r.db.NewInsert().Model(cred).Exec(ctx)
db := dbOrTx(ctx, r.db)
_, err := db.NewInsert().Model(cred).Exec(ctx)
if err != nil {
return fmt.Errorf("failed to create credential: %w", err)
}
Expand All @@ -32,7 +35,8 @@ func (r *CredentialRepository) Create(ctx context.Context, cred *domain.IssuedCr
// GetByID retrieves a credential by its UUID.
func (r *CredentialRepository) GetByID(ctx context.Context, id, accountID, projectID string) (*domain.IssuedCredential, error) {
cred := &domain.IssuedCredential{}
err := r.db.NewSelect().Model(cred).
db := dbOrTx(ctx, r.db)
err := db.NewSelect().Model(cred).
Where("id = ?", id).
Where("account_id = ?", accountID).
Where("project_id = ?", projectID).
Expand All @@ -46,7 +50,8 @@ func (r *CredentialRepository) GetByID(ctx context.Context, id, accountID, proje
// GetByJTI retrieves a credential by its JWT ID (jti claim).
func (r *CredentialRepository) GetByJTI(ctx context.Context, jti string) (*domain.IssuedCredential, error) {
cred := &domain.IssuedCredential{}
err := r.db.NewSelect().Model(cred).
db := dbOrTx(ctx, r.db)
err := db.NewSelect().Model(cred).
Where("jti = ?", jti).
Scan(ctx)
if err != nil {
Expand All @@ -58,7 +63,8 @@ func (r *CredentialRepository) GetByJTI(ctx context.Context, jti string) (*domai
// ListByIdentity returns all credentials for a given identity.
func (r *CredentialRepository) ListByIdentity(ctx context.Context, identityID, accountID, projectID string) ([]*domain.IssuedCredential, error) {
var creds []*domain.IssuedCredential
err := r.db.NewSelect().Model(&creds).
db := dbOrTx(ctx, r.db)
err := db.NewSelect().Model(&creds).
Where("identity_id = ?", identityID).
Where("account_id = ?", accountID).
Where("project_id = ?", projectID).
Expand All @@ -79,7 +85,8 @@ func (r *CredentialRepository) ListByIdentity(ctx context.Context, identityID, a
func (r *CredentialRepository) RevokeAllActiveForIdentity(ctx context.Context, identityID, reason string) (int64, error) {
now := time.Now()
var count int64
if err := r.db.NewRaw(
db := dbOrTx(ctx, r.db)
if err := db.NewRaw(
"SELECT revoke_credentials_cascade(?, ?, ?)",
identityID, now, reason,
).Scan(ctx, &count); err != nil {
Expand All @@ -95,7 +102,8 @@ func (r *CredentialRepository) RevokeAllActiveForIdentity(ctx context.Context, i
func (r *CredentialRepository) Revoke(ctx context.Context, id, accountID, projectID, reason string) error {
now := time.Now()
var count int64
if err := r.db.NewRaw(
db := dbOrTx(ctx, r.db)
if err := db.NewRaw(
"SELECT revoke_credential_cascade(?, ?, ?, ?, ?)",
id, accountID, projectID, now, reason,
).Scan(ctx, &count); err != nil {
Expand Down
Loading