Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions internal/handler/wellknown.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@ package handler

import (
"context"
"encoding/json"
"fmt"
"net/http"

"github.com/danielgtaylor/huma/v2"
"github.com/lestrrat-go/jwx/v2/jwk"
)

// ── Well-known types ─────────────────────────────────────────────────────────

// JWKSOutput is the published /.well-known/jwks.json payload. We use a generic
// map (not jwk.Set) because we need to rewrite the "use" field on each key
// from "sig" (what jwx stores internally for verifier compatibility) to
// "jwt-svid" (what JWT-SVID §4 requires SPIFFE bundles to advertise).
type JWKSOutput struct {
Body jwk.Set
Body map[string]any
}

type OAuthMetadataOutput struct {
Expand Down Expand Up @@ -39,7 +44,27 @@ func (a *API) registerWellKnownRoutes(api huma.API) {
}

func (a *API) jwksOp(_ context.Context, _ *struct{}) (*JWKSOutput, error) {
return &JWKSOutput{Body: a.jwksSvc.KeySet()}, nil
// Marshal the in-memory keyset, then rewrite each key's "use" field to
// "jwt-svid" before returning. JWT-SVID §4 requires this value on every
// key in a SPIFFE bundle. We don't store it that way internally because
// lestrrat-go/jwx's verifier skips keys whose use is anything other than
// "sig" — see internal/signing/jwks.go.
raw, err := json.Marshal(a.jwksSvc.KeySet())
if err != nil {
return nil, fmt.Errorf("marshal jwks: %w", err)
}
var body map[string]any
if err := json.Unmarshal(raw, &body); err != nil {
return nil, fmt.Errorf("unmarshal jwks: %w", err)
}
if keys, ok := body["keys"].([]any); ok {
for _, k := range keys {
if km, ok := k.(map[string]any); ok {
km["use"] = "jwt-svid"
}
}
}
return &JWKSOutput{Body: body}, nil
}

func (a *API) oauthMetadataOp(_ context.Context, _ *struct{}) (*OAuthMetadataOutput, error) {
Expand Down
4 changes: 4 additions & 0 deletions internal/signing/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ func addToKeySet(set jwk.Set, pubKey crypto.PublicKey, keyID string, alg jwa.Sig
if err := jwkKey.Set(jwk.AlgorithmKey, alg); err != nil {
return fmt.Errorf("failed to set algorithm: %w", err)
}
// In-memory keys keep use=sig because lestrrat-go/jwx's verifier skips
// any key whose use is set to anything other than "sig". The published
// /.well-known/jwks.json rewrites this to "jwt-svid" at the handler so
// SPIFFE verifiers see the value JWT-SVID §4 requires.
if err := jwkKey.Set(jwk.KeyUsageKey, jwk.ForSignature); err != nil {
return fmt.Errorf("failed to set key usage: %w", err)
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/authjwt/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,19 @@ func (c *JWKSClient) refresh(ctx context.Context) error {
return fmt.Errorf("fetch JWKS: %w", err)
}

// SPIFFE bundles publish use=jwt-svid (JWT-SVID §4). lestrrat-go/jwx's
// verifier treats anything other than "sig" as non-signing and skips the
// key, so we normalize on ingest. RFC 7517 says use is informational —
// rewriting it doesn't change what the key actually is.
kids := make(map[string]struct{}, set.Len())
for i := 0; i < set.Len(); i++ {
key, ok := set.Key(i)
if !ok {
continue
}
if key.KeyUsage() == "jwt-svid" {
_ = key.Set(jwk.KeyUsageKey, jwk.ForSignature)
}
kids[key.KeyID()] = struct{}{}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/wellknown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestJWKSEndpoint(t *testing.T) {

assert.Equal(t, "EC", ecKey["kty"])
assert.Equal(t, "ES256", ecKey["alg"])
assert.Equal(t, "sig", ecKey["use"])
assert.Equal(t, "jwt-svid", ecKey["use"])
assert.Equal(t, testKeyID, ecKey["kid"])
assert.Equal(t, "P-256", ecKey["crv"])
assert.NotEmpty(t, ecKey["x"], "EC key must have x coordinate")
Expand Down