Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ traces/
.cache/*

# Claude AI files
.claude/
.claude/
176 changes: 148 additions & 28 deletions lib/ocrypto/asym_decryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
Expand All @@ -26,6 +27,38 @@ type AsymDecryption struct {
type PrivateKeyDecryptor interface {
// Decrypt decrypts ciphertext with private key.
Decrypt(data []byte) ([]byte, error)

// PrivateKeyInPemFormat returns the private key in PEM format.
PrivateKeyInPemFormat() (string, error)

// Public returns the corresponding public-key encryptor.
Public() (PublicKeyEncryptor, error)

// KeyType returns the key type, e.g. RSA or EC.
KeyType() KeyType
}

func NewPrivateKeyDecryptor(kt KeyType) (PrivateKeyDecryptor, error) {
switch {
case IsRSAKeyType(kt):
bits, err := RSAKeyTypeToBits(kt)
if err != nil {
return nil, err
}
keyPair, err := NewRSAKeyPair(bits)
if err != nil {
return nil, err
}
return keyPair, nil
case IsECKeyType(kt):
mode, err := ECKeyTypeToMode(kt)
if err != nil {
return nil, err
}
return NewECPrivateKey(mode)
default:
return nil, fmt.Errorf("unsupported key type: %v", kt)
}
}

// FromPrivatePEM creates and returns a new AsymDecryption.
Expand Down Expand Up @@ -109,6 +142,33 @@ func (asymDecryption AsymDecryption) Decrypt(data []byte) ([]byte, error) {
return bytes, nil
}

func (asymDecryption AsymDecryption) PrivateKeyInPemFormat() (string, error) {
return privateKeyInPemFormat(asymDecryption.PrivateKey)
}

func (asymDecryption AsymDecryption) Public() (PublicKeyEncryptor, error) {
if asymDecryption.PrivateKey == nil {
return nil, errors.New("failed to generate public key encryptor, private key is empty")
}

return &AsymEncryption{PublicKey: &asymDecryption.PrivateKey.PublicKey}, nil
}

func (asymDecryption AsymDecryption) KeyType() KeyType {
if asymDecryption.PrivateKey == nil {
return KeyType("rsa:[unknown]")
}
Comment on lines +158 to +160
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check for asymDecryption.PrivateKey == nil is redundant here because PrivateKey is a field of the struct and the method is called on an instance. If the instance is initialized, this field should be checked at the point of creation or usage, not inside every getter method. This adds unnecessary complexity.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that would make sense.. unfortunately FromRSA and the corresponding SDK WithSession...RSA options don't return error so it would be hard to wire that in at the moment


switch asymDecryption.PrivateKey.Size() {
case RSA2048Size / 8: //nolint:mnd // standard key size in bytes
return RSA2048Key
case RSA4096Size / 8: //nolint:mnd // large key size in bytes
return RSA4096Key
default:
return KeyType(fmt.Sprintf("rsa:%d", asymDecryption.PrivateKey.Size()*8)) //nolint:mnd // convert to bits
}
}

type ECDecryptor struct {
sk *ecdh.PrivateKey
salt []byte
Expand All @@ -124,6 +184,20 @@ func NewECDecryptor(sk *ecdh.PrivateKey) (ECDecryptor, error) {
return ECDecryptor{sk, salt, nil}, nil
}

func NewECPrivateKey(mode ECCMode) (ECDecryptor, error) {
curve, err := curveFromECCMode(mode)
if err != nil {
return ECDecryptor{}, err
}

sk, err := curve.GenerateKey(rand.Reader)
if err != nil {
return ECDecryptor{}, fmt.Errorf("ecdh.GenerateKey failed: %w", err)
}

return NewECDecryptor(sk)
}

func NewSaltedECDecryptor(sk *ecdh.PrivateKey, salt, info []byte) (ECDecryptor, error) {
return ECDecryptor{sk, salt, info}, nil
}
Expand All @@ -133,30 +207,30 @@ func (e ECDecryptor) Decrypt(_ []byte) ([]byte, error) {
return nil, errors.New("ecdh standard decrypt unimplemented")
}

func (e ECDecryptor) DecryptWithEphemeralKey(data, ephemeral []byte) ([]byte, error) {
var ek *ecdh.PublicKey
func (e ECDecryptor) PrivateKeyInPemFormat() (string, error) {
return privateKeyInPemFormat(e.sk)
}

if pubFromDSN, err := x509.ParsePKIXPublicKey(ephemeral); err == nil {
switch pubFromDSN := pubFromDSN.(type) {
case *ecdsa.PublicKey:
ek, err = ConvertToECDHPublicKey(pubFromDSN)
if err != nil {
return nil, fmt.Errorf("ecdh conversion failure: %w", err)
}
case *ecdh.PublicKey:
ek = pubFromDSN
default:
return nil, fmt.Errorf("unsupported public key of type: %T", pubFromDSN)
}
} else {
ekDSA, err := UncompressECPubKey(convCurve(e.sk.Curve()), ephemeral)
if err != nil {
return nil, err
}
ek, err = ekDSA.ECDH()
if err != nil {
return nil, fmt.Errorf("ecdh failure: %w", err)
}
func (e ECDecryptor) Public() (PublicKeyEncryptor, error) {
if e.sk == nil {
return nil, errors.New("failed to generate public key encryptor, private key is empty")
}

return newECIES(e.sk.PublicKey(), e.salt, e.info)
}

func (e ECDecryptor) KeyType() KeyType {
if e.sk == nil {
return KeyType("ec:[unknown]")
}
Comment on lines +223 to +225
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the RSA implementation, checking for e.sk == nil inside the KeyType() method is defensive programming that might hide initialization issues. It is better to ensure the struct is properly initialized upon creation.


return keyTypeFromECDHCurve(e.sk.Curve())
}

func (e ECDecryptor) DecryptWithEphemeralKey(data, ephemeral []byte) ([]byte, error) {
ek, err := e.parseEphemeralPublicKey(ephemeral)
if err != nil {
return nil, fmt.Errorf("failed to parse ephemeral public key: %w", err)
}

ikm, err := e.sk.ECDH(ek)
Expand Down Expand Up @@ -196,15 +270,61 @@ func (e ECDecryptor) DecryptWithEphemeralKey(data, ephemeral []byte) ([]byte, er
return plaintext, nil
}

func convCurve(c ecdh.Curve) elliptic.Curve {
func (e ECDecryptor) deriveSharedKey(publicKeyInPem string) ([]byte, error) {
if e.sk == nil {
return nil, errors.New("failed to derive shared key, private key is empty")
}

pub, err := getPublicPart(publicKeyInPem)
if err != nil {
return nil, err
}

ecdhPublicKey, err := ConvertToECDHPublicKey(pub)
if err != nil {
return nil, fmt.Errorf("unsupported public key type: %w", err)
}

sharedKey, err := e.sk.ECDH(ecdhPublicKey)
if err != nil {
return nil, fmt.Errorf("there was a problem deriving a shared ECDH key: %w", err)
}

return sharedKey, nil
}

// parseEphemeralPublicKey parses an ephemeral public key from DER (PKIX) or compressed EC point bytes.
func (e ECDecryptor) parseEphemeralPublicKey(ephemeral []byte) (*ecdh.PublicKey, error) {
if pub, err := x509.ParsePKIXPublicKey(ephemeral); err == nil {
switch pub := pub.(type) {
case *ecdsa.PublicKey:
return ConvertToECDHPublicKey(pub)
case *ecdh.PublicKey:
return pub, nil
default:
return nil, fmt.Errorf("unsupported public key of type: %T", pub)
}
}
curve, err := convCurve(e.sk.Curve())
if err != nil {
return nil, err
}
ekDSA, err := UncompressECPubKey(curve, ephemeral)
if err != nil {
return nil, err
}
return ekDSA.ECDH()
}

func convCurve(c ecdh.Curve) (elliptic.Curve, error) {
switch c {
case ecdh.P256():
return elliptic.P256()
return elliptic.P256(), nil
case ecdh.P384():
return elliptic.P384()
return elliptic.P384(), nil
case ecdh.P521():
return elliptic.P521()
return elliptic.P521(), nil
default:
return nil
return nil, fmt.Errorf("unsupported ECDH curve: %v", c)
}
}
162 changes: 162 additions & 0 deletions lib/ocrypto/asym_decryption_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package ocrypto

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestNewPrivateKeyDecryptor(t *testing.T) {
t.Parallel()

t.Run("RSA2048", func(t *testing.T) {
t.Parallel()
d, err := NewPrivateKeyDecryptor(RSA2048Key)
require.NoError(t, err)
require.Equal(t, RSA2048Key, d.KeyType())
})

t.Run("RSA4096", func(t *testing.T) {
t.Parallel()
d, err := NewPrivateKeyDecryptor(RSA4096Key)
require.NoError(t, err)
require.Equal(t, RSA4096Key, d.KeyType())
})

t.Run("EC256", func(t *testing.T) {
t.Parallel()
d, err := NewPrivateKeyDecryptor(EC256Key)
require.NoError(t, err)
require.Equal(t, EC256Key, d.KeyType())
})

t.Run("EC384", func(t *testing.T) {
t.Parallel()
d, err := NewPrivateKeyDecryptor(EC384Key)
require.NoError(t, err)
require.Equal(t, EC384Key, d.KeyType())
})

t.Run("EC521", func(t *testing.T) {
t.Parallel()
d, err := NewPrivateKeyDecryptor(EC521Key)
require.NoError(t, err)
require.Equal(t, EC521Key, d.KeyType())
})

t.Run("UnsupportedType", func(t *testing.T) {
t.Parallel()
_, err := NewPrivateKeyDecryptor(KeyType("dsa:1024"))
require.Error(t, err)
require.Contains(t, err.Error(), "unsupported key type")
})
}

func TestECDecryptorMethods(t *testing.T) {
t.Parallel()

modes := []struct {
mode ECCMode
kt KeyType
}{
{ECCModeSecp256r1, EC256Key},
{ECCModeSecp384r1, EC384Key},
{ECCModeSecp521r1, EC521Key},
}

for _, tc := range modes {
t.Run(string(tc.kt), func(t *testing.T) {
t.Parallel()

dec, err := NewECPrivateKey(tc.mode)
require.NoError(t, err)

// KeyType
require.Equal(t, tc.kt, dec.KeyType())

// PrivateKeyInPemFormat round-trip
privPEM, err := dec.PrivateKeyInPemFormat()
require.NoError(t, err)
require.NotEmpty(t, privPEM)
dec2, err := FromPrivatePEM(privPEM)
require.NoError(t, err)
require.Equal(t, tc.kt, dec2.KeyType())

// Public — returns an encryptor whose public key round-trips
enc, err := dec.Public()
require.NoError(t, err)
pubPEM, err := enc.PublicKeyInPemFormat()
require.NoError(t, err)
require.NotEmpty(t, pubPEM)
require.Equal(t, tc.kt, enc.KeyType())
})
}

t.Run("NilKeyGuards", func(t *testing.T) {
t.Parallel()
nilDec := ECDecryptor{}
_, err := nilDec.Public()
require.Error(t, err)
_, err = nilDec.PrivateKeyInPemFormat()
require.Error(t, err)
})
}

func TestRsaKeyPairNewMethods(t *testing.T) {
t.Parallel()

sizes := []struct {
bits int
kt KeyType
}{
{RSA2048Size, RSA2048Key},
{RSA4096Size, RSA4096Key},
}

for _, tc := range sizes {
t.Run(string(tc.kt), func(t *testing.T) {
t.Parallel()
kp, err := NewRSAKeyPair(tc.bits)
require.NoError(t, err)
require.Equal(t, tc.kt, kp.KeyType())

enc, err := kp.Public()
require.NoError(t, err)
require.Equal(t, tc.kt, enc.KeyType())
})
}

t.Run("NilKeyGuard", func(t *testing.T) {
t.Parallel()
_, err := RsaKeyPair{}.Public()
require.Error(t, err)
})
}

func TestAsymDecryptionKeyType(t *testing.T) {
t.Parallel()

for _, bits := range []int{RSA2048Size, RSA4096Size} {
kp, err := NewRSAKeyPair(bits)
require.NoError(t, err)
privPEM, err := kp.PrivateKeyInPemFormat()
require.NoError(t, err)
d, err := FromPrivatePEM(privPEM)
require.NoError(t, err)
ad, ok := d.(AsymDecryption)
require.True(t, ok)
if bits == RSA2048Size {
require.Equal(t, RSA2048Key, ad.KeyType())
} else {
require.Equal(t, RSA4096Key, ad.KeyType())
}
}
Comment on lines +136 to +153
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Consider using subtests for each RSA key size to improve test diagnostics.

The loop iterates over multiple RSA key sizes but doesn't wrap each iteration in t.Run(). This is inconsistent with the other tests in this file and makes it harder to identify which key size failed if there's an assertion error.

♻️ Proposed fix to add subtests
 func TestAsymDecryptionKeyType(t *testing.T) {
 	t.Parallel()

-	for _, bits := range []int{RSA2048Size, RSA4096Size} {
+	sizes := []struct {
+		bits int
+		kt   KeyType
+	}{
+		{RSA2048Size, RSA2048Key},
+		{RSA4096Size, RSA4096Key},
+	}
+
+	for _, tc := range sizes {
+		t.Run(string(tc.kt), func(t *testing.T) {
+			t.Parallel()
+			kp, err := NewRSAKeyPair(tc.bits)
+			require.NoError(t, err)
+			privPEM, err := kp.PrivateKeyInPemFormat()
+			require.NoError(t, err)
+			d, err := FromPrivatePEM(privPEM)
+			require.NoError(t, err)
+			ad, ok := d.(AsymDecryption)
+			require.True(t, ok)
+			require.Equal(t, tc.kt, ad.KeyType())
+		})
+	}
-		kp, err := NewRSAKeyPair(bits)
-		require.NoError(t, err)
-		privPEM, err := kp.PrivateKeyInPemFormat()
-		require.NoError(t, err)
-		d, err := FromPrivatePEM(privPEM)
-		require.NoError(t, err)
-		ad, ok := d.(AsymDecryption)
-		require.True(t, ok)
-		if bits == RSA2048Size {
-			require.Equal(t, RSA2048Key, ad.KeyType())
-		} else {
-			require.Equal(t, RSA4096Key, ad.KeyType())
-		}
-	}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func TestAsymDecryptionKeyType(t *testing.T) {
t.Parallel()
for _, bits := range []int{RSA2048Size, RSA4096Size} {
kp, err := NewRSAKeyPair(bits)
require.NoError(t, err)
privPEM, err := kp.PrivateKeyInPemFormat()
require.NoError(t, err)
d, err := FromPrivatePEM(privPEM)
require.NoError(t, err)
ad, ok := d.(AsymDecryption)
require.True(t, ok)
if bits == RSA2048Size {
require.Equal(t, RSA2048Key, ad.KeyType())
} else {
require.Equal(t, RSA4096Key, ad.KeyType())
}
}
func TestAsymDecryptionKeyType(t *testing.T) {
t.Parallel()
sizes := []struct {
bits int
kt KeyType
}{
{RSA2048Size, RSA2048Key},
{RSA4096Size, RSA4096Key},
}
for _, tc := range sizes {
t.Run(string(tc.kt), func(t *testing.T) {
t.Parallel()
kp, err := NewRSAKeyPair(tc.bits)
require.NoError(t, err)
privPEM, err := kp.PrivateKeyInPemFormat()
require.NoError(t, err)
d, err := FromPrivatePEM(privPEM)
require.NoError(t, err)
ad, ok := d.(AsymDecryption)
require.True(t, ok)
require.Equal(t, tc.kt, ad.KeyType())
})
}
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@lib/ocrypto/asym_decryption_test.go` around lines 136 - 153, The
TestAsymDecryptionKeyType loop should be converted to subtests so failures show
which RSA size failed: for each bits in {RSA2048Size, RSA4096Size} call t.Run
with a descriptive name (e.g. "RSA2048"/"RSA4096" or fmt.Sprintf("%d", bits)),
capture bits in the closure, and move t.Parallel() into the subtest body; inside
the subtest keep the existing logic (NewRSAKeyPair, PrivateKeyInPemFormat,
FromPrivatePEM, type assertion to AsymDecryption and the KeyType checks against
RSA2048Key/RSA4096Key).


t.Run("NilKeyGuard", func(t *testing.T) {
t.Parallel()
kt := AsymDecryption{}.KeyType()
// nil key returns a sentinel string, not a valid key type
require.False(t, IsRSAKeyType(kt))
require.False(t, IsECKeyType(kt))
})
}
Loading
Loading