diff --git a/pycose/algorithms.py b/pycose/algorithms.py index 8a2b4c0..9cb1c75 100644 --- a/pycose/algorithms.py +++ b/pycose/algorithms.py @@ -198,6 +198,11 @@ def verify(cls, key: 'EC2', data: bytes, signature: bytes) -> bool: class _AesMac(CoseAlgorithm, ABC): + @classmethod + @abstractmethod + def get_key_length(cls) -> int: + raise NotImplementedError() + @classmethod @abstractmethod def get_digest_length(cls) -> int: @@ -232,6 +237,11 @@ def verify_tag(cls, key: 'SK', tag: bytes, data: bytes): class _HMAC(CoseAlgorithm, ABC): + @classmethod + @abstractmethod + def get_key_length(cls) -> int: + raise NotImplementedError() + @classmethod @abstractmethod def get_digest_length(cls) -> int: @@ -301,10 +311,10 @@ def _ecdh(cls, curve: 'CoseCurve', private_key: 'EC2', public_key: 'EC2') -> byt return shared_key @classmethod - def derive_kek(cls, curve: 'CoseCurve', private_key: 'EC2', public_key: 'EC2', context: 'CoseKDFContext') -> bytes: + def derive_kek(cls, curve: 'CoseCurve', private_key: 'EC2', public_key: 'EC2', salt: Optional[bytes], context: 'CoseKDFContext') -> bytes: shared_secret = cls._ecdh(curve, private_key, public_key) - kdf = HKDF(algorithm=cls.get_hash_func(), length=context.supp_pub_info.key_data_length, salt=None, + kdf = HKDF(algorithm=cls.get_hash_func(), length=context.supp_pub_info.key_data_length, salt=salt, info=context.encode(), backend=default_backend()) return kdf.derive(shared_secret) @@ -971,8 +981,23 @@ class DirectHKDFAES128(CoseAlgorithm): fullname = "DIRECT_HKDF_AES_128" +class _DirectHkdf(CoseAlgorithm): + + @classmethod + def derive_cek(cls, shared_key: 'SK', salt: Optional[bytes], context: 'CoseKDFContext') -> bytes: + + kdf = HKDF(algorithm=cls.get_hash_func(), length=context.supp_pub_info.key_data_length, salt=salt, + info=context.encode(), + backend=default_backend()) + return kdf.derive(shared_key.k) + + @classmethod + def get_hash_func(cls) -> HashAlgorithm: + raise NotImplementedError() + + @CoseAlgorithm.register_attribute() -class DirecHKDFSHA512(CoseAlgorithm): +class DirectHKDFSHA512(_DirectHkdf): """ Shared secret w/ HKDF and SHA-512 @@ -985,9 +1010,13 @@ class DirecHKDFSHA512(CoseAlgorithm): identifier = - 11 fullname = "DIRECT_HKDF_SHA_512" + @classmethod + def get_hash_func(cls) -> HashAlgorithm: + return SHA512() + @CoseAlgorithm.register_attribute() -class DirectHKDFSHA256(CoseAlgorithm): +class DirectHKDFSHA256(_DirectHkdf): """ Shared secret w/ HKDF and SHA-256 @@ -1000,6 +1029,10 @@ class DirectHKDFSHA256(CoseAlgorithm): identifier = - 10 fullname = "DIRECT_HKDF_SHA_256" + @classmethod + def get_hash_func(cls) -> HashAlgorithm: + return SHA256() + @CoseAlgorithm.register_attribute() class EdDSA(CoseAlgorithm): @@ -1199,6 +1232,10 @@ class HMAC25664(_HMAC): identifier = 4 fullname = 'HMAC_256_64' + @classmethod + def get_key_length(cls) -> int: + return 32 + @classmethod def get_digest_length(cls) -> int: return 8 @@ -1213,6 +1250,10 @@ class HMAC256(_HMAC): identifier = 5 fullname = 'HMAC_256' + @classmethod + def get_key_length(cls) -> int: + return 32 + @classmethod def get_digest_length(cls) -> int: return 32 @@ -1227,6 +1268,10 @@ class HMAC384(_HMAC): identifier = 6 fullname = 'HMAC_384' + @classmethod + def get_key_length(cls) -> int: + return 48 + @classmethod def get_digest_length(cls) -> int: return 48 @@ -1241,6 +1286,10 @@ class HMAC512(_HMAC): identifier = 7 fullname = 'HMAC_512' + @classmethod + def get_key_length(cls) -> int: + return 64 + @classmethod def get_digest_length(cls) -> int: return 64 diff --git a/pycose/messages/encmessage.py b/pycose/messages/encmessage.py index cc43e07..15a8305 100644 --- a/pycose/messages/encmessage.py +++ b/pycose/messages/encmessage.py @@ -3,7 +3,7 @@ from pycose import utils, headers from pycose.exceptions import CoseException -from pycose.keys.keyops import EncryptOp +from pycose.keys.keyops import EncryptOp, DecryptOp from pycose.keys.keyparam import KpAlg, KpKeyOps from pycose.keys.symmetric import SymmetricKey from pycose.messages import enccommon, cosemessage @@ -78,12 +78,10 @@ def encrypt(self, *args, **kwargs) -> bytes: r_types = CoseRecipient.verify_recipients(self.recipients) if DirectEncryption in r_types: - # key should already be known - payload = super(EncMessage, self).encrypt() + self.key = self.recipients[0].compute_cek(target_algorithm) elif DirectKeyAgreement in r_types: - self.key = self.recipients[0].compute_cek(target_algorithm, "encrypt") - payload = super(EncMessage, self).encrypt() + self.key = self.recipients[0].compute_cek(target_algorithm, EncryptOp) elif KeyWrap in r_types or KeyAgreementWithKeyWrap in r_types: key_bytes = os.urandom(self.get_attr(headers.Algorithm).get_key_length()) @@ -94,10 +92,11 @@ def encrypt(self, *args, **kwargs) -> bytes: key_bytes = r.payload r.encrypt(target_algorithm) self.key = SymmetricKey(k=key_bytes, optional_params={KpAlg: target_algorithm, KpKeyOps: [EncryptOp]}) - payload = super(EncMessage, self).encrypt() else: - raise CoseException('Unsupported COSE recipient class') + raise CoseException(f'Unsupported COSE recipient class: {r_types}') + + payload = super(EncMessage, self).encrypt() return payload @@ -108,17 +107,18 @@ def decrypt(self, recipient: 'Recipient', *args, **kwargs) -> bytes: if not CoseRecipient.has_recipient(recipient, self.recipients): raise CoseException(f"Cannot find recipient: {recipient}") - r_types = CoseRecipient.verify_recipients(self.recipients) + CoseRecipient.verify_recipients(self.recipients) - if DirectEncryption in r_types: - # key should already be known - payload = super(EncMessage, self).decrypt() + if isinstance(recipient, DirectEncryption): + self.key = recipient.compute_cek(target_algorithm) + + elif isinstance(recipient, (DirectKeyAgreement, KeyWrap, KeyAgreementWithKeyWrap)): + self.key = recipient.compute_cek(target_algorithm, DecryptOp) - elif DirectKeyAgreement in r_types or KeyWrap in r_types or KeyAgreementWithKeyWrap in r_types: - self.key = recipient.compute_cek(target_algorithm, "decrypt") - payload = super(EncMessage, self).decrypt() else: - raise CoseException('Unsupported COSE recipient class') + raise CoseException(f'Unsupported COSE recipient class: {recipient}') + + payload = super(EncMessage, self).decrypt() return payload diff --git a/pycose/messages/macmessage.py b/pycose/messages/macmessage.py index 94befe8..d45a28b 100644 --- a/pycose/messages/macmessage.py +++ b/pycose/messages/macmessage.py @@ -14,13 +14,16 @@ from pycose import utils, headers from pycose.exceptions import CoseException -from pycose.keys.keyops import MacCreateOp +from pycose.keys.keyops import MacCreateOp, MacVerifyOp +from pycose.keys.symmetric import SymmetricKey +from pycose.keys.keyparam import KpAlg, KpKeyOps from pycose.messages import cosemessage, maccommon from pycose.messages.recipient import CoseRecipient, DirectEncryption, DirectKeyAgreement, KeyWrap, \ KeyAgreementWithKeyWrap if TYPE_CHECKING: - from pycose.keys.symmetric import SK, SymmetricKey + from pycose.keys.symmetric import SK + from pycose.messages.recipient import Recipient CBOR = bytes @@ -72,18 +75,36 @@ def encode(self, tag: bool = True, mac: bool = True, *args, **kwargs) -> CBOR: res = super(MacMessage, self).encode(message, tag) return res + def verify_tag(self, recipient: 'Recipient', *args, **kwargs) -> bool: + target_algorithm = self.get_attr(headers.Algorithm) + + # check if recipient exists + if not CoseRecipient.has_recipient(recipient, self.recipients): + raise CoseException(f"Cannot find recipient: {recipient}") + + CoseRecipient.verify_recipients(self.recipients) + + if isinstance(recipient, DirectEncryption): + self.key = recipient.compute_cek(target_algorithm) + + elif isinstance(recipient, (DirectKeyAgreement, KeyWrap, KeyAgreementWithKeyWrap)): + self.key = recipient.compute_cek(target_algorithm, MacVerifyOp) + + else: + raise CoseException(f'Unsupported COSE recipient class: {type(recipient)}') + + return super(MacMessage, self).verify_tag() + def compute_tag(self, *args, **kwargs) -> bytes: target_algorithm = self.get_attr(headers.Algorithm) r_types = CoseRecipient.verify_recipients(self.recipients) if DirectEncryption in r_types: - # key should already be known - payload = super(MacMessage, self).compute_tag() + self.key = self.recipients[0].compute_cek(target_algorithm) elif DirectKeyAgreement in r_types: - self.key = self.recipients[0].compute_cek(target_algorithm, "encrypt") - payload = super(MacMessage, self).compute_tag() + self.key = self.recipients[0].compute_cek(target_algorithm, MacCreateOp) elif KeyWrap in r_types or KeyAgreementWithKeyWrap in r_types: key_bytes = os.urandom(self.get_attr(headers.Algorithm).get_key_length()) @@ -94,11 +115,12 @@ def compute_tag(self, *args, **kwargs) -> bytes: else: key_bytes = r.payload r.encrypt(target_algorithm) - self.key = SymmetricKey(k=key_bytes, alg=target_algorithm, key_ops=[MacCreateOp]) - payload = super(MacMessage, self).compute_tag() + self.key = SymmetricKey(k=key_bytes, optional_params={KpAlg: target_algorithm, KpKeyOps: [MacCreateOp]}) else: - raise CoseException('Unsupported COSE recipient class') + raise CoseException(f'Unsupported COSE recipient class: {r_types}') + + payload = super(MacMessage, self).compute_tag() return payload diff --git a/pycose/messages/recipient.py b/pycose/messages/recipient.py index 69ab52a..83f2e30 100644 --- a/pycose/messages/recipient.py +++ b/pycose/messages/recipient.py @@ -14,8 +14,10 @@ A128KW, \ A192KW, \ A256KW, \ + DirectHKDFAES128, \ DirectHKDFAES256, \ DirectHKDFSHA256, \ + DirectHKDFSHA512, \ EcdhEsHKDF256, \ EcdhEsHKDF512, \ EcdhEsA128KW, \ @@ -28,7 +30,8 @@ EcdhSsA256KW from pycose.exceptions import CoseException, CoseMalformedMessage, CoseIllegalAlgorithm from pycose.keys.ec2 import EC2Key, EC2KpD -from pycose.keys.keyops import DeriveKeyOp, EncryptOp, DecryptOp, WrapOp, UnwrapOp, DeriveBitsOp +from pycose.keys.keyops import KeyOps, DeriveKeyOp, EncryptOp, DecryptOp, \ + WrapOp, UnwrapOp, DeriveBitsOp, MacCreateOp, MacVerifyOp from pycose.keys.keyparam import KpAlg, KpKeyOps from pycose.keys.rsa import RSAKey from pycose.keys.symmetric import SymmetricKey @@ -185,7 +188,7 @@ def _setup_ephemeral_key(self, peer_key, optional_params: dict = None): self.uhdr_update({headers.EphemeralKey: ephemeral_public_key}) -@CoseRecipient.record_rc([Direct, DirectHKDFSHA256, DirectHKDFAES256]) +@CoseRecipient.record_rc([Direct, DirectHKDFSHA256, DirectHKDFSHA512, DirectHKDFAES128, DirectHKDFAES256]) class DirectEncryption(CoseRecipient): @classmethod @@ -229,14 +232,20 @@ def encode(self, *args, **kwargs) -> list: return recipient - def compute_cek(self, target_alg: 'CoseAlgorithm') -> Optional['SK']: + def compute_cek(self, target_alg: '_EncAlg') -> 'SK': alg = self.get_attr(headers.Algorithm) if alg == Direct: - return None - else: + return self.key + elif alg in {DirectHKDFSHA256, DirectHKDFSHA512}: self.key.verify(SymmetricKey, algorithm=alg, key_ops=[DeriveKeyOp, DeriveBitsOp]) - _ = target_alg + + salt = self.get_attr(headers.Salt) + keybytes = alg.derive_cek(shared_key=self.key, salt=salt, context=self.get_kdf_context(target_alg)) + return SymmetricKey(k=keybytes) + elif alg in {DirectHKDFAES128, DirectHKDFAES256}: raise NotImplementedError() + else: + raise ValueError(f"Inappropriate alg value: {alg}") def __repr__(self) -> str: phdr, uhdr = self._hdr_repr() @@ -315,15 +324,17 @@ def _compute_kek(self, target_alg: '_EncAlg', ops: 'str') -> bytes: return self.key.k - def compute_cek(self, target_alg: '_EncAlg', ops: str) -> Optional['SK']: - if ops == "encrypt": + def compute_cek(self, target_alg: '_EncAlg', key_op: KeyOps) -> Optional['SK']: + if key_op in {EncryptOp, MacCreateOp}: if self.payload == b'': return None else: - return SymmetricKey(k=self.payload, optional_params={KpAlg: target_alg, KpKeyOps: [EncryptOp]}) - else: + return SymmetricKey(k=self.payload, optional_params={KpAlg: target_alg, KpKeyOps: [key_op]}) + elif key_op in {DecryptOp, MacVerifyOp}: return SymmetricKey(k=self.decrypt(target_alg), - optional_params={KpAlg: target_alg, KpKeyOps: [DecryptOp]}) + optional_params={KpAlg: target_alg, KpKeyOps: [key_op]}) + else: + raise CoseException(f"Invalid compute_cek op: {key_op}") def encrypt(self, target_alg: '_EncAlg') -> bytes: alg = self.get_attr(headers.Algorithm) @@ -432,20 +443,22 @@ def encode(self, *args, **kwargs) -> list: return recipient def _compute_kek(self, target_alg: '_EncAlg', peer_key: 'EC2Key', local_key: 'EC2Key', kex_alg) -> bytes: + salt = self.get_attr(headers.Salt) + return kex_alg.derive_kek(peer_key.crv, local_key, peer_key, salt, self.get_kdf_context(target_alg)) - return kex_alg.derive_kek(peer_key.crv, local_key, peer_key, self.get_kdf_context(target_alg)) - - def compute_cek(self, target_alg: '_EncAlg', ops: str) -> 'SK': + def compute_cek(self, target_alg: '_EncAlg', key_op: KeyOps) -> 'SK': alg = self.get_attr(headers.Algorithm) if alg in {EcdhSsHKDF256, EcdhSsHKDF512, EcdhEsHKDF256, EcdhEsHKDF512}: - if ops == "encrypt": + if key_op in {EncryptOp, MacCreateOp}: peer_key = self.local_attrs.get(headers.StaticKey) - else: + elif key_op in {DecryptOp, MacVerifyOp}: if alg in {EcdhSsHKDF256, EcdhSsHKDF512}: peer_key = self.get_attr(headers.StaticKey) else: peer_key = self.get_attr(headers.EphemeralKey) + else: + raise CoseException(f"Invalid compute_cek op: {key_op}") else: raise CoseIllegalAlgorithm(f"Algorithm {alg} unsupported for {self.__name__}") @@ -494,15 +507,17 @@ def context(self): def context(self, context: str): self._context = context - def compute_cek(self, target_alg: '_EncAlg', ops: str) -> Optional['SK']: - if ops == "encrypt": + def compute_cek(self, target_alg: '_EncAlg', key_op: KeyOps) -> Optional['SK']: + if key_op in {EncryptOp, MacCreateOp}: if self.payload == b'': return None else: return SymmetricKey(k=self.payload, optional_params={KpAlg: target_alg, KpKeyOps: [EncryptOp]}) - else: + elif key_op in {DecryptOp, MacVerifyOp}: return SymmetricKey(k=self.decrypt(target_alg), optional_params={KpAlg: target_alg, KpKeyOps: [DecryptOp]}) + else: + raise CoseException(f"Invalid compute_cek op: {key_op}") def encode(self, *args, **kwargs) -> list: @@ -515,8 +530,8 @@ def encode(self, *args, **kwargs) -> list: return recipient def _compute_kek(self, target_alg: '_EncAlg', peer_key: 'EC2Key', local_key: 'EC2Key', kex_alg) -> bytes: - - key_bytes = kex_alg.derive_kek(peer_key.crv, local_key, peer_key, self.get_kdf_context(target_alg)) + salt = self.get_attr(headers.Salt) + key_bytes = kex_alg.derive_kek(peer_key.crv, local_key, peer_key, salt, self.get_kdf_context(target_alg)) return key_bytes def encrypt(self, target_alg) -> bytes: diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 8f4a941..c0ec94d 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -198,7 +198,7 @@ def test_allow_unknown_header_attribute_encoding_decoding(): msg = EncMessage(phdr={Algorithm: AESCCM1664128, "Custom-Header-Attr1": 7879}, uhdr={KID: b'foo', IV: unhexlify(b'00000000000000000000000000'), "Custom-Header-Attr2": 878}, recipients=[DirectEncryption(uhdr={Algorithm: Direct, "Custom-Header-Attr3": 9999})]) - msg.key = SymmetricKey.generate_key(key_len=16) + msg.recipients[0].key = SymmetricKey.generate_key(key_len=16) assert "Custom-Header-Attr1" in msg.phdr assert "Custom-Header-Attr2" in msg.uhdr @@ -228,7 +228,7 @@ def test_allow_unknown_header_attribute_encoding_decoding(): msg = MacMessage(phdr={Algorithm: HMAC256, "Custom-Header-Attr1": 7879}, uhdr={KID: b'foo', IV: unhexlify(b'00000000000000000000000000'), "Custom-Header-Attr2": 878}, recipients=[DirectEncryption(uhdr={Algorithm: Direct, "Custom-Header-Attr3": 9999})]) - msg.key = SymmetricKey.generate_key(key_len=16) + msg.recipients[0].key = SymmetricKey.generate_key(key_len=16) assert "Custom-Header-Attr1" in msg.phdr assert "Custom-Header-Attr2" in msg.uhdr @@ -272,6 +272,7 @@ def test_allow_unknown_header_attribute_encoding_decoding(): assert "Custom-Header-Attr1" in msg_decoded.phdr assert "Custom-Header-Attr2" in msg_decoded.uhdr + def test_no_reencoding_of_protected_header(): # The following protected header encodes {Alg: Es256, "foo": 1}, however, # it is crafted such that it would not be emitted by cbor2. @@ -284,5 +285,5 @@ def test_no_reencoding_of_protected_header(): msg = msg.encode() msg_decoded = Sign1Message.decode(msg) - + assert msg_decoded.phdr_encoded == phdr_encoded diff --git a/tests/test_encmessage.py b/tests/test_encmessage.py index 7f6ad03..03bc9ce 100644 --- a/tests/test_encmessage.py +++ b/tests/test_encmessage.py @@ -19,8 +19,8 @@ def test_encrypt_direct_encryption_encoding(test_encrypt_direct_encryption_files key = CoseKey.from_dict(test_encrypt_direct_encryption_files["cek"]) key.key_ops = [EncryptOp] - - msg.key = key + # first recipient is arbitrary choice + msg.recipients[0].key = key assert msg.phdr_encoded == test_output['protected'] assert msg.uhdr_encoded == test_output['unprotected'] @@ -38,8 +38,8 @@ def test_encrypt_direct_encryption_decoding(test_encrypt_direct_encryption_files key = CoseKey.from_dict(test_encrypt_direct_encryption_files["cek"]) key.key_ops = [DecryptOp] - - msg.key = key + # first recipient is arbitrary choice + msg.recipients[0].key = key assert msg.phdr == test_input['protected'] assert msg.uhdr == test_input['unprotected'] @@ -167,7 +167,7 @@ def test_encrypt_key_agreement_key_wrap_encoding(test_encrypt_key_agreement_key_ assert msg.phdr == test_input['protected'] assert msg.uhdr == test_input['unprotected'] - for i, (r, r_output) in enumerate(zip(msg.recipients, test_output['recipients'])): + for _i, (r, r_output) in enumerate(zip(msg.recipients, test_output['recipients'])): r.payload = test_encrypt_key_agreement_key_wrap_files['random_key'].k assert r.phdr_encoded == r_output['protected'] assert r.uhdr_encoded == r_output['unprotected'] @@ -198,7 +198,7 @@ def test_encrypt_key_agreement_key_wrap_decoding(test_encrypt_key_agreement_key_ assert r.payload == r_output['ciphertext'] assert r.get_kdf_context((r.get_attr(headers.Algorithm)).get_key_wrap_func()).encode() == r_output['context'] assert r.decrypt((r.get_attr(headers.Algorithm)).get_key_wrap_func()) == \ - test_encrypt_key_agreement_key_wrap_files['random_key'].k + test_encrypt_key_agreement_key_wrap_files['random_key'].k for r in msg.recipients: assert msg.decrypt(r) == test_input['plaintext'] diff --git a/tests/test_macmessage.py b/tests/test_macmessage.py index 60532fe..35e089e 100644 --- a/tests/test_macmessage.py +++ b/tests/test_macmessage.py @@ -21,8 +21,8 @@ def test_mac_direct_encryption_encoding(test_mac_direct_encryption_files): key = CoseKey.from_dict(test_mac_direct_encryption_files["cek"]) key.key_ops = [MacCreateOp] - - msg.key = key + # first recipient is arbitrary choice + msg.recipients[0].key = key assert msg.phdr_encoded == test_output['protected'] assert msg.uhdr_encoded == test_output['unprotected'] @@ -41,8 +41,8 @@ def test_mac_direct_encryption_decoding(test_mac_direct_encryption_files): key = CoseKey.from_dict(test_mac_direct_encryption_files["cek"]) key.key_ops = [MacVerifyOp] - - msg.key = key + # first recipient is arbitrary choice + msg.recipients[0].key = key assert msg.phdr == test_input['protected'] assert msg.uhdr == test_input['unprotected']