Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions pycose/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ def verify(cls, key: 'EC2', data: bytes, signature: bytes) -> bool:


class _AesMac(CoseAlgorithm, ABC):
@classmethod
@abstractmethod
def get_key_length(cls) -> int:
raise NotImplementedError()

@classmethod
@abstractmethod
def get_digest_length(cls) -> int:
Expand Down Expand Up @@ -232,6 +237,11 @@ def verify_tag(cls, key: 'SK', tag: bytes, data: bytes):


class _HMAC(CoseAlgorithm, ABC):
@classmethod
@abstractmethod
def get_key_length(cls) -> int:
raise NotImplementedError()

@classmethod
@abstractmethod
def get_digest_length(cls) -> int:
Expand Down Expand Up @@ -301,10 +311,10 @@ def _ecdh(cls, curve: 'CoseCurve', private_key: 'EC2', public_key: 'EC2') -> byt
return shared_key

@classmethod
def derive_kek(cls, curve: 'CoseCurve', private_key: 'EC2', public_key: 'EC2', context: 'CoseKDFContext') -> bytes:
def derive_kek(cls, curve: 'CoseCurve', private_key: 'EC2', public_key: 'EC2', salt: Optional[bytes], context: 'CoseKDFContext') -> bytes:
shared_secret = cls._ecdh(curve, private_key, public_key)

