diff --git a/.gitignore b/.gitignore index b55a187b95..79e6854bd7 100644 --- a/.gitignore +++ b/.gitignore @@ -54,4 +54,4 @@ traces/ .cache/* # Claude AI files -.claude/ \ No newline at end of file +.claude/ diff --git a/lib/ocrypto/asym_decryption.go b/lib/ocrypto/asym_decryption.go index 426f859723..bb41212383 100644 --- a/lib/ocrypto/asym_decryption.go +++ b/lib/ocrypto/asym_decryption.go @@ -7,6 +7,7 @@ import ( "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" + "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" @@ -26,6 +27,38 @@ type AsymDecryption struct { type PrivateKeyDecryptor interface { // Decrypt decrypts ciphertext with private key. Decrypt(data []byte) ([]byte, error) + + // PrivateKeyInPemFormat returns the private key in PEM format. + PrivateKeyInPemFormat() (string, error) + + // Public returns the corresponding public-key encryptor. + Public() (PublicKeyEncryptor, error) + + // KeyType returns the key type, e.g. RSA or EC. + KeyType() KeyType +} + +func NewPrivateKeyDecryptor(kt KeyType) (PrivateKeyDecryptor, error) { + switch { + case IsRSAKeyType(kt): + bits, err := RSAKeyTypeToBits(kt) + if err != nil { + return nil, err + } + keyPair, err := NewRSAKeyPair(bits) + if err != nil { + return nil, err + } + return keyPair, nil + case IsECKeyType(kt): + mode, err := ECKeyTypeToMode(kt) + if err != nil { + return nil, err + } + return NewECPrivateKey(mode) + default: + return nil, fmt.Errorf("unsupported key type: %v", kt) + } } // FromPrivatePEM creates and returns a new AsymDecryption. @@ -109,6 +142,33 @@ func (asymDecryption AsymDecryption) Decrypt(data []byte) ([]byte, error) { return bytes, nil } +func (asymDecryption AsymDecryption) PrivateKeyInPemFormat() (string, error) { + return privateKeyInPemFormat(asymDecryption.PrivateKey) +} + +func (asymDecryption AsymDecryption) Public() (PublicKeyEncryptor, error) { + if asymDecryption.PrivateKey == nil { + return nil, errors.New("failed to generate public key encryptor, private key is empty") + } + + return &AsymEncryption{PublicKey: &asymDecryption.PrivateKey.PublicKey}, nil +} + +func (asymDecryption AsymDecryption) KeyType() KeyType { + if asymDecryption.PrivateKey == nil { + return KeyType("rsa:[unknown]") + } + + switch asymDecryption.PrivateKey.Size() { + case RSA2048Size / 8: //nolint:mnd // standard key size in bytes + return RSA2048Key + case RSA4096Size / 8: //nolint:mnd // large key size in bytes + return RSA4096Key + default: + return KeyType(fmt.Sprintf("rsa:%d", asymDecryption.PrivateKey.Size()*8)) //nolint:mnd // convert to bits + } +} + type ECDecryptor struct { sk *ecdh.PrivateKey salt []byte @@ -124,6 +184,20 @@ func NewECDecryptor(sk *ecdh.PrivateKey) (ECDecryptor, error) { return ECDecryptor{sk, salt, nil}, nil } +func NewECPrivateKey(mode ECCMode) (ECDecryptor, error) { + curve, err := curveFromECCMode(mode) + if err != nil { + return ECDecryptor{}, err + } + + sk, err := curve.GenerateKey(rand.Reader) + if err != nil { + return ECDecryptor{}, fmt.Errorf("ecdh.GenerateKey failed: %w", err) + } + + return NewECDecryptor(sk) +} + func NewSaltedECDecryptor(sk *ecdh.PrivateKey, salt, info []byte) (ECDecryptor, error) { return ECDecryptor{sk, salt, info}, nil } @@ -133,30 +207,30 @@ func (e ECDecryptor) Decrypt(_ []byte) ([]byte, error) { return nil, errors.New("ecdh standard decrypt unimplemented") } -func (e ECDecryptor) DecryptWithEphemeralKey(data, ephemeral []byte) ([]byte, error) { - var ek *ecdh.PublicKey +func (e ECDecryptor) PrivateKeyInPemFormat() (string, error) { + return privateKeyInPemFormat(e.sk) +} - if pubFromDSN, err := x509.ParsePKIXPublicKey(ephemeral); err == nil { - switch pubFromDSN := pubFromDSN.(type) { - case *ecdsa.PublicKey: - ek, err = ConvertToECDHPublicKey(pubFromDSN) - if err != nil { - return nil, fmt.Errorf("ecdh conversion failure: %w", err) - } - case *ecdh.PublicKey: - ek = pubFromDSN - default: - return nil, fmt.Errorf("unsupported public key of type: %T", pubFromDSN) - } - } else { - ekDSA, err := UncompressECPubKey(convCurve(e.sk.Curve()), ephemeral) - if err != nil { - return nil, err - } - ek, err = ekDSA.ECDH() - if err != nil { - return nil, fmt.Errorf("ecdh failure: %w", err) - } +func (e ECDecryptor) Public() (PublicKeyEncryptor, error) { + if e.sk == nil { + return nil, errors.New("failed to generate public key encryptor, private key is empty") + } + + return newECIES(e.sk.PublicKey(), e.salt, e.info) +} + +func (e ECDecryptor) KeyType() KeyType { + if e.sk == nil { + return KeyType("ec:[unknown]") + } + + return keyTypeFromECDHCurve(e.sk.Curve()) +} + +func (e ECDecryptor) DecryptWithEphemeralKey(data, ephemeral []byte) ([]byte, error) { + ek, err := e.parseEphemeralPublicKey(ephemeral) + if err != nil { + return nil, fmt.Errorf("failed to parse ephemeral public key: %w", err) } ikm, err := e.sk.ECDH(ek) @@ -196,15 +270,61 @@ func (e ECDecryptor) DecryptWithEphemeralKey(data, ephemeral []byte) ([]byte, er return plaintext, nil } -func convCurve(c ecdh.Curve) elliptic.Curve { +func (e ECDecryptor) deriveSharedKey(publicKeyInPem string) ([]byte, error) { + if e.sk == nil { + return nil, errors.New("failed to derive shared key, private key is empty") + } + + pub, err := getPublicPart(publicKeyInPem) + if err != nil { + return nil, err + } + + ecdhPublicKey, err := ConvertToECDHPublicKey(pub) + if err != nil { + return nil, fmt.Errorf("unsupported public key type: %w", err) + } + + sharedKey, err := e.sk.ECDH(ecdhPublicKey) + if err != nil { + return nil, fmt.Errorf("there was a problem deriving a shared ECDH key: %w", err) + } + + return sharedKey, nil +} + +// parseEphemeralPublicKey parses an ephemeral public key from DER (PKIX) or compressed EC point bytes. +func (e ECDecryptor) parseEphemeralPublicKey(ephemeral []byte) (*ecdh.PublicKey, error) { + if pub, err := x509.ParsePKIXPublicKey(ephemeral); err == nil { + switch pub := pub.(type) { + case *ecdsa.PublicKey: + return ConvertToECDHPublicKey(pub) + case *ecdh.PublicKey: + return pub, nil + default: + return nil, fmt.Errorf("unsupported public key of type: %T", pub) + } + } + curve, err := convCurve(e.sk.Curve()) + if err != nil { + return nil, err + } + ekDSA, err := UncompressECPubKey(curve, ephemeral) + if err != nil { + return nil, err + } + return ekDSA.ECDH() +} + +func convCurve(c ecdh.Curve) (elliptic.Curve, error) { switch c { case ecdh.P256(): - return elliptic.P256() + return elliptic.P256(), nil case ecdh.P384(): - return elliptic.P384() + return elliptic.P384(), nil case ecdh.P521(): - return elliptic.P521() + return elliptic.P521(), nil default: - return nil + return nil, fmt.Errorf("unsupported ECDH curve: %v", c) } } diff --git a/lib/ocrypto/asym_decryption_test.go b/lib/ocrypto/asym_decryption_test.go new file mode 100644 index 0000000000..2b173baf3d --- /dev/null +++ b/lib/ocrypto/asym_decryption_test.go @@ -0,0 +1,162 @@ +package ocrypto + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewPrivateKeyDecryptor(t *testing.T) { + t.Parallel() + + t.Run("RSA2048", func(t *testing.T) { + t.Parallel() + d, err := NewPrivateKeyDecryptor(RSA2048Key) + require.NoError(t, err) + require.Equal(t, RSA2048Key, d.KeyType()) + }) + + t.Run("RSA4096", func(t *testing.T) { + t.Parallel() + d, err := NewPrivateKeyDecryptor(RSA4096Key) + require.NoError(t, err) + require.Equal(t, RSA4096Key, d.KeyType()) + }) + + t.Run("EC256", func(t *testing.T) { + t.Parallel() + d, err := NewPrivateKeyDecryptor(EC256Key) + require.NoError(t, err) + require.Equal(t, EC256Key, d.KeyType()) + }) + + t.Run("EC384", func(t *testing.T) { + t.Parallel() + d, err := NewPrivateKeyDecryptor(EC384Key) + require.NoError(t, err) + require.Equal(t, EC384Key, d.KeyType()) + }) + + t.Run("EC521", func(t *testing.T) { + t.Parallel() + d, err := NewPrivateKeyDecryptor(EC521Key) + require.NoError(t, err) + require.Equal(t, EC521Key, d.KeyType()) + }) + + t.Run("UnsupportedType", func(t *testing.T) { + t.Parallel() + _, err := NewPrivateKeyDecryptor(KeyType("dsa:1024")) + require.Error(t, err) + require.Contains(t, err.Error(), "unsupported key type") + }) +} + +func TestECDecryptorMethods(t *testing.T) { + t.Parallel() + + modes := []struct { + mode ECCMode + kt KeyType + }{ + {ECCModeSecp256r1, EC256Key}, + {ECCModeSecp384r1, EC384Key}, + {ECCModeSecp521r1, EC521Key}, + } + + for _, tc := range modes { + t.Run(string(tc.kt), func(t *testing.T) { + t.Parallel() + + dec, err := NewECPrivateKey(tc.mode) + require.NoError(t, err) + + // KeyType + require.Equal(t, tc.kt, dec.KeyType()) + + // PrivateKeyInPemFormat round-trip + privPEM, err := dec.PrivateKeyInPemFormat() + require.NoError(t, err) + require.NotEmpty(t, privPEM) + dec2, err := FromPrivatePEM(privPEM) + require.NoError(t, err) + require.Equal(t, tc.kt, dec2.KeyType()) + + // Public — returns an encryptor whose public key round-trips + enc, err := dec.Public() + require.NoError(t, err) + pubPEM, err := enc.PublicKeyInPemFormat() + require.NoError(t, err) + require.NotEmpty(t, pubPEM) + require.Equal(t, tc.kt, enc.KeyType()) + }) + } + + t.Run("NilKeyGuards", func(t *testing.T) { + t.Parallel() + nilDec := ECDecryptor{} + _, err := nilDec.Public() + require.Error(t, err) + _, err = nilDec.PrivateKeyInPemFormat() + require.Error(t, err) + }) +} + +func TestRsaKeyPairNewMethods(t *testing.T) { + t.Parallel() + + sizes := []struct { + bits int + kt KeyType + }{ + {RSA2048Size, RSA2048Key}, + {RSA4096Size, RSA4096Key}, + } + + for _, tc := range sizes { + t.Run(string(tc.kt), func(t *testing.T) { + t.Parallel() + kp, err := NewRSAKeyPair(tc.bits) + require.NoError(t, err) + require.Equal(t, tc.kt, kp.KeyType()) + + enc, err := kp.Public() + require.NoError(t, err) + require.Equal(t, tc.kt, enc.KeyType()) + }) + } + + t.Run("NilKeyGuard", func(t *testing.T) { + t.Parallel() + _, err := RsaKeyPair{}.Public() + require.Error(t, err) + }) +} + +func TestAsymDecryptionKeyType(t *testing.T) { + t.Parallel() + + for _, bits := range []int{RSA2048Size, RSA4096Size} { + kp, err := NewRSAKeyPair(bits) + require.NoError(t, err) + privPEM, err := kp.PrivateKeyInPemFormat() + require.NoError(t, err) + d, err := FromPrivatePEM(privPEM) + require.NoError(t, err) + ad, ok := d.(AsymDecryption) + require.True(t, ok) + if bits == RSA2048Size { + require.Equal(t, RSA2048Key, ad.KeyType()) + } else { + require.Equal(t, RSA4096Key, ad.KeyType()) + } + } + + t.Run("NilKeyGuard", func(t *testing.T) { + t.Parallel() + kt := AsymDecryption{}.KeyType() + // nil key returns a sentinel string, not a valid key type + require.False(t, IsRSAKeyType(kt)) + require.False(t, IsECKeyType(kt)) + }) +} diff --git a/lib/ocrypto/asym_encryption.go b/lib/ocrypto/asym_encryption.go index c44aa64cce..75156c11cf 100644 --- a/lib/ocrypto/asym_encryption.go +++ b/lib/ocrypto/asym_encryption.go @@ -94,7 +94,10 @@ func FromPublicPEMWithSalt(publicKeyInPem string, salt, info []byte) (PublicKeyE func newECIES(pub *ecdh.PublicKey, salt, info []byte) (ECEncryptor, error) { ek, err := pub.Curve().GenerateKey(rand.Reader) - return ECEncryptor{pub, ek, salt, info}, err + if err != nil { + return ECEncryptor{}, fmt.Errorf("newECIES: failed to generate ephemeral key: %w", err) + } + return ECEncryptor{pub, ek, salt, info}, nil } // NewAsymEncryption creates and returns a new AsymEncryption. @@ -181,9 +184,13 @@ func (e AsymEncryption) EphemeralKey() []byte { } func (e ECEncryptor) EphemeralKey() []byte { + if e.ek == nil { + return nil + } publicKeyBytes, err := x509.MarshalPKIXPublicKey(e.ek.PublicKey()) if err != nil { - return nil + // MarshalPKIXPublicKey failing on a freshly-generated ecdh.PublicKey is unexpected. + panic(fmt.Sprintf("ocrypto: EphemeralKey: unexpected marshal failure: %v", err)) } return publicKeyBytes } @@ -193,8 +200,12 @@ func (e AsymEncryption) Metadata() (map[string]string, error) { } func (e ECEncryptor) Metadata() (map[string]string, error) { + ek := e.EphemeralKey() + if len(ek) == 0 { + return nil, errors.New("ECEncryptor.Metadata: ephemeral key is empty") + } m := make(map[string]string) - m["ephemeralPublicKey"] = string(e.EphemeralKey()) + m["ephemeralPublicKey"] = string(ek) return m, nil } @@ -271,5 +282,13 @@ func (e ECEncryptor) Encrypt(data []byte) ([]byte, error) { // PublicKeyInPemFormat Returns public key in pem format. func (e ECEncryptor) PublicKeyInPemFormat() (string, error) { - return publicKeyInPemFormat(e.ek.Public()) + return publicKeyInPemFormat(e.pub) +} + +func (e ECEncryptor) EphemeralPublicKeyInPemFormat() (string, error) { + if e.ek == nil { + return "", errors.New("failed to generate PEM formatted public key") + } + + return publicKeyInPemFormat(e.ek.PublicKey()) } diff --git a/lib/ocrypto/ec_decrypt_compressed_test.go b/lib/ocrypto/ec_decrypt_compressed_test.go index 45a94e795f..7d7868523c 100644 --- a/lib/ocrypto/ec_decrypt_compressed_test.go +++ b/lib/ocrypto/ec_decrypt_compressed_test.go @@ -26,17 +26,22 @@ func TestECDecryptWithCompressedEphemeralKey(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - receiverKeys, err := NewECKeyPair(test.mode) + receiverKey, err := NewECPrivateKey(test.mode) if err != nil { - t.Fatalf("NewECKeyPair failed: %v", err) + t.Fatalf("NewECPrivateKey failed: %v", err) } - pubPEM, err := receiverKeys.PublicKeyInPemFormat() + pubEncryptor, err := receiverKey.Public() + if err != nil { + t.Fatalf("Public failed: %v", err) + } + + pubPEM, err := pubEncryptor.PublicKeyInPemFormat() if err != nil { t.Fatalf("PublicKeyInPemFormat failed: %v", err) } - privPEM, err := receiverKeys.PrivateKeyInPemFormat() + privPEM, err := receiverKey.PrivateKeyInPemFormat() if err != nil { t.Fatalf("PrivateKeyInPemFormat failed: %v", err) } diff --git a/lib/ocrypto/ec_key_pair.go b/lib/ocrypto/ec_key_pair.go index f9a9554d4f..3a882f77b4 100644 --- a/lib/ocrypto/ec_key_pair.go +++ b/lib/ocrypto/ec_key_pair.go @@ -5,6 +5,7 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/pem" @@ -44,12 +45,19 @@ const ( RSA4096Size = 4096 ) +// KeyPair represents a cryptographic key pair. +// +// Deprecated: Prefer PrivateKeyDecryptor from asym_decryption.go, which separates +// key-management from algorithm-specific capabilities. type KeyPair interface { PublicKeyInPemFormat() (string, error) PrivateKeyInPemFormat() (string, error) GetKeyType() KeyType } +// NewKeyPair creates a new key pair of the given type. +// +// Deprecated: Use NewPrivateKeyDecryptor instead. func NewKeyPair(kt KeyType) (KeyPair, error) { switch kt { case RSA2048Key, RSA4096Key: @@ -69,6 +77,10 @@ func NewKeyPair(kt KeyType) (KeyPair, error) { } } +// ECKeyPair combines private and public EC key material in one value. +// +// Deprecated: Prefer PrivateKeyDecryptor implementations from asym_decryption.go +// and derive the public half via PrivateKeyDecryptor.Public(). type ECKeyPair struct { PrivateKey *ecdsa.PrivateKey } @@ -112,6 +124,21 @@ func GetECCurveFromECCMode(mode ECCMode) (elliptic.Curve, error) { return c, nil } +func curveFromECCMode(mode ECCMode) (ecdh.Curve, error) { + switch mode { + case ECCModeSecp256r1: + return ecdh.P256(), nil + case ECCModeSecp384r1: + return ecdh.P384(), nil + case ECCModeSecp521r1: + return ecdh.P521(), nil + case ECCModeSecp256k1: + return nil, errors.New("unsupported ECC mode") + default: + return nil, fmt.Errorf("unsupported ECC mode %d", mode) + } +} + func (mode ECCMode) String() string { switch mode { case ECCModeSecp256r1: @@ -165,6 +192,8 @@ func RSAKeyTypeToBits(kt KeyType) (int, error) { } // NewECKeyPair Generates an EC key pair of the given bit size. +// +// Deprecated: Prefer NewECPrivateKey and derive the public half via PrivateKeyDecryptor.Public(). func NewECKeyPair(mode ECCMode) (ECKeyPair, error) { var c elliptic.Curve @@ -186,22 +215,7 @@ func NewECKeyPair(mode ECCMode) (ECKeyPair, error) { // PrivateKeyInPemFormat Returns private key in pem format. func (keyPair ECKeyPair) PrivateKeyInPemFormat() (string, error) { - if keyPair.PrivateKey == nil { - return "", errors.New("failed to generate PEM formatted private key") - } - - privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(keyPair.PrivateKey) - if err != nil { - return "", fmt.Errorf("x509.MarshalPKCS8PrivateKey failed: %w", err) - } - - privateKeyPem := pem.EncodeToMemory( - &pem.Block{ - Type: "PRIVATE KEY", - Bytes: privateKeyBytes, - }, - ) - return string(privateKeyPem), nil + return privateKeyInPemFormat(keyPair.PrivateKey) } // PublicKeyInPemFormat Returns public key in pem format. @@ -210,18 +224,7 @@ func (keyPair ECKeyPair) PublicKeyInPemFormat() (string, error) { return "", errors.New("failed to generate PEM formatted public key") } - publicKeyBytes, err := x509.MarshalPKIXPublicKey(&keyPair.PrivateKey.PublicKey) - if err != nil { - return "", fmt.Errorf("x509.MarshalPKIXPublicKey failed: %w", err) - } - - publicKeyPem := pem.EncodeToMemory( - &pem.Block{ - Type: "PUBLIC KEY", - Bytes: publicKeyBytes, - }, - ) - return string(publicKeyPem), nil + return publicKeyInPemFormat(&keyPair.PrivateKey.PublicKey) } // KeySize Return the size of this ec key pair. @@ -232,6 +235,50 @@ func (keyPair ECKeyPair) KeySize() (int, error) { return keyPair.PrivateKey.Params().N.BitLen(), nil } +func (keyPair ECKeyPair) Decrypt(data []byte) ([]byte, error) { + ecdhPrivateKey, err := ConvertToECDHPrivateKey(keyPair.PrivateKey) + if err != nil { + return nil, err + } + dec, err := NewECDecryptor(ecdhPrivateKey) + if err != nil { + return nil, err + } + return dec.Decrypt(data) +} + +func (keyPair ECKeyPair) Public() (PublicKeyEncryptor, error) { + ecdhPrivateKey, err := ConvertToECDHPrivateKey(keyPair.PrivateKey) + if err != nil { + return nil, err + } + dec, err := NewECDecryptor(ecdhPrivateKey) + if err != nil { + return nil, err + } + return dec.Public() +} + +func (keyPair ECKeyPair) KeyType() KeyType { + if keyPair.PrivateKey == nil { + return KeyType("ec:[unknown]") + } + + return keyTypeFromEllipticCurve(keyPair.PrivateKey.Curve) +} + +func (keyPair ECKeyPair) DeriveSharedKey(publicKeyInPem string) ([]byte, error) { + ecdhPrivateKey, err := ConvertToECDHPrivateKey(keyPair.PrivateKey) + if err != nil { + return nil, err + } + dec, err := NewECDecryptor(ecdhPrivateKey) + if err != nil { + return nil, err + } + return dec.deriveSharedKey(publicKeyInPem) +} + // CompressedECPublicKey - return a compressed key from the supplied curve and public key func CompressedECPublicKey(mode ECCMode, pubKey ecdsa.PublicKey) ([]byte, error) { curve, err := GetECCurveFromECCMode(mode) @@ -365,6 +412,8 @@ func ECPrivateKeyFromPem(privateECKeyInPem []byte) (*ecdh.PrivateKey, error) { } // ComputeECDHKey calculate shared secret from public key from one party and the private key from another party. +// +// Deprecated: Use PrivateKeyDecryptor.DeriveSharedKey. func ComputeECDHKey(privateKeyInPem []byte, publicKeyInPem []byte) ([]byte, error) { ecdhPrivateKey, err := ECPrivateKeyFromPem(privateKeyInPem) if err != nil { @@ -430,35 +479,12 @@ func UncompressECPubKey(curve elliptic.Curve, compressedPubKey []byte) (*ecdsa.P // ECPrivateKeyInPemFormat Returns private key in pem format. func ECPrivateKeyInPemFormat(privateKey ecdsa.PrivateKey) (string, error) { - privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) - if err != nil { - return "", fmt.Errorf("x509.MarshalPKCS8PrivateKey failed: %w", err) - } - - privateKeyPem := pem.EncodeToMemory( - &pem.Block{ - Type: "PRIVATE KEY", - Bytes: privateKeyBytes, - }, - ) - return string(privateKeyPem), nil + return privateKeyInPemFormat(privateKey) } // ECPublicKeyInPemFormat Returns public key in pem format. func ECPublicKeyInPemFormat(publicKey ecdsa.PublicKey) (string, error) { - pkb, err := x509.MarshalPKIXPublicKey(publicKey) - if err != nil { - return "", fmt.Errorf("x509.MarshalPKIXPublicKey failed: %w", err) - } - - publicKeyPem := pem.EncodeToMemory( - &pem.Block{ - Type: "PUBLIC KEY", - Bytes: pkb, - }, - ) - - return string(publicKeyPem), nil + return publicKeyInPemFormat(publicKey) } // GetECKeySize returns the curve size from a PEM-encoded EC public key @@ -492,5 +518,65 @@ func GetECKeySize(pemData []byte) (int, error) { // GetKeyType returns the key type (ECKey) func (keyPair ECKeyPair) GetKeyType() KeyType { - return EC256Key + return keyPair.KeyType() +} + +func privateKeyInPemFormat(pk any) (string, error) { + switch key := pk.(type) { + case nil: + return "", errors.New("failed to generate PEM formatted private key") + case *ecdsa.PrivateKey: + if key == nil { + return "", errors.New("failed to generate PEM formatted private key") + } + case *ecdh.PrivateKey: + if key == nil { + return "", errors.New("failed to generate PEM formatted private key") + } + case *rsa.PrivateKey: + if key == nil { + return "", errors.New("failed to generate PEM formatted private key") + } + } + + privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(pk) + if err != nil { + return "", fmt.Errorf("x509.MarshalPKCS8PrivateKey failed: %w", err) + } + + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: privateKeyBytes, + }) + + return string(privateKeyPEM), nil +} + +func keyTypeFromECDHCurve(curve ecdh.Curve) KeyType { + switch curve { + case ecdh.P256(): + return EC256Key + case ecdh.P384(): + return EC384Key + case ecdh.P521(): + return EC521Key + default: + if n, ok := curve.(fmt.Stringer); ok { + return KeyType("ec:" + n.String()) + } + return KeyType("ec:[unknown]") + } +} + +func keyTypeFromEllipticCurve(curve elliptic.Curve) KeyType { + switch curve { + case elliptic.P256(): + return EC256Key + case elliptic.P384(): + return EC384Key + case elliptic.P521(): + return EC521Key + default: + return KeyType("ec:[unknown]") + } } diff --git a/lib/ocrypto/ec_key_pair_test.go b/lib/ocrypto/ec_key_pair_test.go index 18e875cb52..8c6ac184f8 100644 --- a/lib/ocrypto/ec_key_pair_test.go +++ b/lib/ocrypto/ec_key_pair_test.go @@ -1,9 +1,9 @@ package ocrypto import ( - "crypto/sha256" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -36,77 +36,43 @@ func TestECKeyPair(t *testing.T) { size = 99999 // deliberately bad value } - if keySize != size { - t.Fatalf("invalid key size for mode %d, expected:%d actual:%d", - modeGood, size, keySize) - } + assert.Equal(t, keySize, size, "invalid key size for mode %d", modeGood) } // Fail case emptyECKeyPair := ECKeyPair{} _, err := emptyECKeyPair.PrivateKeyInPemFormat() - if err == nil { - t.Fatal("EcKeyPair.PrivateKeyInPemFormat() fail to return error") - } + require.Error(t, err, "EcKeyPair.PrivateKeyInPemFormat() fail to return error") _, err = emptyECKeyPair.PublicKeyInPemFormat() - if err == nil { - t.Fatal("EcKeyPair.PublicKeyInPemFormat() fail to return error") - } + require.Error(t, err, "EcKeyPair.PublicKeyInPemFormat() fail to return error") _, err = emptyECKeyPair.KeySize() - if err == nil { - t.Fatal("EcKeyPair.keySize() fail to return error") - } + require.Error(t, err, "EcKeyPair.keySize() fail to return error") for _, modeBad := range []ECCMode{ECCModeSecp256k1} { _, err := NewECKeyPair(modeBad) - if err == nil { - t.Fatalf("did not fail as expected: NewECKeyPair(%d): %v", modeBad, err) - } + assert.Error(t, err, "did not fail as expected: NewECKeyPair(%d)", modeBad) } } func TestECRewrapKeyGenerate(t *testing.T) { - kasECKeyPair, err := NewECKeyPair(ECCModeSecp256r1) - require.NoError(t, err, "fail on NewECKeyPair") - - kasPubKeyAsPem, err := kasECKeyPair.PublicKeyInPemFormat() - require.NoError(t, err, "fail to generate ec public key in pem format") - - kasPrivateKeyAsPem, err := kasECKeyPair.PrivateKeyInPemFormat() - require.NoError(t, err, "fail to generate ec private key in pem format") + // KAS key pair + kasKey, err := NewECPrivateKey(ECCModeSecp256r1) + require.NoError(t, err, "fail on NewECPrivateKey") - sdkECKeyPair, err := NewECKeyPair(ECCModeSecp256r1) - require.NoError(t, err, "fail on NewECKeyPair") + kasPublicKey, err := kasKey.Public() + require.NoError(t, err, "fail to get KAS public key") - sdkPubKeyAsPem, err := sdkECKeyPair.PublicKeyInPemFormat() - require.NoError(t, err, "fail to generate ec public key in pem format") + sampleKey := []byte("samplekey") + wrappedKey, err := kasPublicKey.Encrypt(sampleKey) + require.NoError(t, err, "fail unable to encypt samplekey") - sdkPrivateKeyAsPem, err := sdkECKeyPair.PrivateKeyInPemFormat() - require.NoError(t, err, "fail to generate ec private key in pem format") + unwrappedKey, err := kasKey.DecryptWithEphemeralKey(wrappedKey, kasPublicKey.EphemeralKey()) + require.NoError(t, err, "fail to unwrap") - kasECDHKey, err := ComputeECDHKey([]byte(kasPrivateKeyAsPem), []byte(sdkPubKeyAsPem)) - require.NoError(t, err, "fail to calculate ecdh key") - - // slat - digest := sha256.New() - digest.Write([]byte("TDF")) - - kasSymmetricKey, err := CalculateHKDF(digest.Sum(nil), kasECDHKey) - require.NoError(t, err, "fail to calculate HKDF key") - - sdkECDHKey, err := ComputeECDHKey([]byte(sdkPrivateKeyAsPem), []byte(kasPubKeyAsPem)) - require.NoError(t, err, "fail to calculate ecdh key") - - sdkSymmetricKey, err := CalculateHKDF(digest.Sum(nil), sdkECDHKey) - require.NoError(t, err, "fail to calculate HKDF key") - - if string(kasSymmetricKey) != string(sdkSymmetricKey) { - t.Fatalf("symmetric keys on both kas and sdk should be same kas:%s sdk:%s", - string(kasSymmetricKey), string(sdkSymmetricKey)) - } + assert.Equal(t, sampleKey, unwrappedKey) } func TestECDSASignature(t *testing.T) { diff --git a/lib/ocrypto/rsa_key_pair.go b/lib/ocrypto/rsa_key_pair.go index 914eb8f80c..3f9b77cb3c 100644 --- a/lib/ocrypto/rsa_key_pair.go +++ b/lib/ocrypto/rsa_key_pair.go @@ -3,8 +3,6 @@ package ocrypto import ( "crypto/rand" "crypto/rsa" - "crypto/x509" - "encoding/pem" "errors" "fmt" ) @@ -30,22 +28,7 @@ func NewRSAKeyPair(bits int) (RsaKeyPair, error) { // PrivateKeyInPemFormat Returns private key in pem format. func (keyPair RsaKeyPair) PrivateKeyInPemFormat() (string, error) { - if keyPair.privateKey == nil { - return "", errors.New("failed to generate PEM formatted private key") - } - - privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(keyPair.privateKey) - if err != nil { - return "", fmt.Errorf("x509.MarshalPKCS8PrivateKey failed: %w", err) - } - - privateKeyPem := pem.EncodeToMemory( - &pem.Block{ - Type: "PRIVATE KEY", - Bytes: privateKeyBytes, - }, - ) - return string(privateKeyPem), nil + return privateKeyInPemFormat(keyPair.privateKey) } // PublicKeyInPemFormat Returns public key in pem format. @@ -54,19 +37,7 @@ func (keyPair RsaKeyPair) PublicKeyInPemFormat() (string, error) { return "", errors.New("failed to generate PEM formatted public key") } - publicKeyBytes, err := x509.MarshalPKIXPublicKey(&keyPair.privateKey.PublicKey) - if err != nil { - return "", fmt.Errorf("x509.MarshalPKIXPublicKey failed: %w", err) - } - - publicKeyPem := pem.EncodeToMemory( - &pem.Block{ - Type: "PUBLIC KEY", - Bytes: publicKeyBytes, - }, - ) - - return string(publicKeyPem), nil + return publicKeyInPemFormat(&keyPair.privateKey.PublicKey) } // KeySize Return the size of this rsa key pair. @@ -77,7 +48,34 @@ func (keyPair RsaKeyPair) KeySize() (int, error) { return keyPair.privateKey.N.BitLen(), nil } +func (keyPair RsaKeyPair) Decrypt(data []byte) ([]byte, error) { + return AsymDecryption{PrivateKey: keyPair.privateKey}.Decrypt(data) +} + +func (keyPair RsaKeyPair) Public() (PublicKeyEncryptor, error) { + if keyPair.privateKey == nil { + return nil, errors.New("failed to generate public key encryptor, private key is empty") + } + + return &AsymEncryption{PublicKey: &keyPair.privateKey.PublicKey}, nil +} + +func (keyPair RsaKeyPair) KeyType() KeyType { + if keyPair.privateKey == nil { + return KeyType("rsa:[unknown]") + } + + switch keyPair.privateKey.Size() { + case RSA2048Size / 8: //nolint:mnd // standard key size in bytes + return RSA2048Key + case RSA4096Size / 8: //nolint:mnd // large key size in bytes + return RSA4096Key + default: + return KeyType(fmt.Sprintf("rsa:%d", keyPair.privateKey.Size()*8)) //nolint:mnd // convert to bits + } +} + // GetKeyType returns the key type (RSAKey) func (keyPair RsaKeyPair) GetKeyType() KeyType { - return RSA2048Key + return keyPair.KeyType() } diff --git a/sdk/codegen/runner/generate.go b/sdk/codegen/runner/generate.go index 756b1e5c56..3b6e54f86a 100644 --- a/sdk/codegen/runner/generate.go +++ b/sdk/codegen/runner/generate.go @@ -177,12 +177,12 @@ func New%s%s%sConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts func generateInterfaceType(interfaceName string, methods []string, packageName, prefix, suffix string) string { // Generate the interface type definition var builder strings.Builder - builder.WriteString(fmt.Sprintf(` + fmt.Fprintf(&builder, ` type %s%s%s interface { -`, prefix, interfaceName, suffix)) +`, prefix, interfaceName, suffix) for _, method := range methods { - builder.WriteString(fmt.Sprintf(` %s(ctx context.Context, req *%s.%sRequest) (*%s.%sResponse, error) -`, method, packageName, method, packageName, method)) + fmt.Fprintf(&builder, ` %s(ctx context.Context, req *%s.%sRequest) (*%s.%sResponse, error) +`, method, packageName, method, packageName, method) } builder.WriteString("}\n") return builder.String() diff --git a/sdk/experimental/tdf/key_access.go b/sdk/experimental/tdf/key_access.go index 6e97701ad5..1fe8676f0e 100644 --- a/sdk/experimental/tdf/key_access.go +++ b/sdk/experimental/tdf/key_access.go @@ -165,78 +165,44 @@ func wrapKeyWithPublicKey(symKey []byte, pubKeyInfo keysplit.KASPublicKey) (stri // Determine key type based on algorithm ktype := ocrypto.KeyType(pubKeyInfo.Algorithm) + kasPublicKey, err := ocrypto.FromPublicPEM(pubKeyInfo.PEM) + if err != nil { + return "", "", "", fmt.Errorf("failed to create parse KAS public key: %w", err) + } + if ocrypto.IsECKeyType(ktype) { - // Handle EC key wrapping - return wrapKeyWithEC(ktype, pubKeyInfo.PEM, symKey) + if epk, ok := kasPublicKey.(ocrypto.ECEncryptor); ok { + // Handle EC key wrapping + return wrapKeyWithEC(ktype, epk, symKey) + } + return "", "", "", fmt.Errorf("incorrect key type for %v", ktype) } // Handle RSA key wrapping - wrapped, err := wrapKeyWithRSA(pubKeyInfo.PEM, symKey) + wrapped, err := wrapKeyWithRSA(kasPublicKey, symKey) return wrapped, "wrapped", "", err } // wrapKeyWithEC encrypts a key using EC public key with ECIES -func wrapKeyWithEC(keyType ocrypto.KeyType, kasPublicKeyPEM string, symKey []byte) (string, string, string, error) { - // Convert key type to ECC mode - mode, err := ocrypto.ECKeyTypeToMode(keyType) - if err != nil { - return "", "", "", fmt.Errorf("failed to convert key type to ECC mode: %w", err) - } - - // Generate ephemeral key pair - ecKeyPair, err := ocrypto.NewECKeyPair(mode) - if err != nil { - return "", "", "", fmt.Errorf("failed to create EC key pair: %w", err) - } - - // Get ephemeral public key in PEM format - ephemeralPubKey, err := ecKeyPair.PublicKeyInPemFormat() - if err != nil { - return "", "", "", fmt.Errorf("failed to get ephemeral public key: %w", err) +func wrapKeyWithEC(keyType ocrypto.KeyType, kasPublicKey ocrypto.ECEncryptor, symKey []byte) (string, string, string, error) { + if !ocrypto.IsECKeyType(kasPublicKey.KeyType()) { + return "", "", "", fmt.Errorf("unexpected KAS public key type: %v", kasPublicKey.KeyType()) } - // Get ephemeral private key - ephemeralPrivKey, err := ecKeyPair.PrivateKeyInPemFormat() + wrapped, err := kasPublicKey.Encrypt(symKey) if err != nil { - return "", "", "", fmt.Errorf("failed to get ephemeral private key: %w", err) + return "", "", "", fmt.Errorf("failed to wrap with %v: %w", keyType, err) } - // Compute ECDH shared secret - ecdhKey, err := ocrypto.ComputeECDHKey([]byte(ephemeralPrivKey), []byte(kasPublicKeyPEM)) + epk, err := kasPublicKey.EphemeralPublicKeyInPemFormat() if err != nil { - return "", "", "", fmt.Errorf("failed to compute ECDH key: %w", err) + return "", "", "", fmt.Errorf("failed to export ephemeral public key: %w", err) } - // Derive wrapping key using HKDF - salt := tdfSalt() - wrapKey, err := ocrypto.CalculateHKDF(salt, ecdhKey) - if err != nil { - return "", "", "", fmt.Errorf("failed to derive wrap key: %w", err) - } - - // Ensure we have the right length for wrapping, trim if needed, or error if too short - if len(wrapKey) > len(symKey) { - wrapKey = wrapKey[:len(symKey)] - } else if len(wrapKey) < len(symKey) { - return "", "", "", fmt.Errorf("wrap key too short: got %d, expected at least %d", - len(wrapKey), len(symKey)) - } - - wrapped := make([]byte, len(symKey)) - for i := range symKey { - wrapped[i] = symKey[i] ^ wrapKey[i] - } - - return string(ocrypto.Base64Encode(wrapped)), "eccWrapped", ephemeralPubKey, nil + return string(ocrypto.Base64Encode(wrapped)), "eccWrapped", epk, nil } // wrapKeyWithRSA encrypts a key using RSA public key with OAEP padding -func wrapKeyWithRSA(kasPublicKeyPEM string, symKey []byte) (string, error) { - // Create RSA encryptor from PEM - encryptor, err := ocrypto.FromPublicPEM(kasPublicKeyPEM) - if err != nil { - return "", fmt.Errorf("failed to create RSA encryptor: %w", err) - } - +func wrapKeyWithRSA(encryptor ocrypto.PublicKeyEncryptor, symKey []byte) (string, error) { // Encrypt with OAEP padding encryptedKey, err := encryptor.Encrypt(symKey) if err != nil { diff --git a/sdk/experimental/tdf/key_access_test.go b/sdk/experimental/tdf/key_access_test.go index 9dac3ca3b6..253e918b47 100644 --- a/sdk/experimental/tdf/key_access_test.go +++ b/sdk/experimental/tdf/key_access_test.go @@ -90,10 +90,13 @@ func TestBuildKeyAccessObjects(t *testing.T) { // Test that buildKeyAccessObjects correctly handles elliptic curve keys with ephemeral key generation // Generate a real EC P-256 key pair for testing - ecKeyPair, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1) - require.NoError(t, err, "Should generate EC key pair") + ecPrivateKey, err := ocrypto.NewECPrivateKey(ocrypto.ECCModeSecp256r1) + require.NoError(t, err, "Should generate EC private key") - ecPublicKeyPEM, err := ecKeyPair.PublicKeyInPemFormat() + ecPublicKey, err := ecPrivateKey.Public() + require.NoError(t, err, "Should derive EC public key") + + ecPublicKeyPEM, err := ecPublicKey.PublicKeyInPemFormat() require.NoError(t, err, "Should get public key in PEM format") splitResult := createTestSplitResult(testKAS1URL, ecPublicKeyPEM, "ec:secp256r1") @@ -397,10 +400,13 @@ func TestWrapKeyWithPublicKey(t *testing.T) { require.NoError(t, err) // Generate a real EC P-256 key pair for testing - ecKeyPair, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1) - require.NoError(t, err, "Should generate EC key pair") + ecPrivateKey, err := ocrypto.NewECPrivateKey(ocrypto.ECCModeSecp256r1) + require.NoError(t, err, "Should generate EC private key") + + ecPublicKey, err := ecPrivateKey.Public() + require.NoError(t, err, "Should derive EC public key") - ecPublicKeyPEM, err := ecKeyPair.PublicKeyInPemFormat() + ecPublicKeyPEM, err := ecPublicKey.PublicKeyInPemFormat() require.NoError(t, err, "Should get public key in PEM format") pubKeyInfo := keysplit.KASPublicKey{ diff --git a/sdk/kas_client.go b/sdk/kas_client.go index bb2b1f7f3a..9980aed1ee 100644 --- a/sdk/kas_client.go +++ b/sdk/kas_client.go @@ -2,9 +2,9 @@ package sdk import ( "context" - "crypto/sha256" "encoding/base64" "encoding/json" + "encoding/pem" "errors" "fmt" "net" @@ -36,7 +36,7 @@ type KASClient struct { accessTokenSource auth.AccessTokenSource httpClient *http.Client connectOptions []connect.ClientOption - sessionKey ocrypto.KeyPair + sessionKey ocrypto.PrivateKeyDecryptor // Set this to enable legacy, non-batch rewrap requests supportSingleRewrapEndpoint bool @@ -63,7 +63,7 @@ type additionalRewrapContext struct { Obligations obligationContext `json:"obligations"` } -func newKASClient(httpClient *http.Client, options []connect.ClientOption, accessTokenSource auth.AccessTokenSource, sessionKey ocrypto.KeyPair, fulfillableObligations []string) *KASClient { +func newKASClient(httpClient *http.Client, options []connect.ClientOption, accessTokenSource auth.AccessTokenSource, sessionKey ocrypto.PrivateKeyDecryptor, fulfillableObligations []string) *KASClient { return &KASClient{ accessTokenSource: accessTokenSource, httpClient: httpClient, @@ -174,7 +174,11 @@ func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapR if k.sessionKey == nil { return nil, errors.New("session key is nil") } - pubKey, err := k.sessionKey.PublicKeyInPemFormat() + publicKeyEncryptor, err := k.sessionKey.Public() + if err != nil { + return nil, fmt.Errorf("failed to create public key encryptor: %w", err) + } + pubKey, err := publicKeyEncryptor.PublicKeyInPemFormat() if err != nil { return nil, fmt.Errorf("ocrypto.PublicKeyInPermFormat failed: %w", err) } @@ -183,47 +187,35 @@ func (k *KASClient) unwrap(ctx context.Context, requests ...*kas.UnsignedRewrapR return nil, fmt.Errorf("error making rewrap request to kas: %w", err) } - if ocrypto.IsECKeyType(k.sessionKey.GetKeyType()) { + if ocrypto.IsECKeyType(k.sessionKey.KeyType()) { return k.handleECKeyResponse(response) } return k.handleRSAKeyResponse(response) } func (k *KASClient) handleECKeyResponse(response *kas.RewrapResponse) (map[string][]kaoResult, error) { - kasEphemeralPublicKey := response.GetSessionPublicKey() - clientPrivateKey, err := k.sessionKey.PrivateKeyInPemFormat() - if err != nil { - return nil, fmt.Errorf("failed to get private key: %w", err) - } - ecdhKey, err := ocrypto.ComputeECDHKey([]byte(clientPrivateKey), []byte(kasEphemeralPublicKey)) - if err != nil { - return nil, fmt.Errorf("ocrypto.ComputeECDHKey failed: %w", err) - } - - digest := sha256.New() - digest.Write([]byte("TDF")) - salt := digest.Sum(nil) - sessionKey, err := ocrypto.CalculateHKDF(salt, ecdhKey) - if err != nil { - return nil, fmt.Errorf("ocrypto.CalculateHKDF failed: %w", err) + kasEphemeralPublicKeyPEM := response.GetSessionPublicKey() + block, _ := pem.Decode([]byte(kasEphemeralPublicKeyPEM)) + if block == nil { + return nil, errors.New("failed to decode KAS session public key PEM") } - aesGcm, err := ocrypto.NewAESGcm(sessionKey) - if err != nil { - return nil, fmt.Errorf("ocrypto.NewAESGcm failed: %w", err) + ecDecryptor, ok := k.sessionKey.(ocrypto.ECDecryptor) + if !ok { + return nil, errors.New("session key is not an EC decryptor") } - return k.processECResponse(response, aesGcm) + return k.processECResponse(response, ecDecryptor, block.Bytes) } -func (k *KASClient) processECResponse(response *kas.RewrapResponse, aesGcm ocrypto.AesGcm) (map[string][]kaoResult, error) { +func (k *KASClient) processECResponse(response *kas.RewrapResponse, ecDecryptor ocrypto.ECDecryptor, ephemeralKeyDER []byte) (map[string][]kaoResult, error) { policyResults := make(map[string][]kaoResult) for _, results := range response.GetResponses() { var kaoKeys []kaoResult for _, kao := range results.GetResults() { requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata()) if kao.GetStatus() == statusPermit { - key, err := aesGcm.Decrypt(kao.GetKasWrappedKey()) + key, err := ecDecryptor.DecryptWithEphemeralKey(kao.GetKasWrappedKey(), ephemeralKeyDER) if err != nil { kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO}) } else { @@ -268,20 +260,10 @@ func (k *KASClient) retrieveObligationsFromMetadata(metadata map[string]*structp } func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[string][]kaoResult, error) { - clientPrivateKey, err := k.sessionKey.PrivateKeyInPemFormat() - if err != nil { - return nil, fmt.Errorf("ocrypto.PrivateKeyInPemFormat failed: %w", err) - } - - asymDecryption, err := ocrypto.NewAsymDecryption(clientPrivateKey) - if err != nil { - return nil, fmt.Errorf("ocrypto.NewAsymDecryption failed: %w", err) - } - - return k.processRSAResponse(response, asymDecryption) + return k.processRSAResponse(response, k.sessionKey) } -func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecryption ocrypto.AsymDecryption) (map[string][]kaoResult, error) { +func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecryption ocrypto.PrivateKeyDecryptor) (map[string][]kaoResult, error) { policyResults := make(map[string][]kaoResult) for _, results := range response.GetResponses() { var kaoKeys []kaoResult diff --git a/sdk/kas_client_test.go b/sdk/kas_client_test.go index 7fd4b6b60c..ea6e57c27a 100644 --- a/sdk/kas_client_test.go +++ b/sdk/kas_client_test.go @@ -2,7 +2,6 @@ package sdk import ( "context" - "crypto/sha256" "encoding/base64" "encoding/json" "net/http" @@ -540,35 +539,27 @@ func Test_processRSAResponse(t *testing.T) { func Test_processECResponse(t *testing.T) { c := newKASClient(nil, nil, nil, nil, nil) - // 1. Set up keys for encryption - kasPublicKey, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1) + // 1. Generate client EC session key pair (what the SDK would use for session negotiation) + clientPrivateKey, err := ocrypto.NewECPrivateKey(ocrypto.ECCModeSecp256r1) require.NoError(t, err) - kasPublicKeyPEM, err := kasPublicKey.PublicKeyInPemFormat() + clientPublicKey, err := clientPrivateKey.Public() require.NoError(t, err) - - clientPrivateKey, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1) - require.NoError(t, err) - clientPrivateKeyPEM, err := clientPrivateKey.PrivateKeyInPemFormat() + clientPublicKeyPEM, err := clientPublicKey.PublicKeyInPemFormat() require.NoError(t, err) - // 2. Compute shared secret and derive session key (for encryption) - ecdhKey, err := ocrypto.ComputeECDHKey([]byte(clientPrivateKeyPEM), []byte(kasPublicKeyPEM)) - require.NoError(t, err) - - digest := sha256.New() - digest.Write([]byte("TDF")) - salt := digest.Sum(nil) - sessionKey, err := ocrypto.CalculateHKDF(salt, ecdhKey) - require.NoError(t, err) - - // 3. Create AES-GCM cipher for encryption - encryptor, err := ocrypto.NewAESGcm(sessionKey) + // 2. Simulate KAS encrypting symmetricKey2 for the client using ECIES + // (same TDF salt used by NewECDecryptor/NewECPrivateKey on the client side) + asymEncrypt, err := ocrypto.FromPublicPEMWithSalt(clientPublicKeyPEM, tdfSalt(), nil) require.NoError(t, err) symmetricKey2 := []byte("supersecretkey2") - wrappedKey2, err := encryptor.Encrypt(symmetricKey2) + wrappedKey2, err := asymEncrypt.Encrypt(symmetricKey2) require.NoError(t, err) + // 3. Get the KAS-generated ephemeral key DER bytes (what KAS sends as SessionPublicKey) + kasEphemeralKeyDER := asymEncrypt.EphemeralKey() + require.NotEmpty(t, kasEphemeralKeyDER) + // 5. Create mock response with multiple policies response := &kaspb.RewrapResponse{ Responses: []*kaspb.PolicyRewrapResult{ @@ -615,12 +606,8 @@ func Test_processECResponse(t *testing.T) { }, } - // 6. Create AES-GCM cipher for decryption (using the same session key) - decryptor, err := ocrypto.NewAESGcm(sessionKey) - require.NoError(t, err) - - // 7. Process the response - policyResults, err := c.processECResponse(response, decryptor) + // 4. Process the response using client's private key and KAS ephemeral key DER + policyResults, err := c.processECResponse(response, clientPrivateKey, kasEphemeralKeyDER) require.NoError(t, err) require.Len(t, policyResults, 2) diff --git a/sdk/tdf.go b/sdk/tdf.go index db22208b5f..fbfec932d1 100644 --- a/sdk/tdf.go +++ b/sdk/tdf.go @@ -65,7 +65,7 @@ type Reader struct { aesGcm ocrypto.AesGcm payloadSize int64 payloadKey []byte - kasSessionKey ocrypto.KeyPair + kasSessionKey ocrypto.PrivateKeyDecryptor config TDFReaderConfig requiredObligations *RequiredObligations } @@ -675,11 +675,7 @@ func createKeyAccess(kasInfo KASInfo, symKey []byte, policyBinding PolicyBinding ktype := ocrypto.KeyType(kasInfo.Algorithm) if ocrypto.IsECKeyType(ktype) { - mode, err := ocrypto.ECKeyTypeToMode(ktype) - if err != nil { - return KeyAccess{}, err - } - wrappedKeyInfo, err := generateWrapKeyWithEC(mode, kasInfo.PublicKey, symKey) + wrappedKeyInfo, err := generateWrapKeyWithEC(kasInfo.PublicKey, symKey) if err != nil { return KeyAccess{}, err } @@ -704,45 +700,29 @@ func tdfSalt() []byte { return salt } -func generateWrapKeyWithEC(mode ocrypto.ECCMode, kasPublicKey string, symKey []byte) (ecKeyWrappedKeyInfo, error) { - ecKeyPair, err := ocrypto.NewECKeyPair(mode) - if err != nil { - return ecKeyWrappedKeyInfo{}, fmt.Errorf("ocrypto.NewECKeyPair failed:%w", err) - } - - emphermalPublicKey, err := ecKeyPair.PublicKeyInPemFormat() - if err != nil { - return ecKeyWrappedKeyInfo{}, fmt.Errorf("generateWrapKeyWithEC: failed to get EC public key: %w", err) - } - - emphermalPrivateKey, err := ecKeyPair.PrivateKeyInPemFormat() +func generateWrapKeyWithEC(kasPublicKey string, symKey []byte) (ecKeyWrappedKeyInfo, error) { + asymEncrypt, err := ocrypto.FromPublicPEMWithSalt(kasPublicKey, tdfSalt(), nil) if err != nil { - return ecKeyWrappedKeyInfo{}, fmt.Errorf("generateWrapKeyWithEC: failed to get EC private key: %w", err) + return ecKeyWrappedKeyInfo{}, fmt.Errorf("generateWrapKeyWithEC: failed to create EC encryptor: %w", err) } - ecdhKey, err := ocrypto.ComputeECDHKey([]byte(emphermalPrivateKey), []byte(kasPublicKey)) - if err != nil { - return ecKeyWrappedKeyInfo{}, fmt.Errorf("generateWrapKeyWithEC: ocrypto.ComputeECDHKey failed:%w", err) + ecEnc, ok := asymEncrypt.(ocrypto.ECEncryptor) + if !ok { + return ecKeyWrappedKeyInfo{}, errors.New("generateWrapKeyWithEC: KAS public key is not an EC key") } - salt := tdfSalt() - sessionKey, err := ocrypto.CalculateHKDF(salt, ecdhKey) - if err != nil { - return ecKeyWrappedKeyInfo{}, fmt.Errorf("generateWrapKeyWithEC: ocrypto.CalculateHKDF failed:%w", err) - } - - gcm, err := ocrypto.NewAESGcm(sessionKey) + wrappedKey, err := asymEncrypt.Encrypt(symKey) if err != nil { - return ecKeyWrappedKeyInfo{}, fmt.Errorf("generateWrapKeyWithEC: ocrypto.NewAESGcm failed:%w", err) + return ecKeyWrappedKeyInfo{}, fmt.Errorf("generateWrapKeyWithEC: failed to wrap key: %w", err) } - wrappedKey, err := gcm.Encrypt(symKey) + ephemeralPublicKey, err := ecEnc.EphemeralPublicKeyInPemFormat() if err != nil { - return ecKeyWrappedKeyInfo{}, fmt.Errorf("generateWrapKeyWithEC: ocrypto.AESGcm.Encrypt failed:%w", err) + return ecKeyWrappedKeyInfo{}, fmt.Errorf("generateWrapKeyWithEC: failed to get ephemeral public key: %w", err) } return ecKeyWrappedKeyInfo{ - publicKey: emphermalPublicKey, + publicKey: ephemeralPublicKey, wrappedKey: string(ocrypto.Base64Encode(wrappedKey)), }, nil } diff --git a/sdk/tdf_config.go b/sdk/tdf_config.go index b77e85030c..46bca51cd9 100644 --- a/sdk/tdf_config.go +++ b/sdk/tdf_config.go @@ -268,7 +268,7 @@ type TDFReaderConfig struct { disableAssertionVerification bool schemaValidationIntensity SchemaValidationIntensity - kasSessionKey ocrypto.KeyPair + kasSessionKey ocrypto.PrivateKeyDecryptor kasAllowlist AllowList // KAS URLs that are allowed to be used for reading TDFs ignoreAllowList bool // If true, the kasAllowlist will be ignored, and all KAS URLs will be allowed fulfillableObligationFQNs []string @@ -405,9 +405,9 @@ func WithDisableAssertionVerification(disable bool) TDFReaderOption { func WithSessionKeyType(keyType ocrypto.KeyType) TDFReaderOption { return func(c *TDFReaderConfig) error { - kasSessionKey, err := ocrypto.NewKeyPair(keyType) + kasSessionKey, err := ocrypto.NewPrivateKeyDecryptor(keyType) if err != nil { - return fmt.Errorf("failed to create RSA key pair: %w", err) + return fmt.Errorf("failed to create private key decryptor: %w", err) } c.kasSessionKey = kasSessionKey return nil @@ -446,7 +446,7 @@ func WithTDFFulfillableObligationFQNs(fqns []string) TDFReaderOption { } } -func withSessionKey(k ocrypto.KeyPair) TDFReaderOption { +func withSessionKey(k ocrypto.PrivateKeyDecryptor) TDFReaderOption { return func(c *TDFReaderConfig) error { c.kasSessionKey = k return nil diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go index 0aa58daa3f..96fd14fea1 100644 --- a/sdk/tdf_test.go +++ b/sdk/tdf_test.go @@ -2906,7 +2906,7 @@ func (f *FakeKas) getRewrapResponse(rewrapRequest string, fulfillableObligations var sessionKey string if e, found := asymEncrypt.(ocrypto.ECEncryptor); found { - sessionKey, err = e.PublicKeyInPemFormat() + sessionKey, err = e.EphemeralPublicKeyInPemFormat() f.s.Require().NoError(err, "unable to serialize ephemeral key") } resp.SessionPublicKey = sessionKey diff --git a/service/internal/security/basic_manager.go b/service/internal/security/basic_manager.go index 6197e50eb3..a3cc20a33e 100644 --- a/service/internal/security/basic_manager.go +++ b/service/internal/security/basic_manager.go @@ -3,10 +3,8 @@ package security import ( "context" "crypto/elliptic" - "crypto/x509" "encoding/base64" "encoding/hex" - "encoding/pem" "errors" "fmt" "log/slog" @@ -80,13 +78,9 @@ func (b *BasicManager) Decrypt(ctx context.Context, keyDetails trust.KeyDetails, } return protectedKey, nil case ocrypto.EC256Key, ocrypto.EC384Key, ocrypto.EC521Key: - ecPrivKey, err := ocrypto.ECPrivateKeyFromPem(privKey) - if err != nil { - return nil, fmt.Errorf("failed to create EC private key from PEM: %w", err) - } - ecDecryptor, err := ocrypto.NewECDecryptor(ecPrivKey) - if err != nil { - return nil, fmt.Errorf("failed to create ECDecryptor: %w", err) + ecDecryptor, ok := decrypter.(ocrypto.ECDecryptor) + if !ok { + return nil, fmt.Errorf("failed to create ECDecryptor: unexpected type %T", decrypter) } plaintext, err := ecDecryptor.DecryptWithEphemeralKey(ciphertext, ephemeralPublicKey) if err != nil { @@ -102,47 +96,9 @@ func (b *BasicManager) Decrypt(ctx context.Context, keyDetails trust.KeyDetails, return nil, fmt.Errorf("unsupported algorithm: %s", keyDetails.Algorithm()) } -func (b *BasicManager) DeriveKey(ctx context.Context, keyDetails trust.KeyDetails, ephemeralPublicKeyBytes []byte, curve elliptic.Curve) (ocrypto.ProtectedKey, error) { - // Implementation of DeriveKey method - privateKeyCtx, err := keyDetails.ExportPrivateKey(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get private key: %w", err) - } - - privKey, err := b.unwrap(ctx, string(keyDetails.ID()), privateKeyCtx.WrappedKey) - if err != nil { - return nil, fmt.Errorf("failed to unwrap private key: %w", err) - } - - ephemeralECDSAPublicKey, err := ocrypto.UncompressECPubKey(curve, ephemeralPublicKeyBytes) - if err != nil { - return nil, fmt.Errorf("failed to uncompress ephemeral public key: %w", err) - } - - derBytes, err := x509.MarshalPKIXPublicKey(ephemeralECDSAPublicKey) - if err != nil { - return nil, fmt.Errorf("failed to marshal ECDSA public key: %w", err) - } - pemBlock := &pem.Block{ - Type: "PUBLIC KEY", - Bytes: derBytes, - } - ephemeralECDSAPublicKeyPEM := pem.EncodeToMemory(pemBlock) - - symmetricKey, err := ocrypto.ComputeECDHKey(privKey, ephemeralECDSAPublicKeyPEM) - if err != nil { - return nil, fmt.Errorf("failed to compute ECDH key: %w", err) - } - - key, err := ocrypto.CalculateHKDF(TDFSalt(), symmetricKey) - if err != nil { - return nil, fmt.Errorf("failed to calculate HKDF: %w", err) - } - protectedKey, err := ocrypto.NewAESProtectedKey(key) - if err != nil { - return nil, fmt.Errorf("failed to create protected key: %w", err) - } - return protectedKey, nil +// Deprecated: Prefer to directly unwrap the value with Decrypt. +func (b *BasicManager) DeriveKey(_ context.Context, _ trust.KeyDetails, _ []byte, _ elliptic.Curve) (ocrypto.ProtectedKey, error) { + return nil, errors.New("unsupported operation") } type OCEncapsulator struct { diff --git a/service/internal/security/basic_manager_test.go b/service/internal/security/basic_manager_test.go index ad237c9662..006664e64e 100644 --- a/service/internal/security/basic_manager_test.go +++ b/service/internal/security/basic_manager_test.go @@ -152,9 +152,19 @@ func generateRSAKeyAndPEM() (ocrypto.RsaKeyPair, error) { return ocrypto.NewRSAKeyPair(2048) } -// Helper function to generate EC key pair and PEM encode private key -func generateECKeyAndPEM(curve ocrypto.ECCMode) (ocrypto.ECKeyPair, error) { - return ocrypto.NewECKeyPair(curve) +// Helper function to generate an EC private key and derive the matching public key. +func generateECKeyAndPEM(curve ocrypto.ECCMode) (ocrypto.ECDecryptor, ocrypto.PublicKeyEncryptor, error) { + privateKey, err := ocrypto.NewECPrivateKey(curve) + if err != nil { + return ocrypto.ECDecryptor{}, nil, err + } + + publicKey, err := privateKey.Public() + if err != nil { + return ocrypto.ECDecryptor{}, nil, err + } + + return privateKey, publicKey, nil } func compressEphemeralPublicKey(t *testing.T, der []byte) []byte { @@ -327,11 +337,11 @@ func TestBasicManager_Decrypt(t *testing.T) { wrappedRSAPrivKeyStr, err := wrapKeyWithAESGCM([]byte(rsaPrivKey), rootKey) require.NoError(t, err) - ecKey, err := generateECKeyAndPEM(ocrypto.ECCModeSecp256r1) + ecKey, ecPublicKey, err := generateECKeyAndPEM(ocrypto.ECCModeSecp256r1) require.NoError(t, err) ecPrivKey, err := ecKey.PrivateKeyInPemFormat() require.NoError(t, err) - ecPubKey, err := ecKey.PublicKeyInPemFormat() + ecPubKey, err := ecPublicKey.PublicKeyInPemFormat() require.NoError(t, err) wrappedECPrivKeyStr, err := wrapKeyWithAESGCM([]byte(ecPrivKey), rootKey) @@ -419,11 +429,11 @@ func TestBasicManager_Decrypt(t *testing.T) { }, } { t.Run("successful EC decryption with compressed ephemeral key "+tc.name, func(t *testing.T) { - curveKey, err := generateECKeyAndPEM(tc.mode) + curveKey, curvePublicKey, err := generateECKeyAndPEM(tc.mode) require.NoError(t, err) curvePrivKey, err := curveKey.PrivateKeyInPemFormat() require.NoError(t, err) - curvePubKey, err := curveKey.PublicKeyInPemFormat() + curvePubKey, err := curvePublicKey.PublicKeyInPemFormat() require.NoError(t, err) wrappedCurvePrivKeyStr, err := wrapKeyWithAESGCM([]byte(curvePrivKey), rootKey) @@ -512,84 +522,6 @@ func TestBasicManager_Decrypt(t *testing.T) { }) } -func TestBasicManager_DeriveKey(t *testing.T) { - log := logger.CreateTestLogger() - testCache := newTestCache(t, log) - rootKeyHex := "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f" - rootKey, _ := hex.DecodeString(rootKeyHex) - - ecKey, err := generateECKeyAndPEM(ocrypto.ECCModeSecp256r1) - require.NoError(t, err) - ecPrivKey, err := ecKey.PrivateKeyInPemFormat() - require.NoError(t, err) - - wrappedECPrivKeyStr, err := wrapKeyWithAESGCM([]byte(ecPrivKey), rootKey) - require.NoError(t, err) - - bm, err := NewBasicManager(log, testCache, rootKeyHex) - require.NoError(t, err) - - clientEphemeralECDHKey, err := ecdh.P256().GenerateKey(rand.Reader) - require.NoError(t, err) - // Ensure the public key is in compressed format as expected by ocrypto.UncompressECPubKey - ecdhPubKey := clientEphemeralECDHKey.PublicKey() - pubKeyBytes, err := x509.MarshalPKIXPublicKey(ecdhPubKey) - require.NoError(t, err) - parsedPubKey, err := x509.ParsePKIXPublicKey(pubKeyBytes) - require.NoError(t, err) - clientECDSAKey, ok := parsedPubKey.(*ecdsa.PublicKey) - require.True(t, ok, "failed to convert ecdh.PublicKey to *ecdsa.PublicKey") - clientEphemeralPublicKeyBytes, err := ocrypto.CompressedECPublicKey(ocrypto.ECCModeSecp256r1, *clientECDSAKey) - require.NoError(t, err) - - t.Run("successful key derivation", func(t *testing.T) { - mockDetails := new(MockKeyDetails) - mockDetails.MID = "ec-kid-derive" - mockDetails.MAlgorithm = AlgorithmECP256R1 - mockDetails.MPrivateKey = &policy.PrivateKeyCtx{WrappedKey: wrappedECPrivKeyStr} - - // Set up mock expectations - mockDetails.On("ID").Return(trust.KeyIdentifier(mockDetails.MID)) - mockDetails.On("Algorithm").Return(mockDetails.MAlgorithm) - mockDetails.On("ExportPrivateKey").Return(&trust.PrivateKey{WrappingKeyID: trust.KeyIdentifier(mockDetails.MPrivateKey.GetKeyId()), WrappedKey: mockDetails.MPrivateKey.GetWrappedKey()}, nil) - - protectedKey, err := bm.DeriveKey(t.Context(), mockDetails, clientEphemeralPublicKeyBytes, elliptic.P256()) - require.NoError(t, err) - require.NotNil(t, protectedKey) - - ecdhPrivKey, err := ocrypto.ECPrivateKeyFromPem([]byte(ecPrivKey)) // ECDH private key - require.NoError(t, err) - - // We need to compute the shared secret using the private key and the client ephemeral public key - clientEphemeralECDSAPubKey, err := ocrypto.UncompressECPubKey(elliptic.P256(), clientEphemeralPublicKeyBytes) - require.NoError(t, err) - clientECDHPublicKey, err := ocrypto.ConvertToECDHPublicKey(clientEphemeralECDSAPubKey) - require.NoError(t, err) - - expectedSharedSecret, err := ocrypto.ComputeECDHKeyFromECDHKeys(clientECDHPublicKey, ecdhPrivKey) - require.NoError(t, err) - expectedDerivedKey, err := ocrypto.CalculateHKDF(TDFSalt(), expectedSharedSecret) - require.NoError(t, err) - - // Use noOpEncapsulator to get raw key data for testing - noOpEnc := &noOpEncapsulator{} - //nolint:staticcheck // Export is used in tests until ProtectedKey deprecation is removed upstream. - actualDerivedKey, err := protectedKey.Export(noOpEnc) - require.NoError(t, err) - assert.Equal(t, expectedDerivedKey, actualDerivedKey) - }) - - t.Run("fail ExportPrivateKey for DeriveKey", func(t *testing.T) { - mockDetails := new(MockKeyDetails) - mockDetails.On("ID").Return(trust.KeyIdentifier("fail-export-derive")) - mockDetails.On("ExportPrivateKey").Return(nil, errors.New("export failed derive")) - - _, err := bm.DeriveKey(t.Context(), mockDetails, clientEphemeralPublicKeyBytes, elliptic.P256()) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to get private key") - }) -} - func TestBasicManager_GenerateECSessionKey(t *testing.T) { log := logger.CreateTestLogger() testCache := newTestCache(t, log) diff --git a/service/internal/security/in_process_provider.go b/service/internal/security/in_process_provider.go index 67c86487aa..1f0779ad90 100644 --- a/service/internal/security/in_process_provider.go +++ b/service/internal/security/in_process_provider.go @@ -4,8 +4,6 @@ import ( "context" "crypto" "crypto/elliptic" - "crypto/x509" - "encoding/pem" "errors" "fmt" "log/slog" @@ -280,45 +278,8 @@ func (a *InProcessProvider) Decrypt(ctx context.Context, keyDetails trust.KeyDet } // DeriveKey computes an ECDH shared secret and derives an AES key via HKDF. -func (a *InProcessProvider) DeriveKey(_ context.Context, keyDetails trust.KeyDetails, ephemeralPublicKeyBytes []byte, curve elliptic.Curve) (ocrypto.ProtectedKey, error) { - kid := string(keyDetails.ID()) - k, ok := a.cryptoProvider.keysByID[kid] - if !ok { - return nil, ErrKeyPairInfoNotFound - } - ec, ok := k.(StandardECCrypto) - if !ok { - return nil, ErrKeyPairInfoMalformed - } - - ephemeralECDSAPublicKey, err := ocrypto.UncompressECPubKey(curve, ephemeralPublicKeyBytes) - if err != nil { - return nil, err - } - - derBytes, err := x509.MarshalPKIXPublicKey(ephemeralECDSAPublicKey) - if err != nil { - return nil, fmt.Errorf("failed to marshal ECDSA public key: %w", err) - } - ephemeralECDSAPublicKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "PUBLIC KEY", - Bytes: derBytes, - }) - - symmetricKey, err := ocrypto.ComputeECDHKey([]byte(ec.ecPrivateKeyPem), ephemeralECDSAPublicKeyPEM) - if err != nil { - return nil, fmt.Errorf("ocrypto.ComputeECDHKey failed: %w", err) - } - - key, err := ocrypto.CalculateHKDF(TDFSalt(), symmetricKey) - if err != nil { - return nil, fmt.Errorf("ocrypto.CalculateHKDF failed:%w", err) - } - protectedKey, err := ocrypto.NewAESProtectedKey(key) - if err != nil { - return nil, fmt.Errorf("failed to create protected key: %w", err) - } - return protectedKey, nil +func (a *InProcessProvider) DeriveKey(_ context.Context, _ trust.KeyDetails, _ []byte, _ elliptic.Curve) (ocrypto.ProtectedKey, error) { + return nil, errors.New("unsupported operation") } // GenerateECSessionKey generates a session key for ECDH-based response encryption. diff --git a/service/kas/access/rewrap.go b/service/kas/access/rewrap.go index d62975a480..f8015e5afb 100644 --- a/service/kas/access/rewrap.go +++ b/service/kas/access/rewrap.go @@ -921,7 +921,7 @@ func (p *Provider) tdf3Rewrap(ctx context.Context, requests []*kaspb.UnsignedRew var sessionKey string if e, ok := asymEncrypt.(ocrypto.ECEncryptor); ok { - sessionKey, err = e.PublicKeyInPemFormat() + sessionKey, err = e.EphemeralPublicKeyInPemFormat() if err != nil { p.Logger.ErrorContext(ctx, "unable to serialize ephemeral key", slog.Any("error", err)) // This may be a 500, but could also be caused by a bad clientPublicKey diff --git a/service/trust/key_manager.go b/service/trust/key_manager.go index 6521fd095f..ab71040c90 100644 --- a/service/trust/key_manager.go +++ b/service/trust/key_manager.go @@ -23,12 +23,15 @@ type KeyManager interface { Name() string // Decrypt decrypts data that was encrypted with the key identified by keyID + // This is exclusively used for unwrapping Key Access Object splits. // For EC keys, ephemeralPublicKey must be non-nil // For RSA keys, ephemeralPublicKey should be nil // Returns an UnwrappedKeyData interface for further operations Decrypt(ctx context.Context, key KeyDetails, ciphertext []byte, ephemeralPublicKey []byte) (ProtectedKey, error) // DeriveKey computes an agreed upon secret key derived from an ECDH exchange. + // + // Deprecated: Directly use Decrypt when appropriate. DeriveKey(ctx context.Context, key KeyDetails, ephemeralPublicKeyBytes []byte, curve elliptic.Curve) (ProtectedKey, error) // GenerateECSessionKey generates a private session key, for use with a client-provided ephemeral public key