diff --git a/internal/handler/wellknown.go b/internal/handler/wellknown.go index b1d3555..b244fce 100644 --- a/internal/handler/wellknown.go +++ b/internal/handler/wellknown.go @@ -2,16 +2,21 @@ package handler import ( "context" + "encoding/json" + "fmt" "net/http" "github.com/danielgtaylor/huma/v2" - "github.com/lestrrat-go/jwx/v2/jwk" ) // ── Well-known types ───────────────────────────────────────────────────────── +// JWKSOutput is the published /.well-known/jwks.json payload. We use a generic +// map (not jwk.Set) because we need to rewrite the "use" field on each key +// from "sig" (what jwx stores internally for verifier compatibility) to +// "jwt-svid" (what JWT-SVID §4 requires SPIFFE bundles to advertise). type JWKSOutput struct { - Body jwk.Set + Body map[string]any } type OAuthMetadataOutput struct { @@ -39,7 +44,27 @@ func (a *API) registerWellKnownRoutes(api huma.API) { } func (a *API) jwksOp(_ context.Context, _ *struct{}) (*JWKSOutput, error) { - return &JWKSOutput{Body: a.jwksSvc.KeySet()}, nil + // Marshal the in-memory keyset, then rewrite each key's "use" field to + // "jwt-svid" before returning. JWT-SVID §4 requires this value on every + // key in a SPIFFE bundle. We don't store it that way internally because + // lestrrat-go/jwx's verifier skips keys whose use is anything other than + // "sig" — see internal/signing/jwks.go. + raw, err := json.Marshal(a.jwksSvc.KeySet()) + if err != nil { + return nil, fmt.Errorf("marshal jwks: %w", err) + } + var body map[string]any + if err := json.Unmarshal(raw, &body); err != nil { + return nil, fmt.Errorf("unmarshal jwks: %w", err) + } + if keys, ok := body["keys"].([]any); ok { + for _, k := range keys { + if km, ok := k.(map[string]any); ok { + km["use"] = "jwt-svid" + } + } + } + return &JWKSOutput{Body: body}, nil } func (a *API) oauthMetadataOp(_ context.Context, _ *struct{}) (*OAuthMetadataOutput, error) { diff --git a/internal/signing/jwks.go b/internal/signing/jwks.go index 846c4ca..33509e6 100644 --- a/internal/signing/jwks.go +++ b/internal/signing/jwks.go @@ -184,6 +184,10 @@ func addToKeySet(set jwk.Set, pubKey crypto.PublicKey, keyID string, alg jwa.Sig if err := jwkKey.Set(jwk.AlgorithmKey, alg); err != nil { return fmt.Errorf("failed to set algorithm: %w", err) } + // In-memory keys keep use=sig because lestrrat-go/jwx's verifier skips + // any key whose use is set to anything other than "sig". The published + // /.well-known/jwks.json rewrites this to "jwt-svid" at the handler so + // SPIFFE verifiers see the value JWT-SVID §4 requires. if err := jwkKey.Set(jwk.KeyUsageKey, jwk.ForSignature); err != nil { return fmt.Errorf("failed to set key usage: %w", err) } diff --git a/pkg/authjwt/jwks.go b/pkg/authjwt/jwks.go index 68f30bc..2c4b448 100644 --- a/pkg/authjwt/jwks.go +++ b/pkg/authjwt/jwks.go @@ -149,12 +149,19 @@ func (c *JWKSClient) refresh(ctx context.Context) error { return fmt.Errorf("fetch JWKS: %w", err) } + // SPIFFE bundles publish use=jwt-svid (JWT-SVID §4). lestrrat-go/jwx's + // verifier treats anything other than "sig" as non-signing and skips the + // key, so we normalize on ingest. RFC 7517 says use is informational — + // rewriting it doesn't change what the key actually is. kids := make(map[string]struct{}, set.Len()) for i := 0; i < set.Len(); i++ { key, ok := set.Key(i) if !ok { continue } + if key.KeyUsage() == "jwt-svid" { + _ = key.Set(jwk.KeyUsageKey, jwk.ForSignature) + } kids[key.KeyID()] = struct{}{} } diff --git a/tests/integration/wellknown_test.go b/tests/integration/wellknown_test.go index fd4dffb..0d57a77 100644 --- a/tests/integration/wellknown_test.go +++ b/tests/integration/wellknown_test.go @@ -32,7 +32,7 @@ func TestJWKSEndpoint(t *testing.T) { assert.Equal(t, "EC", ecKey["kty"]) assert.Equal(t, "ES256", ecKey["alg"]) - assert.Equal(t, "sig", ecKey["use"]) + assert.Equal(t, "jwt-svid", ecKey["use"]) assert.Equal(t, testKeyID, ecKey["kid"]) assert.Equal(t, "P-256", ecKey["crv"]) assert.NotEmpty(t, ecKey["x"], "EC key must have x coordinate")