kdf = HKDF(algorithm=cls.get_hash_func(), length=context.supp_pub_info.key_data_length, salt=None,
kdf = HKDF(algorithm=cls.get_hash_func(), length=context.supp_pub_info.key_data_length, salt=salt,
info=context.encode(),
backend=default_backend())
return kdf.derive(shared_secret)
Expand Down Expand Up @@ -971,8 +981,23 @@ class DirectHKDFAES128(CoseAlgorithm):
fullname = "DIRECT_HKDF_AES_128"


class _DirectHkdf(CoseAlgorithm):

@classmethod
def derive_cek(cls, shared_key: 'SK', salt: Optional[bytes], context: 'CoseKDFContext') -> bytes:

kdf = HKDF(algorithm=cls.get_hash_func(), length=context.supp_pub_info.key_data_length, salt=salt,
info=context.encode(),
backend=default_backend())
return kdf.derive(shared_key.k)

@classmethod
def get_hash_func(cls) -> HashAlgorithm:
raise NotImplementedError()


@CoseAlgorithm.register_attribute()
class DirecHKDFSHA512(CoseAlgorithm):
class DirectHKDFSHA512(_DirectHkdf):
"""
Shared secret w/ HKDF and SHA-512

Expand All @@ -985,9 +1010,13 @@ class DirecHKDFSHA512(CoseAlgorithm):
identifier = - 11
fullname = "DIRECT_HKDF_SHA_512"

@classmethod
def get_hash_func(cls) -> HashAlgorithm:
return SHA512()


@CoseAlgorithm.register_attribute()
class DirectHKDFSHA256(CoseAlgorithm):
class DirectHKDFSHA256(_DirectHkdf):
"""
Shared secret w/ HKDF and SHA-256

Expand All @@ -1000,6 +1029,10 @@ class DirectHKDFSHA256(CoseAlgorithm):
identifier = - 10
fullname = "DIRECT_HKDF_SHA_256"

@classmethod
def get_hash_func(cls) -> HashAlgorithm:
return SHA256()


@CoseAlgorithm.register_attribute()
class EdDSA(CoseAlgorithm):
Expand Down Expand Up @@ -1199,6 +1232,10 @@ class HMAC25664(_HMAC):
identifier = 4
fullname = 'HMAC_256_64'

@classmethod
def get_key_length(cls) -> int:
return 32

@classmethod
def get_digest_length(cls) -> int:
return 8
Expand All @@ -1213,6 +1250,10 @@ class HMAC256(_HMAC):
identifier = 5
fullname = 'HMAC_256'

@classmethod
def get_key_length(cls) -> int:
return 32

@classmethod
def get_digest_length(cls) -> int:
return 32
Expand All @@ -1227,6 +1268,10 @@ class HMAC384(_HMAC):
identifier = 6
fullname = 'HMAC_384'

@classmethod
def get_key_length(cls) -> int:
return 48

@classmethod
def get_digest_length(cls) -> int:
return 48
Expand All @@ -1241,6 +1286,10 @@ class HMAC512(_HMAC):
identifier = 7
fullname = 'HMAC_512'

@classmethod
def get_key_length(cls) -> int:
return 64

@classmethod
def get_digest_length(cls) -> int:
return 64
Expand Down
30 changes: 15 additions & 15 deletions pycose/messages/encmessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pycose import utils, headers
from pycose.exceptions import CoseException
from pycose.keys.keyops import EncryptOp
from pycose.keys.keyops import EncryptOp, DecryptOp
from pycose.keys.keyparam import KpAlg, KpKeyOps
from pycose.keys.symmetric import SymmetricKey
from pycose.messages import enccommon, cosemessage
Expand Down Expand Up @@ -78,12 +78,10 @@ def encrypt(self, *args, **kwargs) -> bytes:
r_types = CoseRecipient.verify_recipients(self.recipients)

if DirectEncryption in r_types:
# key should already be known
payload = super(EncMessage, self).encrypt()
self.key = self.recipients[0].compute_cek(target_algorithm)

elif DirectKeyAgreement in r_types:
self.key = self.recipients[0].compute_cek(target_algorithm, "encrypt")
payload = super(EncMessage, self).encrypt()
self.key = self.recipients[0].compute_cek(target_algorithm, EncryptOp)

elif KeyWrap in r_types or KeyAgreementWithKeyWrap in r_types:
key_bytes = os.urandom(self.get_attr(headers.Algorithm).get_key_length())
Expand All @@ -94,10 +92,11 @@ def encrypt(self, *args, **kwargs) -> bytes:
key_bytes = r.payload
r.encrypt(target_algorithm)
self.key = SymmetricKey(k=key_bytes, optional_params={KpAlg: target_algorithm, KpKeyOps: [EncryptOp]})
payload = super(EncMessage, self).encrypt()

else:
raise CoseException('Unsupported COSE recipient class')
raise CoseException(f'Unsupported COSE recipient class: {r_types}')

payload = super(EncMessage, self).encrypt()

return payload

Expand All @@ -108,17 +107,18 @@ def decrypt(self, recipient: 'Recipient', *args, **kwargs) -> bytes:
if not CoseRecipient.has_recipient(recipient, self.recipients):
raise CoseException(f"Cannot find recipient: {recipient}")

r_types = CoseRecipient.verify_recipients(self.recipients)
CoseRecipient.verify_recipients(self.recipients)

if DirectEncryption in r_types:
# key should already be known
payload = super(EncMessage, self).decrypt()
if isinstance(recipient, DirectEncryption):
self.key = recipient.compute_cek(target_algorithm)

elif isinstance(recipient, (DirectKeyAgreement, KeyWrap, KeyAgreementWithKeyWrap)):
self.key = recipient.compute_cek(target_algorithm, DecryptOp)

elif DirectKeyAgreement in r_types or KeyWrap in r_types or KeyAgreementWithKeyWrap in r_types:
self.key = recipient.compute_cek(target_algorithm, "decrypt")
payload = super(EncMessage, self).decrypt()
else:
raise CoseException('Unsupported COSE recipient class')
raise CoseException(f'Unsupported COSE recipient class: {recipient}')

payload = super(EncMessage, self).decrypt()

return payload

Expand Down
40 changes: 31 additions & 9 deletions pycose/messages/macmessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@

from pycose import utils, headers
from pycose.exceptions import CoseException
from pycose.keys.keyops import MacCreateOp
from pycose.keys.keyops import MacCreateOp, MacVerifyOp
from pycose.keys.symmetric import SymmetricKey
from pycose.keys.keyparam import KpAlg, KpKeyOps
from pycose.messages import cosemessage, maccommon
from pycose.messages.recipient import CoseRecipient, DirectEncryption, DirectKeyAgreement, KeyWrap, \
KeyAgreementWithKeyWrap

if TYPE_CHECKING:
from pycose.keys.symmetric import SK, SymmetricKey
from pycose.keys.symmetric import SK
from pycose.messages.recipient import Recipient

CBOR = bytes

Expand Down Expand Up @@ -72,18 +75,36 @@ def encode(self, tag: bool = True, mac: bool = True, *args, **kwargs) -> CBOR:
res = super(MacMessage, self).encode(message, tag)
return res

def verify_tag(self, recipient: 'Recipient', *args, **kwargs) -> bool:
target_algorithm = self.get_attr(headers.Algorithm)

# check if recipient exists
if not CoseRecipient.has_recipient(recipient, self.recipients):
raise CoseException(f"Cannot find recipient: {recipient}")

CoseRecipient.verify_recipients(self.recipients)

if isinstance(recipient, DirectEncryption):
self.key = recipient.compute_cek(target_algorithm)

elif isinstance(recipient, (DirectKeyAgreement, KeyWrap, KeyAgreementWithKeyWrap)):
self.key = recipient.compute_cek(target_algorithm, MacVerifyOp)

else:
raise CoseException(f'Unsupported COSE recipient class: {type(recipient)}')

return super(MacMessage, self).verify_tag()

def compute_tag(self, *args, **kwargs) -> bytes:
target_algorithm = self.get_attr(headers.Algorithm)

r_types = CoseRecipient.verify_recipients(self.recipients)

if DirectEncryption in r_types:
# key should already be known
payload = super(MacMessage, self).compute_tag()
self.key = self.recipients[0].compute_cek(target_algorithm)

elif DirectKeyAgreement in r_types:
self.key = self.recipients[0].compute_cek(target_algorithm, "encrypt")
payload = super(MacMessage, self).compute_tag()
self.key = self.recipients[0].compute_cek(target_algorithm, MacCreateOp)

elif KeyWrap in r_types or KeyAgreementWithKeyWrap in r_types:
key_bytes = os.urandom(self.get_attr(headers.Algorithm).get_key_length())
Expand All @@ -94,11 +115,12 @@ def compute_tag(self, *args, **kwargs) -> bytes:
else:
key_bytes = r.payload
r.encrypt(target_algorithm)
self.key = SymmetricKey(k=key_bytes, alg=target_algorithm, key_ops=[MacCreateOp])
payload = super(MacMessage, self).compute_tag()
self.key = SymmetricKey(k=key_bytes, optional_params={KpAlg: target_algorithm, KpKeyOps: [MacCreateOp]})

else:
raise CoseException('Unsupported COSE recipient class')
raise CoseException(f'Unsupported COSE recipient class: {r_types}')

payload = super(MacMessage, self).compute_tag()

return payload

Expand Down
Loading
Loading