diff --git a/include/sframe/map.h b/include/sframe/map.h index b6f8617..2a796b0 100644 --- a/include/sframe/map.h +++ b/include/sframe/map.h @@ -83,28 +83,35 @@ class map : private vector>, N> namespace SFRAME_NAMESPACE { -// NOTE: NOT RECOMMENDED FOR USE OUTSIDE THIS LIBRARY -// -// We have used public inheritance from std::map to simplify the interface -// here. This works fine for the use cases we have within this library. If you -// choose to use this map type outside this library, you MUST NOT store it as a -// std::map pointer or reference. This will cause memory leaks, because the -// destructor ~std::map is not virtual. template -class map : public std::map +class map : private std::map { private: using parent = std::map; public: + template + void emplace(Args&&... args) + { + parent::emplace(std::forward(args)...); + } + + auto find(const K& key) { return parent::find(key); } + auto find(const K& key) const { return parent::find(key); } + bool contains(const K& key) const { return this->count(key) > 0; } + const V& at(const K& key) const { return parent::at(key); } + V& at(const K& key) { return parent::at(key); } + + void erase(const K& key) { parent::erase(key); } + template void erase_if_key(F&& f) { - for (auto iter = this->begin(); iter != this->end();) { + for (auto iter = parent::begin(); iter != parent::end();) { if (f(iter->first)) { - iter = this->erase(iter); + iter = parent::erase(iter); } else { ++iter; } diff --git a/include/sframe/result.h b/include/sframe/result.h index 5a896c3..0a74c20 100644 --- a/include/sframe/result.h +++ b/include/sframe/result.h @@ -45,6 +45,8 @@ class SFrameError private: SFrameErrorType type_; + // Message storage is borrowed; callers must pass a string with static or + // otherwise stable lifetime. const char* message_ = nullptr; }; diff --git a/include/sframe/vector.h b/include/sframe/vector.h index 79f23e0..d86065a 100644 --- a/include/sframe/vector.h +++ b/include/sframe/vector.h @@ -29,12 +29,14 @@ class vector constexpr vector(std::initializer_list content) { + std::fill(_data.begin(), _data.end(), T()); resize(content.size()); std::copy(content.begin(), content.end(), _data.begin()); } constexpr vector(gsl::span content) { + std::fill(_data.begin(), _data.end(), T()); resize(content.size()); std::copy(content.begin(), content.end(), _data.begin()); } @@ -44,6 +46,7 @@ class vector template constexpr vector(const vector& content) { + std::fill(_data.begin(), _data.end(), T()); resize(content.size()); std::copy(content.begin(), content.end(), _data.begin()); } @@ -95,15 +98,8 @@ class vector namespace SFRAME_NAMESPACE { -// NOTE: NOT RECOMMENDED FOR USE OUTSIDE THIS LIBRARY -// -// We have used public inheritance from std::vector to simplify the interface -// here. This works fine for the use cases we have within this library. If you -// choose to use this vector type outside this library, you MUST NOT store it as -// a std::vector pointer or reference. This will cause memory leaks, because -// the destructor ~std::vector is not virtual. template -class vector : public std::vector +class vector : private std::vector { private: using parent = std::vector; @@ -126,16 +122,39 @@ class vector : public std::vector template constexpr vector(const vector& content) - : parent(content) + : parent(content.begin(), content.end()) { } + T* data() { return parent::data(); } + const T* data() const { return parent::data(); } + + auto begin() const { return parent::begin(); } + auto begin() { return parent::begin(); } + + auto end() const { return parent::end(); } + auto end() { return parent::end(); } + + auto size() const { return parent::size(); } + auto capacity() const { return parent::capacity(); } + void resize(size_t size) { parent::resize(size); } + + auto& operator[](size_t i) { return parent::operator[](i); } + const auto& operator[](size_t i) const { return parent::operator[](i); } + void append(gsl::span content) { const auto start = this->size(); this->resize(start + content.size()); std::copy(content.begin(), content.end(), this->begin() + start); } + + operator gsl::span() const + { + return gsl::span(parent::data(), parent::size()); + } + + operator gsl::span() { return gsl::span(parent::data(), parent::size()); } }; } // namespace SFRAME_NAMESPACE diff --git a/src/crypto.cpp b/src/crypto.cpp index d090d86..8b8a8df 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -1,6 +1,10 @@ #include "crypto.h" #include "header.h" +#include + +#include + namespace SFRAME_NAMESPACE { Result @@ -94,4 +98,38 @@ cipher_overhead(CipherSuite suite) } } +Result +checked_int(size_t size) +{ + if (size > INT_MAX) { + return SFrameError(SFrameErrorType::invalid_parameter_error, + "Input too large for OpenSSL"); + } + + return static_cast(size); +} + +Result +validate_ctr_size(size_t size) +{ + static constexpr uint64_t max_ctr_size = uint64_t(1) << 36; + if (uint64_t(size) > max_ctr_size) { + return SFrameError(SFrameErrorType::invalid_parameter_error, + "CTR input too large"); + } + + auto size_int = checked_int(size); + if (size_int.is_err()) { + return size_int.error(); + } + + return Result::ok(); +} + +void +clear_openssl_errors() +{ + ERR_clear_error(); +} + } // namespace SFRAME_NAMESPACE diff --git a/src/crypto.h b/src/crypto.h index cf95e41..9f1a10e 100644 --- a/src/crypto.h +++ b/src/crypto.h @@ -15,6 +15,12 @@ Result cipher_nonce_size(CipherSuite suite); Result cipher_overhead(CipherSuite suite); +Result +checked_int(size_t size); +Result +validate_ctr_size(size_t size); +void +clear_openssl_errors(); /// /// HKDF diff --git a/src/crypto_boringssl.cpp b/src/crypto_boringssl.cpp index 96dd5ba..0f1957b 100644 --- a/src/crypto_boringssl.cpp +++ b/src/crypto_boringssl.cpp @@ -69,6 +69,7 @@ openssl_cipher(CipherSuite suite) Result> hkdf_extract(CipherSuite suite, input_bytes salt, input_bytes ikm) { + clear_openssl_errors(); SFRAME_VALUE_OR_RETURN(md, openssl_digest_type(suite)); auto out = owned_bytes(EVP_MD_size(md)); auto out_len = size_t(out.size()); @@ -88,6 +89,7 @@ hkdf_extract(CipherSuite suite, input_bytes salt, input_bytes ikm) Result> hkdf_expand(CipherSuite suite, input_bytes prk, input_bytes info, size_t size) { + clear_openssl_errors(); SFRAME_VALUE_OR_RETURN(md, openssl_digest_type(suite)); auto out = owned_bytes(size); if (1 != HKDF_expand(out.data(), @@ -115,6 +117,7 @@ compute_tag(CipherSuite suite, input_bytes ct, size_t tag_size) { + clear_openssl_errors(); using scoped_hmac_ctx = std::unique_ptr; auto ctx = scoped_hmac_ctx(HMAC_CTX_new(), HMAC_CTX_free); @@ -122,7 +125,7 @@ compute_tag(CipherSuite suite, // Guard against sending nullptr to HMAC_Init_ex const auto* key_data = auth_key.data(); - auto key_size = static_cast(auth_key.size()); + SFRAME_VALUE_OR_RETURN(key_size, checked_int(auth_key.size())); const auto non_null_zero_length_key = uint8_t(0); if (key_data == nullptr) { key_data = &non_null_zero_length_key; @@ -173,6 +176,7 @@ ctr_crypt(CipherSuite suite, output_bytes out, input_bytes in) { + clear_openssl_errors(); if (out.size() != in.size()) { return SFrameError(SFrameErrorType::buffer_too_small_error, "CTR size mismatch"); @@ -194,12 +198,14 @@ ctr_crypt(CipherSuite suite, } int outlen = 0; - auto in_size_int = static_cast(in.size()); + SFRAME_VALUE_OR_RETURN(in_size_int, checked_int(in.size())); if (1 != EVP_EncryptUpdate( ctx.get(), out.data(), &outlen, in.data(), in_size_int)) { return SFrameErrorType::crypto_error; } + // CTR is a streaming mode, so finalization does not emit more bytes and a + // null output pointer is fine here. if (1 != EVP_EncryptFinal(ctx.get(), nullptr, &outlen)) { return SFrameErrorType::crypto_error; } @@ -216,6 +222,7 @@ seal_ctr(CipherSuite suite, input_bytes pt) { SFRAME_VALUE_OR_RETURN(tag_size, cipher_overhead(suite)); + SFRAME_VOID_OR_RETURN(validate_ctr_size(pt.size())); if (ct.size() < pt.size() + tag_size) { return SFrameError(SFrameErrorType::buffer_too_small_error, "Ciphertext buffer too small"); @@ -247,6 +254,7 @@ seal_aead(CipherSuite suite, input_bytes aad, input_bytes pt) { + clear_openssl_errors(); SFRAME_VALUE_OR_RETURN(tag_size, cipher_overhead(suite)); if (ct.size() < pt.size() + tag_size) { return SFrameError(SFrameErrorType::buffer_too_small_error, @@ -264,7 +272,7 @@ seal_aead(CipherSuite suite, } int outlen = 0; - auto aad_size_int = static_cast(aad.size()); + SFRAME_VALUE_OR_RETURN(aad_size_int, checked_int(aad.size())); if (aad.size() > 0) { if (1 != EVP_EncryptUpdate( ctx.get(), nullptr, &outlen, aad.data(), aad_size_int)) { @@ -272,7 +280,7 @@ seal_aead(CipherSuite suite, } } - auto pt_size_int = static_cast(pt.size()); + SFRAME_VALUE_OR_RETURN(pt_size_int, checked_int(pt.size())); if (1 != EVP_EncryptUpdate( ctx.get(), ct.data(), &outlen, pt.data(), pt_size_int)) { return SFrameErrorType::crypto_error; @@ -286,7 +294,7 @@ seal_aead(CipherSuite suite, auto tag = ct.subspan(pt.size(), tag_size); auto tag_ptr = const_cast(static_cast(tag.data())); - auto tag_size_downcast = static_cast(tag.size()); + SFRAME_VALUE_OR_RETURN(tag_size_downcast, checked_int(tag.size())); if (1 != EVP_CIPHER_CTX_ctrl( ctx.get(), EVP_CTRL_GCM_GET_TAG, tag_size_downcast, tag_ptr)) { return SFrameErrorType::crypto_error; @@ -334,6 +342,7 @@ open_ctr(CipherSuite suite, } auto inner_ct_size = ct.size() - tag_size; + SFRAME_VOID_OR_RETURN(validate_ctr_size(inner_ct_size)); auto inner_ct = ct.subspan(0, inner_ct_size); auto tag = ct.subspan(inner_ct_size, tag_size); @@ -365,6 +374,7 @@ open_aead(CipherSuite suite, input_bytes aad, input_bytes ct) { + clear_openssl_errors(); SFRAME_VALUE_OR_RETURN(tag_size, cipher_overhead(suite)); if (ct.size() < tag_size) { return SFrameError(SFrameErrorType::buffer_too_small_error, @@ -389,14 +399,14 @@ open_aead(CipherSuite suite, auto tag = ct.subspan(inner_ct_size, tag_size); auto tag_ptr = const_cast(static_cast(tag.data())); - auto tag_size_downcast = static_cast(tag.size()); + SFRAME_VALUE_OR_RETURN(tag_size_downcast, checked_int(tag.size())); if (1 != EVP_CIPHER_CTX_ctrl( ctx.get(), EVP_CTRL_GCM_SET_TAG, tag_size_downcast, tag_ptr)) { return SFrameErrorType::crypto_error; } int out_size; - auto aad_size_int = static_cast(aad.size()); + SFRAME_VALUE_OR_RETURN(aad_size_int, checked_int(aad.size())); if (aad.size() > 0) { if (1 != EVP_DecryptUpdate( ctx.get(), nullptr, &out_size, aad.data(), aad_size_int)) { @@ -404,7 +414,7 @@ open_aead(CipherSuite suite, } } - auto inner_ct_size_int = static_cast(inner_ct_size); + SFRAME_VALUE_OR_RETURN(inner_ct_size_int, checked_int(inner_ct_size)); if (1 != EVP_DecryptUpdate( ctx.get(), pt.data(), &out_size, ct.data(), inner_ct_size_int)) { return SFrameErrorType::crypto_error; diff --git a/src/crypto_openssl11.cpp b/src/crypto_openssl11.cpp index 9b33785..82c0b43 100644 --- a/src/crypto_openssl11.cpp +++ b/src/crypto_openssl11.cpp @@ -88,6 +88,7 @@ struct HMAC static Result create(CipherSuite suite, input_bytes key) { + clear_openssl_errors(); SFRAME_VALUE_OR_RETURN(type, openssl_digest_type(suite)); auto ctx = scoped_hmac_ctx(HMAC_CTX_new(), HMAC_CTX_free); @@ -95,11 +96,11 @@ struct HMAC // Some FIPS-enabled libraries are overly conservative in their // interpretation of NIST SP 800-131A, which requires HMAC keys to be at // least 112 bits long. That document does not impose that requirement on - // HKDF, so we disable FIPS enforcement for purposes of HKDF. + // HKDF, so this override is limited to the HKDF helper paths in this file. // // https://doi.org/10.6028/NIST.SP.800-131Ar2 static const auto fips_min_hmac_key_len = 14; - auto key_size = static_cast(key.size()); + SFRAME_VALUE_OR_RETURN(key_size, checked_int(key.size())); if (FIPS_mode() != 0 && key_size < fips_min_hmac_key_len) { HMAC_CTX_set_flags(ctx.get(), EVP_MD_CTX_FLAG_NON_FIPS_ALLOW); } @@ -144,6 +145,7 @@ struct HMAC Result> hkdf_extract(CipherSuite suite, input_bytes salt, input_bytes ikm) { + clear_openssl_errors(); SFRAME_VALUE_OR_RETURN(h, HMAC::create(suite, salt)); SFRAME_VOID_OR_RETURN(h.write(ikm)); @@ -156,6 +158,7 @@ hkdf_extract(CipherSuite suite, input_bytes salt, input_bytes ikm) Result> hkdf_expand(CipherSuite suite, input_bytes prk, input_bytes info, size_t size) { + clear_openssl_errors(); // Ensure that we need only one hash invocation if (size > max_hkdf_extract_size) { return SFrameError(SFrameErrorType::invalid_parameter_error, @@ -199,6 +202,7 @@ compute_tag(CipherSuite suite, input_bytes ct, size_t tag_size) { + clear_openssl_errors(); auto len_block = owned_bytes<24>(); auto len_view = output_bytes(len_block); encode_uint(aad.size(), len_view.first(8)); @@ -224,6 +228,7 @@ ctr_crypt(CipherSuite suite, output_bytes out, input_bytes in) { + clear_openssl_errors(); if (out.size() != in.size()) { return SFrameError(SFrameErrorType::buffer_too_small_error, "CTR size mismatch"); @@ -245,12 +250,14 @@ ctr_crypt(CipherSuite suite, } int outlen = 0; - auto in_size_int = static_cast(in.size()); + SFRAME_VALUE_OR_RETURN(in_size_int, checked_int(in.size())); if (1 != EVP_EncryptUpdate( ctx.get(), out.data(), &outlen, in.data(), in_size_int)) { return SFrameErrorType::crypto_error; } + // CTR is a streaming mode, so finalization does not emit more bytes and a + // null output pointer is fine here. if (1 != EVP_EncryptFinal(ctx.get(), nullptr, &outlen)) { return SFrameErrorType::crypto_error; } @@ -267,6 +274,7 @@ seal_ctr(CipherSuite suite, input_bytes pt) { SFRAME_VALUE_OR_RETURN(tag_size, cipher_overhead(suite)); + SFRAME_VOID_OR_RETURN(validate_ctr_size(pt.size())); if (ct.size() < pt.size() + tag_size) { return SFrameError(SFrameErrorType::buffer_too_small_error, "Ciphertext buffer too small"); @@ -298,6 +306,7 @@ seal_aead(CipherSuite suite, input_bytes aad, input_bytes pt) { + clear_openssl_errors(); SFRAME_VALUE_OR_RETURN(tag_size, cipher_overhead(suite)); if (ct.size() < pt.size() + tag_size) { return SFrameError(SFrameErrorType::buffer_too_small_error, @@ -315,7 +324,7 @@ seal_aead(CipherSuite suite, } int outlen = 0; - auto aad_size_int = static_cast(aad.size()); + SFRAME_VALUE_OR_RETURN(aad_size_int, checked_int(aad.size())); if (aad.size() > 0) { if (1 != EVP_EncryptUpdate( ctx.get(), nullptr, &outlen, aad.data(), aad_size_int)) { @@ -323,7 +332,7 @@ seal_aead(CipherSuite suite, } } - auto pt_size_int = static_cast(pt.size()); + SFRAME_VALUE_OR_RETURN(pt_size_int, checked_int(pt.size())); if (1 != EVP_EncryptUpdate( ctx.get(), ct.data(), &outlen, pt.data(), pt_size_int)) { return SFrameErrorType::crypto_error; @@ -337,7 +346,7 @@ seal_aead(CipherSuite suite, auto tag = ct.subspan(pt.size(), tag_size); auto tag_ptr = const_cast(static_cast(tag.data())); - auto tag_size_downcast = static_cast(tag.size()); + SFRAME_VALUE_OR_RETURN(tag_size_downcast, checked_int(tag.size())); if (1 != EVP_CIPHER_CTX_ctrl( ctx.get(), EVP_CTRL_GCM_GET_TAG, tag_size_downcast, tag_ptr)) { return SFrameErrorType::crypto_error; @@ -385,6 +394,7 @@ open_ctr(CipherSuite suite, } auto inner_ct_size = ct.size() - tag_size; + SFRAME_VOID_OR_RETURN(validate_ctr_size(inner_ct_size)); auto inner_ct = ct.subspan(0, inner_ct_size); auto tag = ct.subspan(inner_ct_size, tag_size); @@ -416,6 +426,7 @@ open_aead(CipherSuite suite, input_bytes aad, input_bytes ct) { + clear_openssl_errors(); SFRAME_VALUE_OR_RETURN(tag_size, cipher_overhead(suite)); if (ct.size() < tag_size) { return SFrameError(SFrameErrorType::buffer_too_small_error, @@ -440,14 +451,14 @@ open_aead(CipherSuite suite, auto tag = ct.subspan(inner_ct_size, tag_size); auto tag_ptr = const_cast(static_cast(tag.data())); - auto tag_size_downcast = static_cast(tag.size()); + SFRAME_VALUE_OR_RETURN(tag_size_downcast, checked_int(tag.size())); if (1 != EVP_CIPHER_CTX_ctrl( ctx.get(), EVP_CTRL_GCM_SET_TAG, tag_size_downcast, tag_ptr)) { return SFrameErrorType::crypto_error; } int out_size; - auto aad_size_int = static_cast(aad.size()); + SFRAME_VALUE_OR_RETURN(aad_size_int, checked_int(aad.size())); if (aad.size() > 0) { if (1 != EVP_DecryptUpdate( ctx.get(), nullptr, &out_size, aad.data(), aad_size_int)) { @@ -455,7 +466,7 @@ open_aead(CipherSuite suite, } } - auto inner_ct_size_int = static_cast(inner_ct_size); + SFRAME_VALUE_OR_RETURN(inner_ct_size_int, checked_int(inner_ct_size)); if (1 != EVP_DecryptUpdate( ctx.get(), pt.data(), &out_size, ct.data(), inner_ct_size_int)) { return SFrameErrorType::crypto_error; diff --git a/src/crypto_openssl3.cpp b/src/crypto_openssl3.cpp index 36fb125..b19b58d 100644 --- a/src/crypto_openssl3.cpp +++ b/src/crypto_openssl3.cpp @@ -73,6 +73,7 @@ using scoped_evp_kdf_ctx = Result> hkdf_extract(CipherSuite suite, input_bytes salt, input_bytes ikm) { + clear_openssl_errors(); auto mode = EVP_KDF_HKDF_MODE_EXTRACT_ONLY; SFRAME_VALUE_OR_RETURN(digest_name, openssl_digest_name(suite)); auto* salt_ptr = @@ -109,6 +110,7 @@ hkdf_extract(CipherSuite suite, input_bytes salt, input_bytes ikm) Result> hkdf_expand(CipherSuite suite, input_bytes prk, input_bytes info, size_t size) { + clear_openssl_errors(); auto mode = EVP_KDF_HKDF_MODE_EXPAND_ONLY; SFRAME_VALUE_OR_RETURN(digest_name, openssl_digest_name(suite)); auto* prk_ptr = const_cast(reinterpret_cast(prk.data())); @@ -150,6 +152,7 @@ compute_tag(CipherSuite suite, input_bytes ct, size_t tag_size) { + clear_openssl_errors(); using scoped_evp_mac = std::unique_ptr; using scoped_evp_mac_ctx = std::unique_ptr; @@ -213,6 +216,7 @@ ctr_crypt(CipherSuite suite, output_bytes out, input_bytes in) { + clear_openssl_errors(); if (out.size() != in.size()) { return SFrameError(SFrameErrorType::buffer_too_small_error, "CTR size mismatch"); @@ -234,12 +238,14 @@ ctr_crypt(CipherSuite suite, } int outlen = 0; - auto in_size_int = static_cast(in.size()); + SFRAME_VALUE_OR_RETURN(in_size_int, checked_int(in.size())); if (1 != EVP_EncryptUpdate( ctx.get(), out.data(), &outlen, in.data(), in_size_int)) { return SFrameErrorType::crypto_error; } + // CTR is a streaming mode, so finalization does not emit more bytes and a + // null output pointer is fine here. if (1 != EVP_EncryptFinal(ctx.get(), nullptr, &outlen)) { return SFrameErrorType::crypto_error; } @@ -256,6 +262,7 @@ seal_ctr(CipherSuite suite, input_bytes pt) { SFRAME_VALUE_OR_RETURN(tag_size, cipher_overhead(suite)); + SFRAME_VOID_OR_RETURN(validate_ctr_size(pt.size())); if (ct.size() < pt.size() + tag_size) { return SFrameError(SFrameErrorType::buffer_too_small_error, "Ciphertext buffer too small"); @@ -287,6 +294,7 @@ seal_aead(CipherSuite suite, input_bytes aad, input_bytes pt) { + clear_openssl_errors(); SFRAME_VALUE_OR_RETURN(tag_size, cipher_overhead(suite)); if (ct.size() < pt.size() + tag_size) { return SFrameError(SFrameErrorType::buffer_too_small_error, @@ -304,7 +312,7 @@ seal_aead(CipherSuite suite, } int outlen = 0; - auto aad_size_int = static_cast(aad.size()); + SFRAME_VALUE_OR_RETURN(aad_size_int, checked_int(aad.size())); if (aad.size() > 0) { if (1 != EVP_EncryptUpdate( ctx.get(), nullptr, &outlen, aad.data(), aad_size_int)) { @@ -312,7 +320,7 @@ seal_aead(CipherSuite suite, } } - auto pt_size_int = static_cast(pt.size()); + SFRAME_VALUE_OR_RETURN(pt_size_int, checked_int(pt.size())); if (1 != EVP_EncryptUpdate( ctx.get(), ct.data(), &outlen, pt.data(), pt_size_int)) { return SFrameErrorType::crypto_error; @@ -326,7 +334,7 @@ seal_aead(CipherSuite suite, auto tag = ct.subspan(pt.size(), tag_size); auto tag_ptr = const_cast(static_cast(tag.data())); - auto tag_size_downcast = static_cast(tag.size()); + SFRAME_VALUE_OR_RETURN(tag_size_downcast, checked_int(tag.size())); if (1 != EVP_CIPHER_CTX_ctrl( ctx.get(), EVP_CTRL_GCM_GET_TAG, tag_size_downcast, tag_ptr)) { return SFrameErrorType::crypto_error; @@ -374,6 +382,7 @@ open_ctr(CipherSuite suite, } auto inner_ct_size = ct.size() - tag_size; + SFRAME_VOID_OR_RETURN(validate_ctr_size(inner_ct_size)); auto inner_ct = ct.subspan(0, inner_ct_size); auto tag = ct.subspan(inner_ct_size, tag_size); @@ -405,6 +414,7 @@ open_aead(CipherSuite suite, input_bytes aad, input_bytes ct) { + clear_openssl_errors(); SFRAME_VALUE_OR_RETURN(tag_size, cipher_overhead(suite)); if (ct.size() < tag_size) { return SFrameError(SFrameErrorType::buffer_too_small_error, @@ -429,14 +439,14 @@ open_aead(CipherSuite suite, auto tag = ct.subspan(inner_ct_size, tag_size); auto tag_ptr = const_cast(static_cast(tag.data())); - auto tag_size_downcast = static_cast(tag.size()); + SFRAME_VALUE_OR_RETURN(tag_size_downcast, checked_int(tag.size())); if (1 != EVP_CIPHER_CTX_ctrl( ctx.get(), EVP_CTRL_GCM_SET_TAG, tag_size_downcast, tag_ptr)) { return SFrameErrorType::crypto_error; } int out_size; - auto aad_size_int = static_cast(aad.size()); + SFRAME_VALUE_OR_RETURN(aad_size_int, checked_int(aad.size())); if (aad.size() > 0) { if (1 != EVP_DecryptUpdate( ctx.get(), nullptr, &out_size, aad.data(), aad_size_int)) { @@ -444,7 +454,7 @@ open_aead(CipherSuite suite, } } - auto inner_ct_size_int = static_cast(inner_ct_size); + SFRAME_VALUE_OR_RETURN(inner_ct_size_int, checked_int(inner_ct_size)); if (1 != EVP_DecryptUpdate( ctx.get(), pt.data(), &out_size, ct.data(), inner_ct_size_int)) { return SFrameErrorType::crypto_error; diff --git a/src/header.cpp b/src/header.cpp index 58a072f..b1bc9c1 100644 --- a/src/header.cpp +++ b/src/header.cpp @@ -1,5 +1,7 @@ #include "header.h" +#include + namespace SFRAME_NAMESPACE { static size_t @@ -23,6 +25,7 @@ void encode_uint(uint64_t val, output_bytes buffer) { size_t size = buffer.size(); + std::fill(buffer.begin(), buffer.end(), uint8_t(0)); for (size_t i = 0; i < size && i < 8; i++) { buffer[size - i - 1] = uint8_t(val >> (8 * i)); } @@ -31,6 +34,11 @@ encode_uint(uint64_t val, output_bytes buffer) static Result decode_uint(input_bytes data) { + if (data.empty()) { + return SFrameError(SFrameErrorType::invalid_parameter_error, + "Integer encoding is empty"); + } + if (!data.empty() && data[0] == 0) { return SFrameError(SFrameErrorType::invalid_parameter_error, "Integer is not minimally encoded"); @@ -150,6 +158,11 @@ Header::parse(input_bytes buffer) } const auto cfg = ConfigByte{ buffer[0] }; + if (cfg.encoded_size() > buffer.size()) { + return SFrameError(SFrameErrorType::buffer_too_small_error, + "Ciphertext too small to decode header"); + } + const auto after_cfg = buffer.subspan(1); SFRAME_VALUE_OR_RETURN(kid_result, cfg.kid.read(after_cfg)); const auto [key_id, after_kid] = kid_result; diff --git a/src/sframe.cpp b/src/sframe.cpp index b389cfc..60c28c1 100644 --- a/src/sframe.cpp +++ b/src/sframe.cpp @@ -3,6 +3,8 @@ #include "crypto.h" #include "header.h" +#include + namespace SFRAME_NAMESPACE { /// @@ -57,12 +59,12 @@ KeyRecord::from_base_key(CipherSuite suite, SFRAME_VALUE_OR_RETURN(key_size, cipher_key_size(suite)); SFRAME_VALUE_OR_RETURN(nonce_size, cipher_nonce_size(suite)); - const auto empty_byte_string = owned_bytes<1>(); + const auto empty_salt_storage = owned_bytes<1>(); + const auto empty_salt = input_bytes(empty_salt_storage).first(0); const auto key_label = sframe_key_label(suite, key_id); const auto salt_label = sframe_salt_label(suite, key_id); - SFRAME_VALUE_OR_RETURN(secret, - hkdf_extract(suite, empty_byte_string, base_key)); + SFRAME_VALUE_OR_RETURN(secret, hkdf_extract(suite, empty_salt, base_key)); SFRAME_VALUE_OR_RETURN(key, hkdf_expand(suite, secret, key_label, key_size)); SFRAME_VALUE_OR_RETURN(salt, hkdf_expand(suite, secret, salt_label, nonce_size)); @@ -142,6 +144,16 @@ Context::protect(KeyID key_id, { SFRAME_VOID_OR_RETURN(require_key(key_id)); auto& key_record = keys.at(key_id); + if (key_record.usage != KeyUsage::protect) { + return SFrameError(SFrameErrorType::invalid_key_usage_error, + "Key is not valid for protect"); + } + + if (key_record.counter == std::numeric_limits::max()) { + return SFrameError(SFrameErrorType::invalid_parameter_error, + "Counter exhausted"); + } + const auto counter = key_record.counter; key_record.counter += 1; @@ -210,6 +222,10 @@ Context::unprotect_inner(const Header& header, SFRAME_VOID_OR_RETURN(require_key(header.key_id)); const auto& key_and_salt = keys.at(header.key_id); + if (key_and_salt.usage != KeyUsage::unprotect) { + return SFrameError(SFrameErrorType::invalid_key_usage_error, + "Key is not valid for unprotect"); + } SFRAME_VALUE_OR_RETURN(aad, form_aad(header, metadata)); const auto nonce = form_nonce(header.counter, key_and_salt.salt); @@ -359,10 +375,10 @@ MLSContext::remove_epoch(EpochID epoch_id) void MLSContext::purge_epoch(EpochID epoch_id) { - const auto drop_bits = epoch_id & epoch_bits; + const auto drop_bits = epoch_id & epoch_mask; keys.erase_if_key( - [&](const auto& epoch) { return (epoch & epoch_bits) == drop_bits; }); + [&](const auto& epoch) { return (epoch & epoch_mask) == drop_bits; }); } Result