From bc44e63b2afa656b679ae19505d8eeb5591903cf Mon Sep 17 00:00:00 2001 From: Tim Hurski Date: Wed, 12 Nov 2025 16:33:52 +0000 Subject: [PATCH 1/5] Move ECB encryption mode into its own module --- cpp/src/gandiva/CMakeLists.txt | 4 +- cpp/src/gandiva/encrypt_utils.cc | 106 -------- cpp/src/gandiva/encrypt_utils.h | 42 ---- cpp/src/gandiva/encrypt_utils_ecb.cc | 141 +++++++++++ cpp/src/gandiva/encrypt_utils_ecb.h | 63 +++++ ...tils_test.cc => encrypt_utils_ecb_test.cc} | 34 +-- cpp/src/gandiva/function_registry_string.cc | 15 +- cpp/src/gandiva/gdv_function_stubs.cc | 229 ++++++++++++++---- cpp/src/gandiva/gdv_function_stubs.h | 29 ++- cpp/src/gandiva/gdv_function_stubs_test.cc | 85 ++++++- 10 files changed, 521 insertions(+), 227 deletions(-) delete mode 100644 cpp/src/gandiva/encrypt_utils.cc delete mode 100644 cpp/src/gandiva/encrypt_utils.h create mode 100644 cpp/src/gandiva/encrypt_utils_ecb.cc create mode 100644 cpp/src/gandiva/encrypt_utils_ecb.h rename cpp/src/gandiva/{encrypt_utils_test.cc => encrypt_utils_ecb_test.cc} (71%) diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 2eb76efa3a7..02c79f6782a 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -56,7 +56,7 @@ set(SRC_FILES decimal_xlarge.cc engine.cc date_utils.cc - encrypt_utils.cc + encrypt_utils_ecb.cc expr_decomposer.cc expr_validator.cc expression.cc @@ -256,7 +256,7 @@ add_gandiva_test(internals-test llvm_generator_test.cc annotator_test.cc tree_expr_test.cc - encrypt_utils_test.cc + encrypt_utils_ecb_test.cc expr_decomposer_test.cc exported_funcs_registry_test.cc expression_registry_test.cc diff --git a/cpp/src/gandiva/encrypt_utils.cc b/cpp/src/gandiva/encrypt_utils.cc deleted file mode 100644 index 9dee10cfdb6..00000000000 --- a/cpp/src/gandiva/encrypt_utils.cc +++ /dev/null @@ -1,106 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "gandiva/encrypt_utils.h" -#include - -#include - -namespace gandiva { -GANDIVA_EXPORT -int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key, - int32_t key_len, unsigned char* cipher) { - int32_t cipher_len = 0; - int32_t len = 0; - EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); - const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len); - - if (!en_ctx) { - throw std::runtime_error("could not create a new evp cipher ctx for encryption"); - } - - if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, - reinterpret_cast(key), nullptr)) { - throw std::runtime_error("could not initialize evp cipher ctx for encryption"); - } - - if (!EVP_EncryptUpdate(en_ctx, cipher, &len, - reinterpret_cast(plaintext), - plaintext_len)) { - throw std::runtime_error("could not update evp cipher ctx for encryption"); - } - - cipher_len += len; - - if (!EVP_EncryptFinal_ex(en_ctx, cipher + len, &len)) { - throw std::runtime_error("could not finish evp cipher ctx for encryption"); - } - - cipher_len += len; - - EVP_CIPHER_CTX_free(en_ctx); - return cipher_len; -} - -GANDIVA_EXPORT -int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key, - int32_t key_len, unsigned char* plaintext) { - int32_t plaintext_len = 0; - int32_t len = 0; - EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); - const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len); - - if (!de_ctx) { - throw std::runtime_error("could not create a new evp cipher ctx for decryption"); - } - - if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, - reinterpret_cast(key), nullptr)) { - throw std::runtime_error("could not initialize evp cipher ctx for decryption"); - } - - if (!EVP_DecryptUpdate(de_ctx, plaintext, &len, - reinterpret_cast(ciphertext), - ciphertext_len)) { - throw std::runtime_error("could not update evp cipher ctx for decryption"); - } - - plaintext_len += len; - - if (!EVP_DecryptFinal_ex(de_ctx, plaintext + len, &len)) { - throw std::runtime_error("could not finish evp cipher ctx for decryption"); - } - - plaintext_len += len; - - EVP_CIPHER_CTX_free(de_ctx); - return plaintext_len; -} - -const EVP_CIPHER* get_cipher_algo(int32_t key_length){ - switch (key_length) { - case 16: - return EVP_aes_128_ecb(); - case 24: - return EVP_aes_192_ecb(); - case 32: - return EVP_aes_256_ecb(); - default: - throw std::runtime_error("unsupported key length"); - } -} -} // namespace gandiva diff --git a/cpp/src/gandiva/encrypt_utils.h b/cpp/src/gandiva/encrypt_utils.h deleted file mode 100644 index f02b029f01b..00000000000 --- a/cpp/src/gandiva/encrypt_utils.h +++ /dev/null @@ -1,42 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include -#include "gandiva/visibility.h" - -namespace gandiva { - -/** - * Encrypt data using aes algorithm - **/ -GANDIVA_EXPORT -int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key, - int32_t key_len, unsigned char* cipher); - -/** - * Decrypt data using aes algorithm - **/ -GANDIVA_EXPORT -int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key, - int32_t key_len, unsigned char* plaintext); - -const EVP_CIPHER* get_cipher_algo(int32_t key_length); - -} // namespace gandiva diff --git a/cpp/src/gandiva/encrypt_utils_ecb.cc b/cpp/src/gandiva/encrypt_utils_ecb.cc new file mode 100644 index 00000000000..4d2a310121f --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_ecb.cc @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_ecb.h" +#include +#include +#include +#include +#include + +namespace gandiva { + +namespace { + +std::string get_openssl_error_string() { + unsigned long error_code = ERR_get_error(); + if (error_code == 0) { + return "Unknown OpenSSL error"; + } + return std::string(ERR_reason_error_string(error_code)); +} + +const EVP_CIPHER* get_ecb_cipher_algo(int32_t key_length) { + switch (key_length) { + case 16: + return EVP_aes_128_ecb(); + case 24: + return EVP_aes_192_ecb(); + case 32: + return EVP_aes_256_ecb(); + default: { + std::ostringstream oss; + oss << "Unsupported key length for AES-ECB: " << key_length + << " bytes. Supported lengths: 16, 24, 32 bytes"; + throw std::runtime_error(oss.str()); + } + } +} + +} // namespace + +GANDIVA_EXPORT +int32_t aes_encrypt_ecb(const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, unsigned char* cipher) { + int32_t cipher_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_ecb_cipher_algo(key_len); + + if (!en_ctx) { + throw std::runtime_error("Could not create EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, + reinterpret_cast(key), nullptr)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not initialize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + if (!EVP_EncryptUpdate(en_ctx, cipher, &len, + reinterpret_cast(plaintext), + plaintext_len)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not update EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + if (!EVP_EncryptFinal_ex(en_ctx, cipher + len, &len)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not finalize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + EVP_CIPHER_CTX_free(en_ctx); + return cipher_len; +} + +GANDIVA_EXPORT +int32_t aes_decrypt_ecb(const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, unsigned char* plaintext) { + int32_t plaintext_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_ecb_cipher_algo(key_len); + + if (!de_ctx) { + throw std::runtime_error("Could not create EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, + reinterpret_cast(key), nullptr)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not initialize EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + if (!EVP_DecryptUpdate(de_ctx, plaintext, &len, + reinterpret_cast(ciphertext), + ciphertext_len)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not update EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + plaintext_len += len; + + if (!EVP_DecryptFinal_ex(de_ctx, plaintext + len, &len)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not finalize EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + plaintext_len += len; + + EVP_CIPHER_CTX_free(de_ctx); + return plaintext_len; +} + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_ecb.h b/cpp/src/gandiva/encrypt_utils_ecb.h new file mode 100644 index 00000000000..51be2644120 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_ecb.h @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include "gandiva/visibility.h" + +namespace gandiva { + +/** + * Encrypt data using AES-ECB algorithm (legacy, insecure) + * + * WARNING: ECB mode is deterministic and should not be used for sensitive data. + * Use other encryption modes for better security. + * + * @param plaintext The data to encrypt + * @param plaintext_len Length of plaintext in bytes + * @param key The encryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param cipher Output buffer for encrypted data + * @return Length of encrypted data in bytes + * @throws std::runtime_error on encryption failure + */ +GANDIVA_EXPORT +int32_t aes_encrypt_ecb(const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, unsigned char* cipher); + +/** + * Decrypt data using AES-ECB algorithm (legacy, insecure) + * + * WARNING: ECB mode is deterministic and should not be used for sensitive data. + * Use other encryption modes for better security. + * + * @param ciphertext The data to decrypt + * @param ciphertext_len Length of ciphertext in bytes + * @param key The decryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param plaintext Output buffer for decrypted data + * @return Length of decrypted data in bytes + * @throws std::runtime_error on decryption failure + */ +GANDIVA_EXPORT +int32_t aes_decrypt_ecb(const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, unsigned char* plaintext); + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_test.cc b/cpp/src/gandiva/encrypt_utils_ecb_test.cc similarity index 71% rename from cpp/src/gandiva/encrypt_utils_test.cc rename to cpp/src/gandiva/encrypt_utils_ecb_test.cc index 689f20ab032..d8983e947c1 100644 --- a/cpp/src/gandiva/encrypt_utils_test.cc +++ b/cpp/src/gandiva/encrypt_utils_ecb_test.cc @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -#include "gandiva/encrypt_utils.h" +#include "gandiva/encrypt_utils_ecb.h" #include +#include -TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { +TEST(TestAesEcbEncryptUtils, TestAesEncryptDecrypt) { // 16 bytes key auto* key = "12345678abcdefgh"; auto* to_encrypt = "some test string"; @@ -29,11 +30,11 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_1[64]; - int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_1); + int32_t cipher_1_len = gandiva::aes_encrypt_ecb(to_encrypt, to_encrypt_len, key, key_len, cipher_1); unsigned char decrypted_1[64]; - int32_t decrypted_1_len = gandiva::aes_decrypt(reinterpret_cast(cipher_1), - cipher_1_len, key, key_len, decrypted_1); + int32_t decrypted_1_len = gandiva::aes_decrypt_ecb(reinterpret_cast(cipher_1), + cipher_1_len, key, key_len, decrypted_1); EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), std::string(reinterpret_cast(decrypted_1), decrypted_1_len)); @@ -47,11 +48,11 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_2[64]; - int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_2); + int32_t cipher_2_len = gandiva::aes_encrypt_ecb(to_encrypt, to_encrypt_len, key, key_len, cipher_2); unsigned char decrypted_2[64]; - int32_t decrypted_2_len = gandiva::aes_decrypt(reinterpret_cast(cipher_2), - cipher_2_len, key, key_len, decrypted_2); + int32_t decrypted_2_len = gandiva::aes_decrypt_ecb(reinterpret_cast(cipher_2), + cipher_2_len, key, key_len, decrypted_2); EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), std::string(reinterpret_cast(decrypted_2), decrypted_2_len)); @@ -65,11 +66,11 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_3[64]; - int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_3); + int32_t cipher_3_len = gandiva::aes_encrypt_ecb(to_encrypt, to_encrypt_len, key, key_len, cipher_3); unsigned char decrypted_3[64]; - int32_t decrypted_3_len = gandiva::aes_decrypt(reinterpret_cast(cipher_3), - cipher_3_len, key, key_len, decrypted_3); + int32_t decrypted_3_len = gandiva::aes_decrypt_ecb(reinterpret_cast(cipher_3), + cipher_3_len, key, key_len, decrypted_3); EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), std::string(reinterpret_cast(decrypted_3), decrypted_3_len)); @@ -87,11 +88,11 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_4[64]; ASSERT_THROW({ - gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_4); + gandiva::aes_encrypt_ecb(to_encrypt, to_encrypt_len, key, key_len, cipher_4); }, std::runtime_error); ASSERT_THROW({ - gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text); + gandiva::aes_decrypt_ecb(cipher, cipher_len, key, key_len, plain_text); }, std::runtime_error); key = "12345678"; @@ -102,9 +103,10 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_5[64]; ASSERT_THROW({ - gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_5); + gandiva::aes_encrypt_ecb(to_encrypt, to_encrypt_len, key, key_len, cipher_5); }, std::runtime_error); ASSERT_THROW({ - gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text); - }, std::runtime_error); + gandiva::aes_decrypt_ecb(cipher, cipher_len, key, key_len, plain_text); + }, std::runtime_error); } + diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index 2bc6936d77b..fb662b623fb 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -494,12 +494,23 @@ std::vector GetStringFunctionRegistry() { kResultNullIfNull, "split_part", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + // ECB mode specific functions + // String-based signatures (UTF8, UTF8) -> UTF8 NativeFunction("aes_encrypt", {}, DataTypeVector{utf8(), utf8()}, utf8(), - kResultNullIfNull, "gdv_fn_aes_encrypt", + kResultNullIfNull, "gdv_fn_aes_encrypt_ecb_legacy", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), NativeFunction("aes_decrypt", {}, DataTypeVector{utf8(), utf8()}, utf8(), - kResultNullIfNull, "gdv_fn_aes_decrypt", + kResultNullIfNull, "gdv_fn_aes_decrypt_ecb_legacy", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + // Binary-based signatures (BINARY, BINARY, UTF8) -> BINARY + NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(), + kResultNullIfNull, "gdv_fn_aes_encrypt_ecb", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(), + kResultNullIfNull, "gdv_fn_aes_decrypt_ecb", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), NativeFunction("mask_first_n", {}, DataTypeVector{utf8(), int32()}, utf8(), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index ac967659ae3..d2610c4ee92 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -19,7 +19,9 @@ #include +#include #include +#include #include #include @@ -28,7 +30,7 @@ #include "arrow/util/double_conversion.h" #include "arrow/util/value_parsing.h" -#include "gandiva/encrypt_utils.h" +#include "gandiva/encrypt_utils_ecb.h" #include "gandiva/engine.h" #include "gandiva/exported_funcs.h" #include "gandiva/in_holder.h" @@ -394,90 +396,195 @@ CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8) #undef GDV_FN_CAST_VARCHAR_INTEGER #undef GDV_FN_CAST_VARCHAR_REAL + + +// ECB mode specific functions - core implementation +// This handles both string and binary inputs (they have the same C signature) GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - int32_t* out_len) { +const char* gdv_fn_aes_encrypt_ecb(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + int32_t* out_len) { + // Validate mode parameter + if (mode == nullptr) { + std::ostringstream oss; + oss << "Invalid mode parameter for AES encryption"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + std::string mode_str(mode, mode_len); + // Convert to uppercase for comparison + std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); + + if (mode_str != "ECB") { + std::ostringstream oss; + oss << "AES encryption mode mismatch: function signature indicates ECB mode, but '" + << mode_str << "' was provided instead"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + if (data_len < 0) { - gdv_fn_context_set_error_msg(context, "Invalid data length to be encrypted"); + std::ostringstream oss; + oss << "Invalid data length for AES encryption: " << data_len << " (must be >= 0)"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); *out_len = 0; return ""; } - int64_t kAesBlockSize = 0; - if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) { - kAesBlockSize = static_cast(key_data_len); - } else { - gdv_fn_context_set_error_msg(context, "invalid key length"); + if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { + std::ostringstream oss; + oss << "Invalid key length for AES encryption: " << key_data_len + << " bytes. Supported lengths: 16, 24, 32 bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); *out_len = 0; return nullptr; } - + + // AES block size is always 16 bytes (128 bits), regardless of key length + int64_t kAesBlockSize = 16; *out_len = - static_cast(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize)); + static_cast(arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); if (ret == nullptr) { - std::string err_msg = - "Could not allocate memory for returning aes encrypt cypher text"; - gdv_fn_context_set_error_msg(context, err_msg.data()); - *out_len = 0; + std::ostringstream oss; + oss << "Could not allocate memory for AES encryption output: " << *out_len << " bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; return nullptr; } try { - *out_len = gandiva::aes_encrypt(data, data_len, key_data, key_data_len, - reinterpret_cast(ret)); + *out_len = gandiva::aes_encrypt_ecb(data, data_len, key_data, key_data_len, + reinterpret_cast(ret)); } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); - *out_len = 0; + *out_len = 0; return nullptr; } return ret; } +// Legacy wrapper for string-based signatures (UTF8, UTF8) -> UTF8 +// This is called by the LLVM engine with string calling convention +// WARNING: This function is for backward compatibility only. Encrypted binary data +// is not guaranteed to be valid UTF-8. Use binary signatures for new code. GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - int32_t* out_len) { +const char* gdv_fn_aes_encrypt_ecb_legacy(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + int32_t* out_len) { + // Delegate to the core implementation with ECB mode + const char* mode = "ECB"; + int32_t mode_len = 3; + const char* result = gdv_fn_aes_encrypt_ecb(context, data, data_len, key_data, key_data_len, mode, mode_len, out_len); + + // Add null terminator for string compatibility + // Note: This may not be valid UTF-8, but it's needed for string handling + if (result != nullptr) { + char* mutable_result = const_cast(result); + mutable_result[*out_len] = '\0'; + } + + return result; +} + +GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_ecb(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + int32_t* out_len) { + // Validate mode parameter + if (mode == nullptr) { + std::ostringstream oss; + oss << "Invalid mode parameter for AES decryption"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + std::string mode_str(mode, mode_len); + // Convert to uppercase for comparison + std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); + + if (mode_str != "ECB") { + std::ostringstream oss; + oss << "AES decryption mode mismatch: function signature indicates ECB mode, but '" + << mode_str << "' was provided instead"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + if (data_len < 0) { - gdv_fn_context_set_error_msg(context, "Invalid data length to be decrypted"); + std::ostringstream oss; + oss << "Invalid data length for AES decryption: " << data_len << " (must be >= 0)"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); *out_len = 0; return ""; } - int64_t kAesBlockSize = 0; - if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) { - kAesBlockSize = static_cast(key_data_len); - } else { - gdv_fn_context_set_error_msg(context, "invalid key length"); + if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { + std::ostringstream oss; + oss << "Invalid key length for AES decryption: " << key_data_len + << " bytes. Supported lengths: 16, 24, 32 bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); *out_len = 0; return nullptr; } + // AES block size is always 16 bytes (128 bits), regardless of key length + int64_t kAesBlockSize = 16; *out_len = - static_cast(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize)); + static_cast(arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); if (ret == nullptr) { - std::string err_msg = - "Could not allocate memory for returning aes encrypt cypher text"; - gdv_fn_context_set_error_msg(context, err_msg.data()); - *out_len = 0; + std::ostringstream oss; + oss << "Could not allocate memory for AES decryption output: " << *out_len << " bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; return nullptr; } try { - *out_len = gandiva::aes_decrypt(data, data_len, key_data, key_data_len, - reinterpret_cast(ret)); + *out_len = gandiva::aes_decrypt_ecb(data, data_len, key_data, key_data_len, + reinterpret_cast(ret)); } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); - *out_len = 0; + *out_len = 0; return nullptr; } - ret[*out_len] = '\0'; + return ret; } +// Legacy wrapper for string-based signatures (UTF8, UTF8) -> UTF8 +// This is called by the LLVM engine with string calling convention +// WARNING: This function is for backward compatibility only. Decrypted data may not be +// valid UTF-8. Use binary signatures for new code. +GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + int32_t* out_len) { + // Delegate to the core implementation with ECB mode + const char* mode = "ECB"; + int32_t mode_len = 3; + const char* result = gdv_fn_aes_decrypt_ecb(context, data, data_len, key_data, key_data_len, mode, mode_len, out_len); + + // Add null terminator for string compatibility + // Note: This may not be valid UTF-8, but it's needed for string handling + if (result != nullptr) { + char* mutable_result = const_cast(result); + mutable_result[*out_len] = '\0'; + } + + return result; +} + + GANDIVA_EXPORT const char* gdv_mask_first_n_utf8_int32(int64_t context, const char* data, int32_t data_len, int32_t n_to_mask, @@ -1122,7 +1229,43 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(gdv_fn_base64_decode_utf8)); - // gdv_fn_aes_encrypt + // gdv_fn_aes_encrypt_ecb + // Note: The mode parameter is passed as a UTF8 string (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, out_len) + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_length + types->i8_ptr_type(), // key_data + types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (UTF8 string) + types->i32_type(), // mode_length + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc("gdv_fn_aes_encrypt_ecb", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_encrypt_ecb)); + + // gdv_fn_aes_decrypt_ecb + // Note: The mode parameter is passed as a UTF8 string (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, out_len) + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_length + types->i8_ptr_type(), // key_data + types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (UTF8 string) + types->i32_type(), // mode_length + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc("gdv_fn_aes_decrypt_ecb", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_decrypt_ecb)); + + // gdv_fn_aes_encrypt_ecb_legacy (wrapper for string-based signatures) args = { types->i64_type(), // context types->i8_ptr_type(), // data @@ -1132,11 +1275,11 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { types->i32_ptr_type() // out_length }; - engine->AddGlobalMappingForFunc("gdv_fn_aes_encrypt", + engine->AddGlobalMappingForFunc("gdv_fn_aes_encrypt_ecb_legacy", types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_encrypt)); + reinterpret_cast(gdv_fn_aes_encrypt_ecb_legacy)); - // gdv_fn_aes_decrypt + // gdv_fn_aes_decrypt_ecb_legacy (wrapper for string-based signatures) args = { types->i64_type(), // context types->i8_ptr_type(), // data @@ -1146,9 +1289,9 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { types->i32_ptr_type() // out_length }; - engine->AddGlobalMappingForFunc("gdv_fn_aes_decrypt", + engine->AddGlobalMappingForFunc("gdv_fn_aes_decrypt_ecb_legacy", types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_decrypt)); + reinterpret_cast(gdv_fn_aes_decrypt_ecb_legacy)); // gdv_mask_first_n and gdv_mask_last_n std::vector mask_args = { diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 4113f261ad7..5c5452df280 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -189,14 +189,31 @@ float gdv_fn_castFLOAT4_varbinary(gdv_int64 context, const char* in, int32_t in_ GANDIVA_EXPORT double gdv_fn_castFLOAT8_varbinary(gdv_int64 context, const char* in, int32_t in_len); +// ECB mode specific functions GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - int32_t* out_len); +const char* gdv_fn_aes_encrypt_ecb(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + int32_t* out_len); + +GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_ecb(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + int32_t* out_len); + +// Legacy wrappers for string-based signatures GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - int32_t* out_len); +const char* gdv_fn_aes_encrypt_ecb_legacy(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + int32_t* out_len); + +GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + int32_t* out_len); + + GANDIVA_EXPORT const char* gdv_mask_first_n_utf8_int32(int64_t context, const char* data, diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index 134f7dcd27d..da785d1a511 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1353,10 +1353,12 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt16) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); + std::string mode = "ECB"; + auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, &cipher_len); - const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, &decrypted_len); + const char* cipher = gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), mode_len, &cipher_len); + const char* decrypted_value = gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, &decrypted_len); EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); } @@ -1369,11 +1371,13 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); + std::string mode = "ECB"; + auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key24.c_str(), key24_len, &cipher_len); + const char* cipher = gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key24.c_str(), key24_len, mode.c_str(), mode_len, &cipher_len); - const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, &decrypted_len); + const char* decrypted_value = gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, mode.c_str(), mode_len, &decrypted_len); EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); } @@ -1386,11 +1390,13 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); + std::string mode = "ECB"; + auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key32.c_str(), key32_len, &cipher_len); + const char* cipher = gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key32.c_str(), key32_len, mode.c_str(), mode_len, &cipher_len); - const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, &decrypted_len); + const char* decrypted_value = gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, mode.c_str(), mode_len, &decrypted_len); EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); } @@ -1402,17 +1408,76 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); + std::string mode = "ECB"; + auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); std::string cipher = "12345678abcdefgh12345678abcdefghb"; auto cipher_len = static_cast(cipher.length()); - gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key33.c_str(), key33_len, &cipher_len); + gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key33.c_str(), key33_len, mode.c_str(), mode_len, &cipher_len); EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("invalid key length")); + ::testing::HasSubstr("Invalid key length for AES encryption")); ctx.Reset(); - gdv_fn_aes_decrypt(ctx_ptr, cipher.c_str(), cipher_len, key33.c_str(), key33_len, &decrypted_len); EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("invalid key length")); + gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher.c_str(), cipher_len, key33.c_str(), key33_len, mode.c_str(), mode_len, &decrypted_len); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Invalid key length for AES decryption")); ctx.Reset(); } + +// Tests for new mode-aware AES functions +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeEcb) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = "ECB"; + auto mode_len = static_cast(mode.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode.c_str(), mode_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_aes_decrypt_ecb( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, &decrypted_len); + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string invalid_mode = "CBC"; + auto invalid_mode_len = static_cast(invalid_mode.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + // Test encrypt with invalid mode + gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, + invalid_mode.c_str(), invalid_mode_len, &cipher_len); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("AES encryption mode mismatch")); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("CBC")); + ctx.Reset(); + + // Test decrypt with invalid mode + std::string cipher = "12345678abcdefgh12345678abcdefgh"; + auto cipher_len_val = static_cast(cipher.length()); + gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher.c_str(), cipher_len_val, key16.c_str(), key16_len, + invalid_mode.c_str(), invalid_mode_len, &decrypted_len); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("AES decryption mode mismatch")); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("CBC")); + ctx.Reset(); +} + } // namespace gandiva From 0cb4be42c6ab860162d70370aa87632290e7ee1e Mon Sep 17 00:00:00 2001 From: Tim Hurski Date: Fri, 14 Nov 2025 17:48:02 +0000 Subject: [PATCH 2/5] DX-108149: Add support for AES CBC encryption mode --- cpp/src/gandiva/CMakeLists.txt | 3 + cpp/src/gandiva/encrypt_utils_cbc.cc | 197 +++++++++++++++++++ cpp/src/gandiva/encrypt_utils_cbc.h | 67 +++++++ cpp/src/gandiva/encrypt_utils_cbc_test.cc | 191 ++++++++++++++++++ cpp/src/gandiva/encrypt_utils_common.cc | 33 ++++ cpp/src/gandiva/encrypt_utils_common.h | 33 ++++ cpp/src/gandiva/encrypt_utils_ecb.cc | 9 +- cpp/src/gandiva/function_registry_string.cc | 10 + cpp/src/gandiva/gdv_function_stubs.cc | 202 ++++++++++++++++++++ cpp/src/gandiva/gdv_function_stubs.h | 15 ++ cpp/src/gandiva/gdv_function_stubs_test.cc | 177 +++++++++++++++++ 11 files changed, 929 insertions(+), 8 deletions(-) create mode 100644 cpp/src/gandiva/encrypt_utils_cbc.cc create mode 100644 cpp/src/gandiva/encrypt_utils_cbc.h create mode 100644 cpp/src/gandiva/encrypt_utils_cbc_test.cc create mode 100644 cpp/src/gandiva/encrypt_utils_common.cc create mode 100644 cpp/src/gandiva/encrypt_utils_common.h diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 02c79f6782a..8d198416d90 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -56,7 +56,9 @@ set(SRC_FILES decimal_xlarge.cc engine.cc date_utils.cc + encrypt_utils_common.cc encrypt_utils_ecb.cc + encrypt_utils_cbc.cc expr_decomposer.cc expr_validator.cc expression.cc @@ -257,6 +259,7 @@ add_gandiva_test(internals-test annotator_test.cc tree_expr_test.cc encrypt_utils_ecb_test.cc + encrypt_utils_cbc_test.cc expr_decomposer_test.cc exported_funcs_registry_test.cc expression_registry_test.cc diff --git a/cpp/src/gandiva/encrypt_utils_cbc.cc b/cpp/src/gandiva/encrypt_utils_cbc.cc new file mode 100644 index 00000000000..15ac0413a75 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_cbc.cc @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_cbc.h" +#include "gandiva/encrypt_utils_common.h" +#include +#include +#include +#include +#include +#include + +namespace gandiva { + +namespace { + +// Padding mode enum +enum class PaddingMode { + PKCS7, + NONE +}; + +const EVP_CIPHER* get_cbc_cipher_algo(int32_t key_length) { + switch (key_length) { + case 16: + return EVP_aes_128_cbc(); + case 24: + return EVP_aes_192_cbc(); + case 32: + return EVP_aes_256_cbc(); + default: { + std::ostringstream oss; + oss << "Unsupported key length for AES-CBC: " << key_length + << " bytes. Supported lengths: 16, 24, 32 bytes"; + throw std::runtime_error(oss.str()); + } + } +} + +PaddingMode get_padding_mode(const char* padding_str, int32_t padding_len) { + if (padding_str == nullptr || padding_len <= 0) { + throw std::runtime_error("Invalid padding parameter: null or empty"); + } + + // Case-insensitive comparison using strncasecmp + if (strncasecmp(padding_str, "PKCS7", padding_len) == 0 && padding_len == 5) { + return PaddingMode::PKCS7; + } else if (strncasecmp(padding_str, "NONE", padding_len) == 0 && padding_len == 4) { + return PaddingMode::NONE; + } else { + std::ostringstream oss; + oss << "Invalid padding mode: '" << std::string(padding_str, padding_len) + << "'. Supported modes: PKCS7, NONE (case-insensitive)"; + throw std::runtime_error(oss.str()); + } +} + +} // namespace + +GANDIVA_EXPORT +int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + const char* padding, int32_t padding_len, unsigned char* cipher) { + // Validate IV length + if (iv_len != 16) { + std::ostringstream oss; + oss << "Invalid IV length for AES-CBC: " << iv_len + << " bytes. IV must be exactly 16 bytes"; + throw std::runtime_error(oss.str()); + } + + PaddingMode padding_mode = get_padding_mode(padding, padding_len); + + int32_t cipher_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_cbc_cipher_algo(key_len); + + if (!en_ctx) { + throw std::runtime_error("Could not create EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, + reinterpret_cast(key), + reinterpret_cast(iv))) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not initialize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + int padding_flag = (padding_mode == PaddingMode::PKCS7) ? 1 : 0; + if (!EVP_CIPHER_CTX_set_padding(en_ctx, padding_flag)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not set padding mode for encryption: " + + get_openssl_error_string()); + } + + if (!EVP_EncryptUpdate(en_ctx, cipher, &len, + reinterpret_cast(plaintext), + plaintext_len)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not update EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + if (!EVP_EncryptFinal_ex(en_ctx, cipher + len, &len)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not finalize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + EVP_CIPHER_CTX_free(en_ctx); + return cipher_len; +} + +GANDIVA_EXPORT +int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + const char* padding, int32_t padding_len, unsigned char* plaintext) { + // Validate IV length + if (iv_len != 16) { + std::ostringstream oss; + oss << "Invalid IV length for AES-CBC: " << iv_len + << " bytes. IV must be exactly 16 bytes"; + throw std::runtime_error(oss.str()); + } + + PaddingMode padding_mode = get_padding_mode(padding, padding_len); + + int32_t plaintext_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_cbc_cipher_algo(key_len); + + if (!de_ctx) { + throw std::runtime_error("Could not create EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, + reinterpret_cast(key), + reinterpret_cast(iv))) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not initialize EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + int padding_flag = (padding_mode == PaddingMode::PKCS7) ? 1 : 0; + if (!EVP_CIPHER_CTX_set_padding(de_ctx, padding_flag)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not set padding mode for decryption: " + + get_openssl_error_string()); + } + + if (!EVP_DecryptUpdate(de_ctx, plaintext, &len, + reinterpret_cast(ciphertext), + ciphertext_len)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not update EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + plaintext_len += len; + + if (!EVP_DecryptFinal_ex(de_ctx, plaintext + len, &len)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not finalize EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + plaintext_len += len; + + EVP_CIPHER_CTX_free(de_ctx); + return plaintext_len; +} + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_cbc.h b/cpp/src/gandiva/encrypt_utils_cbc.h new file mode 100644 index 00000000000..9b5bcaa6ab2 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_cbc.h @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include "gandiva/visibility.h" + +namespace gandiva { + +/** + * Encrypt data using AES-CBC algorithm with explicit padding mode + * + * @param plaintext The data to encrypt + * @param plaintext_len Length of plaintext in bytes + * @param key The encryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param iv The initialization vector (must be exactly 16 bytes) + * @param iv_len Length of IV in bytes (must be 16) + * @param padding Padding mode string: "PKCS7" or "NONE" (case-insensitive) + * @param padding_len Length of padding string in bytes + * @param cipher Output buffer for encrypted data + * @return Length of encrypted data in bytes + * @throws std::runtime_error on encryption failure or invalid parameters + */ +GANDIVA_EXPORT +int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + const char* padding, int32_t padding_len, unsigned char* cipher); + +/** + * Decrypt data using AES-CBC algorithm with explicit padding mode + * + * @param ciphertext The data to decrypt + * @param ciphertext_len Length of ciphertext in bytes + * @param key The decryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param iv The initialization vector (must be exactly 16 bytes) + * @param iv_len Length of IV in bytes (must be 16) + * @param padding Padding mode string: "PKCS7" or "NONE" (case-insensitive) + * @param padding_len Length of padding string in bytes + * @param plaintext Output buffer for decrypted data + * @return Length of decrypted data in bytes + * @throws std::runtime_error on decryption failure or invalid parameters + */ +GANDIVA_EXPORT +int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + const char* padding, int32_t padding_len, unsigned char* plaintext); + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_cbc_test.cc b/cpp/src/gandiva/encrypt_utils_cbc_test.cc new file mode 100644 index 00000000000..88e5403d120 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_cbc_test.cc @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_cbc.h" + +#include +#include + +// Test PKCS#7 padding with 16-byte key +TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_16) { + auto* key = "12345678abcdefgh"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "some test string"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + "PKCS7", 5, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test PKCS#7 padding with 24-byte key +TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_24) { + auto* key = "12345678abcdefgh12345678"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "some\ntest\nstring"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + "PKCS7", 5, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test PKCS#7 padding with 32-byte key +TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_32) { + auto* key = "12345678abcdefgh12345678abcdefgh"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "New\ntest\nstring"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + "PKCS7", 5, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test no-padding mode with block-aligned data (16 bytes) +TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptNoPadding_16) { + auto* key = "12345678abcdefgh"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "1234567890123456"; // Exactly 16 bytes + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "NONE", 4, cipher); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + "NONE", 4, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test case-insensitive padding mode +TEST(TestAesCbcEncryptUtils, TestCaseInsensitivePadding) { + auto* key = "12345678abcdefgh"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher1[64]; + unsigned char cipher2[64]; + + // Test with "pkcs7" (lowercase) + int32_t cipher1_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "pkcs7", 5, cipher1); + + // Test with "PKCS7" (uppercase) + int32_t cipher2_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher2); + + // Both should produce same ciphertext + EXPECT_EQ(cipher1_len, cipher2_len); + EXPECT_EQ(std::string(reinterpret_cast(cipher1), cipher1_len), + std::string(reinterpret_cast(cipher2), cipher2_len)); +} + +// Test invalid IV length +TEST(TestAesCbcEncryptUtils, TestInvalidIVLength) { + auto* key = "12345678abcdefgh"; + auto* iv = "short"; // Too short + auto* to_encrypt = "test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + ASSERT_THROW({ + gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher); + }, std::runtime_error); +} + +// Test invalid key length +TEST(TestAesCbcEncryptUtils, TestInvalidKeyLength) { + auto* key = "short"; // Too short + auto* iv = "1234567890123456"; + auto* to_encrypt = "test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + ASSERT_THROW({ + gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher); + }, std::runtime_error); +} + +// Test invalid padding mode +TEST(TestAesCbcEncryptUtils, TestInvalidPaddingMode) { + auto* key = "12345678abcdefgh"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + ASSERT_THROW({ + gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "INVALID", 7, cipher); + }, std::runtime_error); +} + diff --git a/cpp/src/gandiva/encrypt_utils_common.cc b/cpp/src/gandiva/encrypt_utils_common.cc new file mode 100644 index 00000000000..b1b7d172949 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_common.cc @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_common.h" +#include +#include + +namespace gandiva { + +std::string get_openssl_error_string() { + unsigned long error_code = ERR_get_error(); + if (error_code == 0) { + return "Unknown OpenSSL error"; + } + return std::string(ERR_reason_error_string(error_code)); +} + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_common.h b/cpp/src/gandiva/encrypt_utils_common.h new file mode 100644 index 00000000000..3887368b5a7 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_common.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef GANDIVA_ENCRYPT_UTILS_COMMON_H +#define GANDIVA_ENCRYPT_UTILS_COMMON_H + +#include + +namespace gandiva { + +/// @brief Get a human-readable error string from OpenSSL's error queue. +/// @return A string describing the most recent OpenSSL error, or "Unknown OpenSSL error" +/// if no error is available. +std::string get_openssl_error_string(); + +} // namespace gandiva + +#endif // GANDIVA_ENCRYPT_UTILS_COMMON_H + diff --git a/cpp/src/gandiva/encrypt_utils_ecb.cc b/cpp/src/gandiva/encrypt_utils_ecb.cc index 4d2a310121f..5cc13f335a7 100644 --- a/cpp/src/gandiva/encrypt_utils_ecb.cc +++ b/cpp/src/gandiva/encrypt_utils_ecb.cc @@ -16,6 +16,7 @@ // under the License. #include "gandiva/encrypt_utils_ecb.h" +#include "gandiva/encrypt_utils_common.h" #include #include #include @@ -26,14 +27,6 @@ namespace gandiva { namespace { -std::string get_openssl_error_string() { - unsigned long error_code = ERR_get_error(); - if (error_code == 0) { - return "Unknown OpenSSL error"; - } - return std::string(ERR_reason_error_string(error_code)); -} - const EVP_CIPHER* get_ecb_cipher_algo(int32_t key_length) { switch (key_length) { case 16: diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index fb662b623fb..ceca4089373 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -513,6 +513,16 @@ std::vector GetStringFunctionRegistry() { kResultNullIfNull, "gdv_fn_aes_decrypt_ecb", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + // CBC mode specific functions + // Binary-based signatures (BINARY, BINARY, BINARY, UTF8, UTF8) -> BINARY + NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), binary(), utf8(), utf8()}, binary(), + kResultNullIfNull, "gdv_fn_aes_encrypt_cbc", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), binary(), utf8(), utf8()}, binary(), + kResultNullIfNull, "gdv_fn_aes_decrypt_cbc", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + NativeFunction("mask_first_n", {}, DataTypeVector{utf8(), int32()}, utf8(), kResultNullIfNull, "gdv_mask_first_n_utf8_int32", NativeFunction::kNeedsContext), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index d2610c4ee92..402c1f2cca4 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -31,6 +31,7 @@ #include "arrow/util/value_parsing.h" #include "gandiva/encrypt_utils_ecb.h" +#include "gandiva/encrypt_utils_cbc.h" #include "gandiva/engine.h" #include "gandiva/exported_funcs.h" #include "gandiva/in_holder.h" @@ -584,6 +585,163 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int return result; } +// CBC mode specific functions - core implementation with explicit padding +GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, + const char* padding, int32_t padding_len, + int32_t* out_len) { + // Validate mode parameter + if (mode == nullptr) { + std::ostringstream oss; + oss << "Invalid mode parameter for AES encryption"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + std::string mode_str(mode, mode_len); + // Convert to uppercase for comparison + std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); + + if (mode_str != "CBC") { + std::ostringstream oss; + oss << "AES encryption mode mismatch: function signature indicates CBC mode, but '" + << mode_str << "' was provided instead"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (data_len < 0) { + std::ostringstream oss; + oss << "Invalid data length for AES encryption: " << data_len << " (must be >= 0)"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (key_data_len < 0 || (key_data_len != 16 && key_data_len != 24 && key_data_len != 32)) { + std::ostringstream oss; + oss << "Invalid key length for AES encryption: " << key_data_len + << " bytes. Supported lengths: 16, 24, 32 bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (iv_data_len != 16) { + std::ostringstream oss; + oss << "Invalid IV length for AES-CBC: " << iv_data_len + << " bytes. IV must be exactly 16 bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + // Allocate output buffer with padding overhead + int32_t max_out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len + 16), 16)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, max_out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output buffer"); + *out_len = 0; + return nullptr; + } + + try { + *out_len = gandiva::aes_encrypt_cbc(data, data_len, key_data, key_data_len, + iv_data, iv_data_len, padding, padding_len, + reinterpret_cast(ret)); + } catch (const std::runtime_error& e) { + gdv_fn_context_set_error_msg(context, e.what()); + *out_len = 0; + return nullptr; + } + + return ret; +} + +GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, + const char* padding, int32_t padding_len, + int32_t* out_len) { + // Validate mode parameter + if (mode == nullptr) { + std::ostringstream oss; + oss << "Invalid mode parameter for AES decryption"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + std::string mode_str(mode, mode_len); + // Convert to uppercase for comparison + std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); + + if (mode_str != "CBC") { + std::ostringstream oss; + oss << "AES decryption mode mismatch: function signature indicates CBC mode, but '" + << mode_str << "' was provided instead"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (data_len < 0) { + std::ostringstream oss; + oss << "Invalid data length for AES decryption: " << data_len << " (must be >= 0)"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (key_data_len < 0 || (key_data_len != 16 && key_data_len != 24 && key_data_len != 32)) { + std::ostringstream oss; + oss << "Invalid key length for AES decryption: " << key_data_len + << " bytes. Supported lengths: 16, 24, 32 bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (iv_data_len != 16) { + std::ostringstream oss; + oss << "Invalid IV length for AES-CBC: " << iv_data_len + << " bytes. IV must be exactly 16 bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + // Allocate output buffer + int32_t max_out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), 16)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, max_out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output buffer"); + *out_len = 0; + return nullptr; + } + + try { + *out_len = gandiva::aes_decrypt_cbc(data, data_len, key_data, key_data_len, + iv_data, iv_data_len, padding, padding_len, + reinterpret_cast(ret)); + } catch (const std::runtime_error& e) { + gdv_fn_context_set_error_msg(context, e.what()); + *out_len = 0; + return nullptr; + } + + return ret; +} + GANDIVA_EXPORT const char* gdv_mask_first_n_utf8_int32(int64_t context, const char* data, @@ -1293,6 +1451,50 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(gdv_fn_aes_decrypt_ecb_legacy)); + // gdv_fn_aes_encrypt_cbc + // Note: Mode and IV parameters are passed as binary strings (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_length + types->i8_ptr_type(), // key_data + types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (binary string) + types->i32_type(), // mode_length + types->i8_ptr_type(), // iv (binary string) + types->i32_type(), // iv_length + types->i8_ptr_type(), // padding (binary string) + types->i32_type(), // padding_length + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc("gdv_fn_aes_encrypt_cbc", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_encrypt_cbc)); + + // gdv_fn_aes_decrypt_cbc + // Note: Mode and IV parameters are passed as binary strings (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_length + types->i8_ptr_type(), // key_data + types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (binary string) + types->i32_type(), // mode_length + types->i8_ptr_type(), // iv (binary string) + types->i32_type(), // iv_length + types->i8_ptr_type(), // padding (binary string) + types->i32_type(), // padding_length + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc("gdv_fn_aes_decrypt_cbc", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_decrypt_cbc)); + // gdv_mask_first_n and gdv_mask_last_n std::vector mask_args = { types->i64_type(), // context diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 5c5452df280..324ecfb2c03 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -213,7 +213,22 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int const char* key_data, int32_t key_data_len, int32_t* out_len); +// CBC mode specific functions +GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, + const char* padding, int32_t padding_len, + int32_t* out_len); +GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, + const char* padding, int32_t padding_len, + int32_t* out_len); GANDIVA_EXPORT const char* gdv_mask_first_n_utf8_int32(int64_t context, const char* data, diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index da785d1a511..55572880bed 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1480,4 +1480,181 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { ctx.Reset(); } +// Tests for CBC mode encryption/decryption +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbc) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding = "PKCS7"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode.c_str(), mode_len, iv.c_str(), + iv_len, padding.c_str(), padding_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_aes_decrypt_cbc( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, + iv.c_str(), iv_len, padding.c_str(), padding_len, &decrypted_len); + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcNoPadding) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "1234567890123456"; // Exactly 16 bytes + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding = "NONE"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode.c_str(), mode_len, iv.c_str(), + iv_len, padding.c_str(), padding_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_aes_decrypt_cbc( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, + iv.c_str(), iv_len, padding.c_str(), padding_len, &decrypted_len); + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcCaseInsensitive) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len1 = 0; + int32_t cipher_len2 = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding_upper = "PKCS7"; + auto padding_upper_len = static_cast(padding_upper.length()); + std::string padding_lower = "pkcs7"; + auto padding_lower_len = static_cast(padding_lower.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher1 = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode.c_str(), mode_len, iv.c_str(), + iv_len, padding_upper.c_str(), padding_upper_len, + &cipher_len1); + const char* cipher2 = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode.c_str(), mode_len, iv.c_str(), + iv_len, padding_lower.c_str(), padding_lower_len, + &cipher_len2); + + // Both should produce same ciphertext + EXPECT_EQ(cipher_len1, cipher_len2); + EXPECT_EQ(std::string(cipher1, cipher_len1), std::string(cipher2, cipher_len2)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidIV) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "short"; // Too short + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding = "PKCS7"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, + mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), + padding_len, &cipher_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid IV length")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidKey) { + gandiva::ExecutionContext ctx; + std::string key = "short"; // Too short + auto key_len = static_cast(key.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding = "PKCS7"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key.c_str(), key_len, + mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), + padding_len, &cipher_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid key length")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidPadding) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding = "INVALID"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, + mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), + padding_len, &cipher_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid padding mode")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcModeValidation) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string invalid_mode = "ECB"; + auto invalid_mode_len = static_cast(invalid_mode.length()); + std::string padding = "PKCS7"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + // Test encrypt with invalid mode + gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, + invalid_mode.c_str(), invalid_mode_len, iv.c_str(), iv_len, + padding.c_str(), padding_len, &cipher_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("AES encryption mode mismatch")); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("ECB")); + ctx.Reset(); +} + } // namespace gandiva From f570415bd831b9e4332ff1d8ab1e4d2bacc4a235 Mon Sep 17 00:00:00 2001 From: Tim Hurski Date: Fri, 14 Nov 2025 19:45:12 +0000 Subject: [PATCH 3/5] DX-108149: Add support for GCM encryption mode --- cpp/src/gandiva/CMakeLists.txt | 4 +- cpp/src/gandiva/encrypt_utils_gcm.cc | 303 ++++++++++++++++++++ cpp/src/gandiva/encrypt_utils_gcm.h | 114 ++++++++ cpp/src/gandiva/encrypt_utils_gcm_test.cc | 197 +++++++++++++ cpp/src/gandiva/function_registry_string.cc | 21 ++ cpp/src/gandiva/function_registry_test.cc | 37 +++ cpp/src/gandiva/gdv_function_stubs.cc | 232 +++++++++++---- cpp/src/gandiva/gdv_function_stubs.h | 25 +- cpp/src/gandiva/gdv_function_stubs_test.cc | 210 +++++++------- 9 files changed, 975 insertions(+), 168 deletions(-) create mode 100644 cpp/src/gandiva/encrypt_utils_gcm.cc create mode 100644 cpp/src/gandiva/encrypt_utils_gcm.h create mode 100644 cpp/src/gandiva/encrypt_utils_gcm_test.cc diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 8d198416d90..c66e25b52fd 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -56,9 +56,10 @@ set(SRC_FILES decimal_xlarge.cc engine.cc date_utils.cc - encrypt_utils_common.cc encrypt_utils_ecb.cc + encrypt_utils_common.cc encrypt_utils_cbc.cc + encrypt_utils_gcm.cc expr_decomposer.cc expr_validator.cc expression.cc @@ -260,6 +261,7 @@ add_gandiva_test(internals-test tree_expr_test.cc encrypt_utils_ecb_test.cc encrypt_utils_cbc_test.cc + encrypt_utils_gcm_test.cc expr_decomposer_test.cc exported_funcs_registry_test.cc expression_registry_test.cc diff --git a/cpp/src/gandiva/encrypt_utils_gcm.cc b/cpp/src/gandiva/encrypt_utils_gcm.cc new file mode 100644 index 00000000000..08a8e417b1e --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_gcm.cc @@ -0,0 +1,303 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_gcm.h" +#include "gandiva/encrypt_utils_common.h" +#include +#include +#include +#include +#include + +namespace gandiva { + +namespace { + +const EVP_CIPHER* get_gcm_cipher_algo(int32_t key_length) { + switch (key_length) { + case 16: + return EVP_aes_128_gcm(); + case 24: + return EVP_aes_192_gcm(); + case 32: + return EVP_aes_256_gcm(); + default: { + std::ostringstream oss; + oss << "Unsupported key length for AES-GCM: " << key_length + << " bytes. Supported lengths: 16, 24, 32 bytes"; + throw std::runtime_error(oss.str()); + } + } +} + +void validate_gcm_iv(int32_t iv_len) { + if (iv_len != 12) { + std::ostringstream oss; + oss << "Invalid IV length for AES-GCM: " << iv_len + << " bytes. IV must be exactly 12 bytes"; + throw std::runtime_error(oss.str()); + } +} + +void validate_gcm_tag_length(int32_t tag_len) { + if (tag_len < 4 || tag_len > 16) { + std::ostringstream oss; + oss << "Invalid tag length for AES-GCM: " << tag_len + << " bytes. Tag length must be between 4 and 16 bytes"; + throw std::runtime_error(oss.str()); + } +} + +} // namespace + +GANDIVA_EXPORT +int32_t aes_encrypt_gcm(const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + unsigned char* cipher, unsigned char* tag, int32_t tag_len) { + validate_gcm_iv(iv_len); + validate_gcm_tag_length(tag_len); + + int32_t cipher_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_gcm_cipher_algo(key_len); + + if (!en_ctx) { + throw std::runtime_error("Could not create EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, + reinterpret_cast(key), + reinterpret_cast(iv))) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not initialize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + if (!EVP_EncryptUpdate(en_ctx, cipher, &len, + reinterpret_cast(plaintext), + plaintext_len)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not update EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + if (!EVP_EncryptFinal_ex(en_ctx, cipher + len, &len)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not finalize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + if (!EVP_CIPHER_CTX_ctrl(en_ctx, EVP_CTRL_GCM_GET_TAG, tag_len, tag)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not extract GCM tag: " + get_openssl_error_string()); + } + + EVP_CIPHER_CTX_free(en_ctx); + return cipher_len; +} + +GANDIVA_EXPORT +int32_t aes_encrypt_gcm_with_aad(const char* plaintext, int32_t plaintext_len, + const char* key, int32_t key_len, const char* iv, + int32_t iv_len, const char* aad, int32_t aad_len, + unsigned char* cipher, unsigned char* tag, + int32_t tag_len) { + validate_gcm_iv(iv_len); + validate_gcm_tag_length(tag_len); + + int32_t cipher_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_gcm_cipher_algo(key_len); + + if (!en_ctx) { + throw std::runtime_error("Could not create EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, + reinterpret_cast(key), + reinterpret_cast(iv))) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not initialize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + // Set AAD if provided + if (aad != nullptr && aad_len > 0) { + if (!EVP_EncryptUpdate(en_ctx, nullptr, &len, + reinterpret_cast(aad), aad_len)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not set AAD for encryption: " + + get_openssl_error_string()); + } + } + + if (!EVP_EncryptUpdate(en_ctx, cipher, &len, + reinterpret_cast(plaintext), + plaintext_len)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not update EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + if (!EVP_EncryptFinal_ex(en_ctx, cipher + len, &len)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not finalize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + if (!EVP_CIPHER_CTX_ctrl(en_ctx, EVP_CTRL_GCM_GET_TAG, tag_len, tag)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not extract GCM tag: " + get_openssl_error_string()); + } + + EVP_CIPHER_CTX_free(en_ctx); + return cipher_len; +} + +GANDIVA_EXPORT +int32_t aes_decrypt_gcm(const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, const char* tag, + int32_t tag_len, unsigned char* plaintext) { + validate_gcm_iv(iv_len); + validate_gcm_tag_length(tag_len); + + int32_t plaintext_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_gcm_cipher_algo(key_len); + + if (!de_ctx) { + throw std::runtime_error("Could not create EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, + reinterpret_cast(key), + reinterpret_cast(iv))) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not initialize EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + // Set the tag for verification + if (!EVP_CIPHER_CTX_ctrl(de_ctx, EVP_CTRL_GCM_SET_TAG, tag_len, + const_cast(tag))) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not set GCM tag for verification: " + + get_openssl_error_string()); + } + + if (!EVP_DecryptUpdate(de_ctx, plaintext, &len, + reinterpret_cast(ciphertext), + ciphertext_len)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not update EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + plaintext_len += len; + + if (!EVP_DecryptFinal_ex(de_ctx, plaintext + len, &len)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("GCM tag verification failed: " + get_openssl_error_string()); + } + + plaintext_len += len; + + EVP_CIPHER_CTX_free(de_ctx); + return plaintext_len; +} + +GANDIVA_EXPORT +int32_t aes_decrypt_gcm_with_aad(const char* ciphertext, int32_t ciphertext_len, + const char* key, int32_t key_len, const char* iv, + int32_t iv_len, const char* aad, int32_t aad_len, + const char* tag, int32_t tag_len, unsigned char* plaintext) { + validate_gcm_iv(iv_len); + validate_gcm_tag_length(tag_len); + + int32_t plaintext_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_gcm_cipher_algo(key_len); + + if (!de_ctx) { + throw std::runtime_error("Could not create EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, + reinterpret_cast(key), + reinterpret_cast(iv))) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not initialize EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + // Set AAD if provided + if (aad != nullptr && aad_len > 0) { + if (!EVP_DecryptUpdate(de_ctx, nullptr, &len, + reinterpret_cast(aad), aad_len)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not set AAD for decryption: " + + get_openssl_error_string()); + } + } + + // Set the tag for verification + if (!EVP_CIPHER_CTX_ctrl(de_ctx, EVP_CTRL_GCM_SET_TAG, tag_len, + const_cast(tag))) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not set GCM tag for verification: " + + get_openssl_error_string()); + } + + if (!EVP_DecryptUpdate(de_ctx, plaintext, &len, + reinterpret_cast(ciphertext), + ciphertext_len)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not update EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + plaintext_len += len; + + if (!EVP_DecryptFinal_ex(de_ctx, plaintext + len, &len)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("GCM tag verification failed: " + get_openssl_error_string()); + } + + plaintext_len += len; + + EVP_CIPHER_CTX_free(de_ctx); + return plaintext_len; +} + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_gcm.h b/cpp/src/gandiva/encrypt_utils_gcm.h new file mode 100644 index 00000000000..2752069b59b --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_gcm.h @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include "gandiva/visibility.h" + +namespace gandiva { + +/** + * Encrypt data using AES-GCM algorithm without AAD + * + * @param plaintext The data to encrypt + * @param plaintext_len Length of plaintext in bytes + * @param key The encryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param iv The initialization vector (must be exactly 12 bytes for GCM) + * @param iv_len Length of IV in bytes (must be 12) + * @param cipher Output buffer for encrypted data + * @param tag Output buffer for authentication tag (typically 16 bytes) + * @param tag_len Length of tag in bytes (4-16, typically 16) + * @return Length of encrypted data in bytes + * @throws std::runtime_error on encryption failure or invalid parameters + */ +GANDIVA_EXPORT +int32_t aes_encrypt_gcm(const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + unsigned char* cipher, unsigned char* tag, int32_t tag_len); + +/** + * Encrypt data using AES-GCM algorithm with AAD + * + * @param plaintext The data to encrypt + * @param plaintext_len Length of plaintext in bytes + * @param key The encryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param iv The initialization vector (must be exactly 12 bytes for GCM) + * @param iv_len Length of IV in bytes (must be 12) + * @param aad Additional authenticated data + * @param aad_len Length of AAD in bytes + * @param cipher Output buffer for encrypted data + * @param tag Output buffer for authentication tag (typically 16 bytes) + * @param tag_len Length of tag in bytes (4-16, typically 16) + * @return Length of encrypted data in bytes + * @throws std::runtime_error on encryption failure or invalid parameters + */ +GANDIVA_EXPORT +int32_t aes_encrypt_gcm_with_aad(const char* plaintext, int32_t plaintext_len, + const char* key, int32_t key_len, const char* iv, + int32_t iv_len, const char* aad, int32_t aad_len, + unsigned char* cipher, unsigned char* tag, + int32_t tag_len); + +/** + * Decrypt data using AES-GCM algorithm without AAD + * + * @param ciphertext The data to decrypt + * @param ciphertext_len Length of ciphertext in bytes + * @param key The decryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param iv The initialization vector (must be exactly 12 bytes for GCM) + * @param iv_len Length of IV in bytes (must be 12) + * @param tag The authentication tag to verify + * @param tag_len Length of tag in bytes (4-16, typically 16) + * @param plaintext Output buffer for decrypted data + * @return Length of decrypted data in bytes + * @throws std::runtime_error on decryption failure, tag verification failure, or invalid parameters + */ +GANDIVA_EXPORT +int32_t aes_decrypt_gcm(const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, const char* tag, + int32_t tag_len, unsigned char* plaintext); + +/** + * Decrypt data using AES-GCM algorithm with AAD + * + * @param ciphertext The data to decrypt + * @param ciphertext_len Length of ciphertext in bytes + * @param key The decryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param iv The initialization vector (must be exactly 12 bytes for GCM) + * @param iv_len Length of IV in bytes (must be 12) + * @param aad Additional authenticated data + * @param aad_len Length of AAD in bytes + * @param tag The authentication tag to verify + * @param tag_len Length of tag in bytes (4-16, typically 16) + * @param plaintext Output buffer for decrypted data + * @return Length of decrypted data in bytes + * @throws std::runtime_error on decryption failure, tag verification failure, or invalid parameters + */ +GANDIVA_EXPORT +int32_t aes_decrypt_gcm_with_aad(const char* ciphertext, int32_t ciphertext_len, + const char* key, int32_t key_len, const char* iv, + int32_t iv_len, const char* aad, int32_t aad_len, + const char* tag, int32_t tag_len, unsigned char* plaintext); + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_gcm_test.cc b/cpp/src/gandiva/encrypt_utils_gcm_test.cc new file mode 100644 index 00000000000..823ff388574 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_gcm_test.cc @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_gcm.h" + +#include +#include + +// Test encryption/decryption without AAD with 16-byte key +TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptNoAad_16) { + auto* key = "12345678abcdefgh"; + auto* iv = "123456789012"; // 12 bytes for GCM + auto* to_encrypt = "some test string"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + unsigned char tag[16]; + + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, cipher, tag, 16); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm( + reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, + reinterpret_cast(tag), 16, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test encryption/decryption without AAD with 24-byte key +TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptNoAad_24) { + auto* key = "12345678abcdefgh12345678"; + auto* iv = "123456789012"; + auto* to_encrypt = "some\ntest\nstring"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + unsigned char tag[16]; + + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, cipher, tag, 16); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm( + reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, + reinterpret_cast(tag), 16, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test encryption/decryption without AAD with 32-byte key +TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptNoAad_32) { + auto* key = "12345678abcdefgh12345678abcdefgh"; + auto* iv = "123456789012"; + auto* to_encrypt = "New\ntest\nstring"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + unsigned char tag[16]; + + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, cipher, tag, 16); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm( + reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, + reinterpret_cast(tag), 16, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test encryption/decryption with AAD +TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptWithAad) { + auto* key = "12345678abcdefgh"; + auto* iv = "123456789012"; + auto* to_encrypt = "secret message"; + auto* aad = "additional data"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + auto aad_len = static_cast(strlen(aad)); + unsigned char cipher[64]; + unsigned char tag[16]; + + int32_t cipher_len = gandiva::aes_encrypt_gcm_with_aad( + to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, aad, aad_len, cipher, tag, 16); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm_with_aad( + reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, aad, + aad_len, reinterpret_cast(tag), 16, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test invalid IV length (not 12 bytes) +TEST(TestAesGcmEncryptUtils, TestInvalidIVLength) { + auto* key = "12345678abcdefgh"; + auto* iv = "short"; // Too short + auto* to_encrypt = "test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + unsigned char tag[16]; + + ASSERT_THROW( + { gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, + cipher, tag, 16); }, + std::runtime_error); +} + +// Test invalid key length +TEST(TestAesGcmEncryptUtils, TestInvalidKeyLength) { + auto* key = "short"; // Invalid key length + auto* iv = "123456789012"; + auto* to_encrypt = "test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + unsigned char tag[16]; + + ASSERT_THROW( + { gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, + cipher, tag, 16); }, + std::runtime_error); +} + +// Test invalid tag length +TEST(TestAesGcmEncryptUtils, TestInvalidTagLength) { + auto* key = "12345678abcdefgh"; + auto* iv = "123456789012"; + auto* to_encrypt = "test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + unsigned char tag[32]; + + ASSERT_THROW( + { gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, iv, iv_len, + cipher, tag, 32); }, // Tag too long + std::runtime_error); +} + +// Test empty plaintext +TEST(TestAesGcmEncryptUtils, TestEmptyPlaintext) { + auto* key = "12345678abcdefgh"; + auto* iv = "123456789012"; + auto* to_encrypt = ""; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = 0; + unsigned char cipher[64]; + unsigned char tag[16]; + + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, cipher, tag, 16); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm( + reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, + reinterpret_cast(tag), 16, decrypted); + + EXPECT_EQ(0, decrypted_len); +} + diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index ceca4089373..16f8977e7d6 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -523,6 +523,27 @@ std::vector GetStringFunctionRegistry() { kResultNullIfNull, "gdv_fn_aes_decrypt_cbc", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + // GCM mode specific functions + // Binary-based signatures (BINARY, BINARY, UTF8, BINARY) -> BINARY + NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(), + kResultNullIfNull, "gdv_fn_aes_encrypt_gcm", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + // Binary-based signatures (BINARY, BINARY, UTF8, BINARY, BINARY) -> BINARY (with AAD) + NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(), + kResultNullIfNull, "gdv_fn_aes_encrypt_gcm", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + // Binary-based signatures (BINARY, BINARY, UTF8, BINARY, INT32) -> BINARY + NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), int32()}, binary(), + kResultNullIfNull, "gdv_fn_aes_decrypt_gcm", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + // Binary-based signatures (BINARY, BINARY, UTF8, BINARY, INT32, BINARY) -> BINARY (with AAD) + NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), int32(), binary()}, binary(), + kResultNullIfNull, "gdv_fn_aes_decrypt_gcm", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + NativeFunction("mask_first_n", {}, DataTypeVector{utf8(), int32()}, utf8(), kResultNullIfNull, "gdv_mask_first_n_utf8_int32", NativeFunction::kNeedsContext), diff --git a/cpp/src/gandiva/function_registry_test.cc b/cpp/src/gandiva/function_registry_test.cc index bbe72c0ee97..a92c3037af9 100644 --- a/cpp/src/gandiva/function_registry_test.cc +++ b/cpp/src/gandiva/function_registry_test.cc @@ -122,4 +122,41 @@ TEST_F(TestFunctionRegistry, TestNoDuplicates) { "different precompiled functions:\n" << stream.str(); } + +// Test that GCM mode function signatures are registered +TEST_F(TestFunctionRegistry, TestAesEncryptGcmSignatures) { + // AES_ENCRYPT(BINARY, BINARY, UTF8, BINARY) -> BINARY (without AAD) + FunctionSignature aes_encrypt_gcm_no_aad("aes_encrypt", + {arrow::binary(), arrow::binary(), arrow::utf8(), arrow::binary()}, + arrow::binary()); + const NativeFunction* function = registry_->LookupSignature(aes_encrypt_gcm_no_aad); + EXPECT_NE(function, nullptr) << "AES_ENCRYPT(BINARY, BINARY, UTF8, BINARY) not found"; + EXPECT_EQ(function->pc_name(), "gdv_fn_aes_encrypt_gcm"); + + // AES_ENCRYPT(BINARY, BINARY, UTF8, BINARY, BINARY) -> BINARY (with AAD) + FunctionSignature aes_encrypt_gcm_with_aad("aes_encrypt", + {arrow::binary(), arrow::binary(), arrow::utf8(), arrow::binary(), arrow::binary()}, + arrow::binary()); + function = registry_->LookupSignature(aes_encrypt_gcm_with_aad); + EXPECT_NE(function, nullptr) << "AES_ENCRYPT(BINARY, BINARY, UTF8, BINARY, BINARY) not found"; + EXPECT_EQ(function->pc_name(), "gdv_fn_aes_encrypt_gcm"); +} + +TEST_F(TestFunctionRegistry, TestAesDecryptGcmSignatures) { + // AES_DECRYPT(BINARY, BINARY, UTF8, BINARY, INT32) -> BINARY (without AAD) + FunctionSignature aes_decrypt_gcm_no_aad("aes_decrypt", + {arrow::binary(), arrow::binary(), arrow::utf8(), arrow::binary(), arrow::int32()}, + arrow::binary()); + const NativeFunction* function = registry_->LookupSignature(aes_decrypt_gcm_no_aad); + EXPECT_NE(function, nullptr) << "AES_DECRYPT(BINARY, BINARY, UTF8, BINARY, INT32) not found"; + EXPECT_EQ(function->pc_name(), "gdv_fn_aes_decrypt_gcm"); + + // AES_DECRYPT(BINARY, BINARY, UTF8, BINARY, INT32, BINARY) -> BINARY (with AAD) + FunctionSignature aes_decrypt_gcm_with_aad("aes_decrypt", + {arrow::binary(), arrow::binary(), arrow::utf8(), arrow::binary(), arrow::int32(), arrow::binary()}, + arrow::binary()); + function = registry_->LookupSignature(aes_decrypt_gcm_with_aad); + EXPECT_NE(function, nullptr) << "AES_DECRYPT(BINARY, BINARY, UTF8, BINARY, INT32, BINARY) not found"; + EXPECT_EQ(function->pc_name(), "gdv_fn_aes_decrypt_gcm"); +} } // namespace gandiva diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 402c1f2cca4..b4edcacd5d4 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -32,6 +32,7 @@ #include "gandiva/encrypt_utils_ecb.h" #include "gandiva/encrypt_utils_cbc.h" +#include "gandiva/encrypt_utils_gcm.h" #include "gandiva/engine.h" #include "gandiva/exported_funcs.h" #include "gandiva/in_holder.h" @@ -585,13 +586,71 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int return result; } -// CBC mode specific functions - core implementation with explicit padding +// CBC mode specific functions - core implementation +// This handles both string and binary inputs (they have the same C signature) GANDIVA_EXPORT const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, - const char* padding, int32_t padding_len, + const char* iv, int32_t iv_len, const char* padding, + int32_t padding_len, int32_t* out_len) { + // Allocate output buffer (max size: input + 16 bytes for padding) + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, data_len + 16)); + if (ret == nullptr) { + std::ostringstream oss; + oss << "Could not allocate memory for AES-CBC encryption"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + try { + *out_len = gandiva::aes_encrypt_cbc(data, data_len, key_data, key_data_len, iv, iv_len, + padding, padding_len, + reinterpret_cast(ret)); + } catch (const std::runtime_error& e) { + gdv_fn_context_set_error_msg(context, e.what()); + *out_len = 0; + return nullptr; + } + + return ret; +} + +GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* iv, int32_t iv_len, const char* padding, + int32_t padding_len, int32_t* out_len) { + // Allocate output buffer (max size: input size, since decryption removes padding) + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, data_len)); + if (ret == nullptr) { + std::ostringstream oss; + oss << "Could not allocate memory for AES-CBC decryption"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + try { + *out_len = gandiva::aes_decrypt_cbc(data, data_len, key_data, key_data_len, iv, iv_len, + padding, padding_len, + reinterpret_cast(ret)); + } catch (const std::runtime_error& e) { + gdv_fn_context_set_error_msg(context, e.what()); + *out_len = 0; + return nullptr; + } + + return ret; +} + +// GCM mode specific functions - core implementation +// This handles both string and binary inputs (they have the same C signature) +GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_gcm(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, const char* iv, + int32_t iv_len, const char* aad, int32_t aad_len, int32_t* out_len) { // Validate mode parameter if (mode == nullptr) { @@ -606,9 +665,9 @@ const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t da // Convert to uppercase for comparison std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); - if (mode_str != "CBC") { + if (mode_str != "GCM") { std::ostringstream oss; - oss << "AES encryption mode mismatch: function signature indicates CBC mode, but '" + oss << "AES encryption mode mismatch: function signature indicates GCM mode, but '" << mode_str << "' was provided instead"; gdv_fn_context_set_error_msg(context, oss.str().c_str()); *out_len = 0; @@ -623,38 +682,43 @@ const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t da return ""; } - if (key_data_len < 0 || (key_data_len != 16 && key_data_len != 24 && key_data_len != 32)) { + if (iv_len < 0) { std::ostringstream oss; - oss << "Invalid key length for AES encryption: " << key_data_len - << " bytes. Supported lengths: 16, 24, 32 bytes"; + oss << "Invalid IV length for AES encryption: " << iv_len << " (must be >= 0)"; gdv_fn_context_set_error_msg(context, oss.str().c_str()); *out_len = 0; return ""; } - if (iv_data_len != 16) { + // For GCM, output is ciphertext + tag (typically 16 bytes) + // Allocate space for ciphertext + 16-byte tag + int64_t kGcmTagSize = 16; + *out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len + kGcmTagSize), 16)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { std::ostringstream oss; - oss << "Invalid IV length for AES-CBC: " << iv_data_len - << " bytes. IV must be exactly 16 bytes"; + oss << "Could not allocate memory for AES encryption output: " << *out_len << " bytes"; gdv_fn_context_set_error_msg(context, oss.str().c_str()); *out_len = 0; - return ""; - } - - // Allocate output buffer with padding overhead - int32_t max_out_len = static_cast( - arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len + 16), 16)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, max_out_len)); - if (ret == nullptr) { - gdv_fn_context_set_error_msg(context, "Could not allocate memory for output buffer"); - *out_len = 0; return nullptr; } try { - *out_len = gandiva::aes_encrypt_cbc(data, data_len, key_data, key_data_len, - iv_data, iv_data_len, padding, padding_len, - reinterpret_cast(ret)); + unsigned char* cipher = reinterpret_cast(ret); + unsigned char* tag = cipher + data_len; + + int32_t cipher_len = 0; + if (aad != nullptr && aad_len > 0) { + cipher_len = gandiva::aes_encrypt_gcm_with_aad( + data, data_len, key_data, key_data_len, iv, iv_len, aad, aad_len, cipher, tag, + kGcmTagSize); + } else { + cipher_len = gandiva::aes_encrypt_gcm(data, data_len, key_data, key_data_len, iv, + iv_len, cipher, tag, kGcmTagSize); + } + + *out_len = cipher_len + kGcmTagSize; } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; @@ -665,12 +729,11 @@ const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t da } GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t data_len, +const char* gdv_fn_aes_decrypt_gcm(int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, - const char* padding, int32_t padding_len, - int32_t* out_len) { + const char* mode, int32_t mode_len, const char* iv, + int32_t iv_len, int32_t tag_length, const char* aad, + int32_t aad_len, int32_t* out_len) { // Validate mode parameter if (mode == nullptr) { std::ostringstream oss; @@ -684,9 +747,9 @@ const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t da // Convert to uppercase for comparison std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); - if (mode_str != "CBC") { + if (mode_str != "GCM") { std::ostringstream oss; - oss << "AES decryption mode mismatch: function signature indicates CBC mode, but '" + oss << "AES decryption mode mismatch: function signature indicates GCM mode, but '" << mode_str << "' was provided instead"; gdv_fn_context_set_error_msg(context, oss.str().c_str()); *out_len = 0; @@ -701,38 +764,50 @@ const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t da return ""; } - if (key_data_len < 0 || (key_data_len != 16 && key_data_len != 24 && key_data_len != 32)) { + if (iv_len < 0) { std::ostringstream oss; - oss << "Invalid key length for AES decryption: " << key_data_len - << " bytes. Supported lengths: 16, 24, 32 bytes"; + oss << "Invalid IV length for AES decryption: " << iv_len << " (must be >= 0)"; gdv_fn_context_set_error_msg(context, oss.str().c_str()); *out_len = 0; return ""; } - if (iv_data_len != 16) { + // For GCM, input is ciphertext + tag + // Ciphertext length is data_len - tag_length + int32_t cipher_len = data_len - tag_length; + if (cipher_len < 0) { std::ostringstream oss; - oss << "Invalid IV length for AES-CBC: " << iv_data_len - << " bytes. IV must be exactly 16 bytes"; + oss << "Invalid tag length for AES decryption: " << tag_length + << " (must be <= data length " << data_len << ")"; gdv_fn_context_set_error_msg(context, oss.str().c_str()); *out_len = 0; return ""; } - // Allocate output buffer - int32_t max_out_len = static_cast( - arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), 16)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, max_out_len)); + // Allocate space for plaintext + *out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(cipher_len), 16)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); if (ret == nullptr) { - gdv_fn_context_set_error_msg(context, "Could not allocate memory for output buffer"); + std::ostringstream oss; + oss << "Could not allocate memory for AES decryption output: " << *out_len << " bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); *out_len = 0; return nullptr; } try { - *out_len = gandiva::aes_decrypt_cbc(data, data_len, key_data, key_data_len, - iv_data, iv_data_len, padding, padding_len, - reinterpret_cast(ret)); + const char* tag = data + cipher_len; + + if (aad != nullptr && aad_len > 0) { + *out_len = gandiva::aes_decrypt_gcm_with_aad( + data, cipher_len, key_data, key_data_len, iv, iv_len, aad, aad_len, tag, + tag_length, reinterpret_cast(ret)); + } else { + *out_len = gandiva::aes_decrypt_gcm(data, cipher_len, key_data, key_data_len, iv, + iv_len, tag, tag_length, + reinterpret_cast(ret)); + } } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; @@ -1452,19 +1527,17 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { reinterpret_cast(gdv_fn_aes_decrypt_ecb_legacy)); // gdv_fn_aes_encrypt_cbc - // Note: Mode and IV parameters are passed as binary strings (data + length) - // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) + // Note: The IV and padding parameters are passed as binary/UTF8 strings (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, iv, iv_len, padding, padding_len, out_len) args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length - types->i8_ptr_type(), // mode (binary string) - types->i32_type(), // mode_length - types->i8_ptr_type(), // iv (binary string) + types->i8_ptr_type(), // iv types->i32_type(), // iv_length - types->i8_ptr_type(), // padding (binary string) + types->i8_ptr_type(), // padding types->i32_type(), // padding_length types->i32_ptr_type() // out_length }; @@ -1474,19 +1547,17 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { reinterpret_cast(gdv_fn_aes_encrypt_cbc)); // gdv_fn_aes_decrypt_cbc - // Note: Mode and IV parameters are passed as binary strings (data + length) - // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) + // Note: The IV and padding parameters are passed as binary/UTF8 strings (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, iv, iv_len, padding, padding_len, out_len) args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length - types->i8_ptr_type(), // mode (binary string) - types->i32_type(), // mode_length - types->i8_ptr_type(), // iv (binary string) + types->i8_ptr_type(), // iv types->i32_type(), // iv_length - types->i8_ptr_type(), // padding (binary string) + types->i8_ptr_type(), // padding types->i32_type(), // padding_length types->i32_ptr_type() // out_length }; @@ -1495,6 +1566,51 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(gdv_fn_aes_decrypt_cbc)); + // gdv_fn_aes_encrypt_gcm + // Note: The mode parameter is passed as a UTF8 string (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, aad, aad_len, out_len) + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_length + types->i8_ptr_type(), // key_data + types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (UTF8 string) + types->i32_type(), // mode_length + types->i8_ptr_type(), // iv + types->i32_type(), // iv_length + types->i8_ptr_type(), // aad + types->i32_type(), // aad_length + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc("gdv_fn_aes_encrypt_gcm", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_encrypt_gcm)); + + // gdv_fn_aes_decrypt_gcm + // Note: The mode parameter is passed as a UTF8 string (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, tag_length, aad, aad_len, out_len) + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data (ciphertext + tag) + types->i32_type(), // data_length + types->i8_ptr_type(), // key_data + types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (UTF8 string) + types->i32_type(), // mode_length + types->i8_ptr_type(), // iv + types->i32_type(), // iv_length + types->i32_type(), // tag_length + types->i8_ptr_type(), // aad + types->i32_type(), // aad_length + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc("gdv_fn_aes_decrypt_gcm", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_decrypt_gcm)); + // gdv_mask_first_n and gdv_mask_last_n std::vector mask_args = { types->i64_type(), // context diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 324ecfb2c03..d471b9103ef 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -217,19 +217,30 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int GANDIVA_EXPORT const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, - const char* padding, int32_t padding_len, - int32_t* out_len); + const char* iv, int32_t iv_len, const char* padding, + int32_t padding_len, int32_t* out_len); GANDIVA_EXPORT const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, - const char* padding, int32_t padding_len, + const char* iv, int32_t iv_len, const char* padding, + int32_t padding_len, int32_t* out_len); + +// GCM mode specific functions +GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_gcm(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, const char* iv, + int32_t iv_len, const char* aad, int32_t aad_len, int32_t* out_len); +GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_gcm(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, const char* iv, + int32_t iv_len, int32_t tag_length, const char* aad, + int32_t aad_len, int32_t* out_len); + GANDIVA_EXPORT const char* gdv_mask_first_n_utf8_int32(int64_t context, const char* data, int32_t data_len, int32_t n_to_mask, diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index 55572880bed..f13d0861558 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1480,181 +1480,187 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { ctx.Reset(); } -// Tests for CBC mode encryption/decryption -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbc) { +// Tests for GCM mode AES functions +TEST(TestGdvFnStubs, TestAesEncryptDecryptGcm16) { gandiva::ExecutionContext ctx; std::string key16 = "12345678abcdefgh"; auto key16_len = static_cast(key16.length()); - std::string iv = "1234567890123456"; + std::string iv = "123456789012"; // 12 bytes for GCM auto iv_len = static_cast(iv.length()); int32_t cipher_len = 0; int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "CBC"; + std::string mode = "GCM"; auto mode_len = static_cast(mode.length()); - std::string padding = "PKCS7"; - auto padding_len = static_cast(padding.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), + const char* cipher = gdv_fn_aes_encrypt_gcm(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), mode_len, iv.c_str(), - iv_len, padding.c_str(), padding_len, &cipher_len); + iv_len, nullptr, 0, &cipher_len); EXPECT_GT(cipher_len, 0); + EXPECT_FALSE(ctx.has_error()); - const char* decrypted_value = gdv_fn_aes_decrypt_cbc( + const char* decrypted_value = gdv_fn_aes_decrypt_gcm( ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, - iv.c_str(), iv_len, padding.c_str(), padding_len, &decrypted_len); + iv.c_str(), iv_len, 16, nullptr, 0, &decrypted_len); EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); + EXPECT_FALSE(ctx.has_error()); } -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcNoPadding) { +TEST(TestGdvFnStubs, TestAesEncryptDecryptGcm24) { gandiva::ExecutionContext ctx; - std::string key16 = "12345678abcdefgh"; - auto key16_len = static_cast(key16.length()); - std::string iv = "1234567890123456"; + std::string key24 = "12345678abcdefgh12345678"; + auto key24_len = static_cast(key24.length()); + std::string iv = "123456789012"; auto iv_len = static_cast(iv.length()); int32_t cipher_len = 0; int32_t decrypted_len = 0; - std::string data = "1234567890123456"; // Exactly 16 bytes + std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "CBC"; + std::string mode = "GCM"; auto mode_len = static_cast(mode.length()); - std::string padding = "NONE"; - auto padding_len = static_cast(padding.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), - key16_len, mode.c_str(), mode_len, iv.c_str(), - iv_len, padding.c_str(), padding_len, &cipher_len); + const char* cipher = gdv_fn_aes_encrypt_gcm(ctx_ptr, data.c_str(), data_len, key24.c_str(), + key24_len, mode.c_str(), mode_len, iv.c_str(), + iv_len, nullptr, 0, &cipher_len); EXPECT_GT(cipher_len, 0); + EXPECT_FALSE(ctx.has_error()); - const char* decrypted_value = gdv_fn_aes_decrypt_cbc( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, - iv.c_str(), iv_len, padding.c_str(), padding_len, &decrypted_len); + const char* decrypted_value = gdv_fn_aes_decrypt_gcm( + ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, mode.c_str(), mode_len, + iv.c_str(), iv_len, 16, nullptr, 0, &decrypted_len); EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); + EXPECT_FALSE(ctx.has_error()); } -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcCaseInsensitive) { +TEST(TestGdvFnStubs, TestAesEncryptDecryptGcm32) { gandiva::ExecutionContext ctx; - std::string key16 = "12345678abcdefgh"; - auto key16_len = static_cast(key16.length()); - std::string iv = "1234567890123456"; + std::string key32 = "12345678abcdefgh12345678abcdefgh"; + auto key32_len = static_cast(key32.length()); + std::string iv = "123456789012"; auto iv_len = static_cast(iv.length()); - int32_t cipher_len1 = 0; - int32_t cipher_len2 = 0; + int32_t cipher_len = 0; + int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "CBC"; + std::string mode = "GCM"; auto mode_len = static_cast(mode.length()); - std::string padding_upper = "PKCS7"; - auto padding_upper_len = static_cast(padding_upper.length()); - std::string padding_lower = "pkcs7"; - auto padding_lower_len = static_cast(padding_lower.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher1 = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), - key16_len, mode.c_str(), mode_len, iv.c_str(), - iv_len, padding_upper.c_str(), padding_upper_len, - &cipher_len1); - const char* cipher2 = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), - key16_len, mode.c_str(), mode_len, iv.c_str(), - iv_len, padding_lower.c_str(), padding_lower_len, - &cipher_len2); + const char* cipher = gdv_fn_aes_encrypt_gcm(ctx_ptr, data.c_str(), data_len, key32.c_str(), + key32_len, mode.c_str(), mode_len, iv.c_str(), + iv_len, nullptr, 0, &cipher_len); + EXPECT_GT(cipher_len, 0); + EXPECT_FALSE(ctx.has_error()); - // Both should produce same ciphertext - EXPECT_EQ(cipher_len1, cipher_len2); - EXPECT_EQ(std::string(cipher1, cipher_len1), std::string(cipher2, cipher_len2)); + const char* decrypted_value = gdv_fn_aes_decrypt_gcm( + ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, mode.c_str(), mode_len, + iv.c_str(), iv_len, 16, nullptr, 0, &decrypted_len); + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); + EXPECT_FALSE(ctx.has_error()); } -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidIV) { +TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmWithAad) { gandiva::ExecutionContext ctx; std::string key16 = "12345678abcdefgh"; auto key16_len = static_cast(key16.length()); - std::string iv = "short"; // Too short + std::string iv = "123456789012"; auto iv_len = static_cast(iv.length()); + std::string aad = "additional data"; + auto aad_len = static_cast(aad.length()); int32_t cipher_len = 0; - std::string data = "test string"; + int32_t decrypted_len = 0; + std::string data = "secret message"; auto data_len = static_cast(data.length()); - std::string mode = "CBC"; + std::string mode = "GCM"; auto mode_len = static_cast(mode.length()); - std::string padding = "PKCS7"; - auto padding_len = static_cast(padding.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, - mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), - padding_len, &cipher_len); - EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid IV length")); - ctx.Reset(); -} - -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidKey) { - gandiva::ExecutionContext ctx; - std::string key = "short"; // Too short - auto key_len = static_cast(key.length()); - std::string iv = "1234567890123456"; - auto iv_len = static_cast(iv.length()); - int32_t cipher_len = 0; - std::string data = "test string"; - auto data_len = static_cast(data.length()); - std::string mode = "CBC"; - auto mode_len = static_cast(mode.length()); - std::string padding = "PKCS7"; - auto padding_len = static_cast(padding.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); + const char* cipher = gdv_fn_aes_encrypt_gcm(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode.c_str(), mode_len, iv.c_str(), + iv_len, aad.c_str(), aad_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + EXPECT_FALSE(ctx.has_error()); - gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key.c_str(), key_len, - mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), - padding_len, &cipher_len); - EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid key length")); - ctx.Reset(); + const char* decrypted_value = gdv_fn_aes_decrypt_gcm( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, + iv.c_str(), iv_len, 16, aad.c_str(), aad_len, &decrypted_len); + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); + EXPECT_FALSE(ctx.has_error()); } -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidPadding) { +TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmModeValidation) { gandiva::ExecutionContext ctx; std::string key16 = "12345678abcdefgh"; auto key16_len = static_cast(key16.length()); - std::string iv = "1234567890123456"; + std::string iv = "123456789012"; auto iv_len = static_cast(iv.length()); int32_t cipher_len = 0; + int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "CBC"; - auto mode_len = static_cast(mode.length()); - std::string padding = "INVALID"; - auto padding_len = static_cast(padding.length()); + std::string invalid_mode = "ECB"; + auto invalid_mode_len = static_cast(invalid_mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, - mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), - padding_len, &cipher_len); - EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid padding mode")); + // Test encrypt with invalid mode + gdv_fn_aes_encrypt_gcm(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, + invalid_mode.c_str(), invalid_mode_len, iv.c_str(), iv_len, nullptr, + 0, &cipher_len); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("AES encryption mode mismatch")); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("ECB")); + ctx.Reset(); + + // Test decrypt with invalid mode + std::string cipher = "12345678abcdefgh12345678abcdefgh"; + auto cipher_len_val = static_cast(cipher.length()); + gdv_fn_aes_decrypt_gcm(ctx_ptr, cipher.c_str(), cipher_len_val, key16.c_str(), key16_len, + invalid_mode.c_str(), invalid_mode_len, iv.c_str(), iv_len, 16, + nullptr, 0, &decrypted_len); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("AES decryption mode mismatch")); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("ECB")); ctx.Reset(); } -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcModeValidation) { +TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmCaseInsensitiveMode) { gandiva::ExecutionContext ctx; std::string key16 = "12345678abcdefgh"; auto key16_len = static_cast(key16.length()); - std::string iv = "1234567890123456"; + std::string iv = "123456789012"; auto iv_len = static_cast(iv.length()); - int32_t cipher_len = 0; + int32_t cipher_len1 = 0; + int32_t cipher_len2 = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string invalid_mode = "ECB"; - auto invalid_mode_len = static_cast(invalid_mode.length()); - std::string padding = "PKCS7"; - auto padding_len = static_cast(padding.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - // Test encrypt with invalid mode - gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, - invalid_mode.c_str(), invalid_mode_len, iv.c_str(), iv_len, - padding.c_str(), padding_len, &cipher_len); - EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("AES encryption mode mismatch")); - EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("ECB")); - ctx.Reset(); + // Test with lowercase mode + std::string mode_lower = "gcm"; + auto mode_lower_len = static_cast(mode_lower.length()); + const char* cipher1 = gdv_fn_aes_encrypt_gcm(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode_lower.c_str(), mode_lower_len, + iv.c_str(), iv_len, nullptr, 0, &cipher_len1); + EXPECT_GT(cipher_len1, 0); + EXPECT_FALSE(ctx.has_error()); + + // Test with uppercase mode + std::string mode_upper = "GCM"; + auto mode_upper_len = static_cast(mode_upper.length()); + const char* cipher2 = gdv_fn_aes_encrypt_gcm(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode_upper.c_str(), mode_upper_len, + iv.c_str(), iv_len, nullptr, 0, &cipher_len2); + EXPECT_GT(cipher_len2, 0); + EXPECT_FALSE(ctx.has_error()); + + // Both should produce same ciphertext + EXPECT_EQ(cipher_len1, cipher_len2); + EXPECT_EQ(std::string(cipher1, cipher_len1), std::string(cipher2, cipher_len2)); } } // namespace gandiva From 77da0d6ac2b6f56e4a348995c3435cdf7897ec0c Mon Sep 17 00:00:00 2001 From: Tim Hurski Date: Tue, 18 Nov 2025 04:43:38 +0000 Subject: [PATCH 4/5] Extract AES mode validation into ensure_mode() helper function - Create ensure_mode() helper that throws std::runtime_error for invalid modes - Update all 6 AES functions (3 encrypt + 3 decrypt) to use ensure_mode() - ECB, CBC, GCM encrypt functions now use centralized mode validation - ECB, CBC, GCM decrypt functions now use centralized mode validation - Remove redundant mode validation code from all functions - Wrap all function bodies in try-catch to handle exceptions from ensure_mode() - Consistent error handling across all modes --- cpp/src/gandiva/function_registry_string.cc | 7 +- cpp/src/gandiva/gdv_function_stubs.cc | 388 ++++++++------------ cpp/src/gandiva/gdv_function_stubs.h | 2 + cpp/src/gandiva/gdv_function_stubs_test.cc | 4 +- 4 files changed, 156 insertions(+), 245 deletions(-) diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index 16f8977e7d6..a1533e0c96f 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -514,12 +514,13 @@ std::vector GetStringFunctionRegistry() { NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), // CBC mode specific functions - // Binary-based signatures (BINARY, BINARY, BINARY, UTF8, UTF8) -> BINARY - NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), binary(), utf8(), utf8()}, binary(), + // Binary-based signatures (BINARY, BINARY, UTF8, BINARY, UTF8) -> BINARY + // Parameters: data, key, mode, iv, padding + NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), utf8()}, binary(), kResultNullIfNull, "gdv_fn_aes_encrypt_cbc", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), - NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), binary(), utf8(), utf8()}, binary(), + NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), utf8()}, binary(), kResultNullIfNull, "gdv_fn_aes_decrypt_cbc", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index b4edcacd5d4..046bd8ca7c7 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -399,76 +399,67 @@ CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8) #undef GDV_FN_CAST_VARCHAR_REAL - -// ECB mode specific functions - core implementation -// This handles both string and binary inputs (they have the same C signature) -GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_ecb(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - int32_t* out_len) { - // Validate mode parameter +// Helper function to validate AES mode parameter +// Throws std::runtime_error if mode is invalid +static void ensure_mode(const char* mode, int32_t mode_len, + const std::string& expected_mode) { if (mode == nullptr) { - std::ostringstream oss; - oss << "Invalid mode parameter for AES encryption"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; + throw std::runtime_error("Invalid mode parameter for AES encryption"); } std::string mode_str(mode, mode_len); // Convert to uppercase for comparison std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); - if (mode_str != "ECB") { + if (mode_str != expected_mode) { std::ostringstream oss; - oss << "AES encryption mode mismatch: function signature indicates ECB mode, but '" - << mode_str << "' was provided instead"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; + oss << "AES encryption mode mismatch: function signature indicates " << expected_mode + << " mode, but '" << mode_str << "' was provided instead"; + throw std::runtime_error(oss.str()); } +} - if (data_len < 0) { - std::ostringstream oss; - oss << "Invalid data length for AES encryption: " << data_len << " (must be >= 0)"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } +// ECB mode specific functions - core implementation +// This handles both string and binary inputs (they have the same C signature) +GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_ecb(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + int32_t* out_len) { + try { + // Validate mode parameter + ensure_mode(mode, mode_len, "ECB"); - if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { - std::ostringstream oss; - oss << "Invalid key length for AES encryption: " << key_data_len - << " bytes. Supported lengths: 16, 24, 32 bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return nullptr; - } + if (data_len < 0) { + throw std::runtime_error( + std::string("Invalid data length for AES encryption: ") + std::to_string(data_len) + + " (must be >= 0)"); + } - // AES block size is always 16 bytes (128 bits), regardless of key length - int64_t kAesBlockSize = 16; - *out_len = - static_cast(arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); - if (ret == nullptr) { - std::ostringstream oss; - oss << "Could not allocate memory for AES encryption output: " << *out_len << " bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return nullptr; - } + if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { + throw std::runtime_error( + std::string("Invalid key length for AES encryption: ") + std::to_string(key_data_len) + + " bytes. Supported lengths: 16, 24, 32 bytes"); + } + + // AES block size is always 16 bytes (128 bits), regardless of key length + int64_t kAesBlockSize = 16; + *out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + throw std::runtime_error(std::string("Could not allocate memory for AES encryption output: ") + + std::to_string(*out_len) + " bytes"); + } - try { *out_len = gandiva::aes_encrypt_ecb(data, data_len, key_data, key_data_len, reinterpret_cast(ret)); + return ret; } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; return nullptr; } - - return ret; } // Legacy wrapper for string-based signatures (UTF8, UTF8) -> UTF8 @@ -499,68 +490,40 @@ const char* gdv_fn_aes_decrypt_ecb(int64_t context, const char* data, int32_t da const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, int32_t* out_len) { - // Validate mode parameter - if (mode == nullptr) { - std::ostringstream oss; - oss << "Invalid mode parameter for AES decryption"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - std::string mode_str(mode, mode_len); - // Convert to uppercase for comparison - std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); - - if (mode_str != "ECB") { - std::ostringstream oss; - oss << "AES decryption mode mismatch: function signature indicates ECB mode, but '" - << mode_str << "' was provided instead"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + try { + // Validate mode parameter + ensure_mode(mode, mode_len, "ECB"); - if (data_len < 0) { - std::ostringstream oss; - oss << "Invalid data length for AES decryption: " << data_len << " (must be >= 0)"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + if (data_len < 0) { + throw std::runtime_error( + std::string("Invalid data length for AES decryption: ") + std::to_string(data_len) + + " (must be >= 0)"); + } - if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { - std::ostringstream oss; - oss << "Invalid key length for AES decryption: " << key_data_len - << " bytes. Supported lengths: 16, 24, 32 bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return nullptr; - } + if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { + throw std::runtime_error( + std::string("Invalid key length for AES decryption: ") + std::to_string(key_data_len) + + " bytes. Supported lengths: 16, 24, 32 bytes"); + } - // AES block size is always 16 bytes (128 bits), regardless of key length - int64_t kAesBlockSize = 16; - *out_len = - static_cast(arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); - if (ret == nullptr) { - std::ostringstream oss; - oss << "Could not allocate memory for AES decryption output: " << *out_len << " bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return nullptr; - } + // AES block size is always 16 bytes (128 bits), regardless of key length + int64_t kAesBlockSize = 16; + *out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + throw std::runtime_error(std::string("Could not allocate memory for AES decryption output: ") + + std::to_string(*out_len) + " bytes"); + } - try { *out_len = gandiva::aes_decrypt_ecb(data, data_len, key_data, key_data_len, reinterpret_cast(ret)); + return ret; } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; return nullptr; } - - return ret; } // Legacy wrapper for string-based signatures (UTF8, UTF8) -> UTF8 @@ -591,57 +554,55 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int GANDIVA_EXPORT const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, const char* iv, int32_t iv_len, const char* padding, int32_t padding_len, int32_t* out_len) { - // Allocate output buffer (max size: input + 16 bytes for padding) - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, data_len + 16)); - if (ret == nullptr) { - std::ostringstream oss; - oss << "Could not allocate memory for AES-CBC encryption"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - try { + // Validate mode parameter + ensure_mode(mode, mode_len, "CBC"); + + // Allocate output buffer (max size: input + 16 bytes for padding) + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, data_len + 16)); + if (ret == nullptr) { + throw std::runtime_error("Could not allocate memory for AES-CBC encryption"); + } + *out_len = gandiva::aes_encrypt_cbc(data, data_len, key_data, key_data_len, iv, iv_len, padding, padding_len, reinterpret_cast(ret)); + return ret; } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; return nullptr; } - - return ret; } GANDIVA_EXPORT const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, const char* iv, int32_t iv_len, const char* padding, int32_t padding_len, int32_t* out_len) { - // Allocate output buffer (max size: input size, since decryption removes padding) - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, data_len)); - if (ret == nullptr) { - std::ostringstream oss; - oss << "Could not allocate memory for AES-CBC decryption"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - try { + // Validate mode parameter + ensure_mode(mode, mode_len, "CBC"); + + // Allocate output buffer (max size: input size, since decryption removes padding) + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, data_len)); + if (ret == nullptr) { + throw std::runtime_error("Could not allocate memory for AES-CBC decryption"); + } + *out_len = gandiva::aes_decrypt_cbc(data, data_len, key_data, key_data_len, iv, iv_len, padding, padding_len, reinterpret_cast(ret)); + return ret; } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; return nullptr; } - - return ret; } // GCM mode specific functions - core implementation @@ -652,59 +613,33 @@ const char* gdv_fn_aes_encrypt_gcm(int64_t context, const char* data, int32_t da const char* mode, int32_t mode_len, const char* iv, int32_t iv_len, const char* aad, int32_t aad_len, int32_t* out_len) { - // Validate mode parameter - if (mode == nullptr) { - std::ostringstream oss; - oss << "Invalid mode parameter for AES encryption"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - std::string mode_str(mode, mode_len); - // Convert to uppercase for comparison - std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); - - if (mode_str != "GCM") { - std::ostringstream oss; - oss << "AES encryption mode mismatch: function signature indicates GCM mode, but '" - << mode_str << "' was provided instead"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + try { + // Validate mode parameter + ensure_mode(mode, mode_len, "GCM"); - if (data_len < 0) { - std::ostringstream oss; - oss << "Invalid data length for AES encryption: " << data_len << " (must be >= 0)"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + if (data_len < 0) { + throw std::runtime_error( + std::string("Invalid data length for AES encryption: ") + std::to_string(data_len) + + " (must be >= 0)"); + } - if (iv_len < 0) { - std::ostringstream oss; - oss << "Invalid IV length for AES encryption: " << iv_len << " (must be >= 0)"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + if (iv_len < 0) { + throw std::runtime_error( + std::string("Invalid IV length for AES encryption: ") + std::to_string(iv_len) + + " (must be >= 0)"); + } - // For GCM, output is ciphertext + tag (typically 16 bytes) - // Allocate space for ciphertext + 16-byte tag - int64_t kGcmTagSize = 16; - *out_len = static_cast( - arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len + kGcmTagSize), 16)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); - if (ret == nullptr) { - std::ostringstream oss; - oss << "Could not allocate memory for AES encryption output: " << *out_len << " bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return nullptr; - } + // For GCM, output is ciphertext + tag (typically 16 bytes) + // Allocate space for ciphertext + 16-byte tag + int64_t kGcmTagSize = 16; + *out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len + kGcmTagSize), 16)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + throw std::runtime_error(std::string("Could not allocate memory for AES encryption output: ") + + std::to_string(*out_len) + " bytes"); + } - try { unsigned char* cipher = reinterpret_cast(ret); unsigned char* tag = cipher + data_len; @@ -719,13 +654,12 @@ const char* gdv_fn_aes_encrypt_gcm(int64_t context, const char* data, int32_t da } *out_len = cipher_len + kGcmTagSize; + return ret; } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; return nullptr; } - - return ret; } GANDIVA_EXPORT @@ -734,69 +668,40 @@ const char* gdv_fn_aes_decrypt_gcm(int64_t context, const char* data, int32_t da const char* mode, int32_t mode_len, const char* iv, int32_t iv_len, int32_t tag_length, const char* aad, int32_t aad_len, int32_t* out_len) { - // Validate mode parameter - if (mode == nullptr) { - std::ostringstream oss; - oss << "Invalid mode parameter for AES decryption"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - std::string mode_str(mode, mode_len); - // Convert to uppercase for comparison - std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); - - if (mode_str != "GCM") { - std::ostringstream oss; - oss << "AES decryption mode mismatch: function signature indicates GCM mode, but '" - << mode_str << "' was provided instead"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + try { + // Validate mode parameter + ensure_mode(mode, mode_len, "GCM"); - if (data_len < 0) { - std::ostringstream oss; - oss << "Invalid data length for AES decryption: " << data_len << " (must be >= 0)"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + if (data_len < 0) { + throw std::runtime_error( + std::string("Invalid data length for AES decryption: ") + std::to_string(data_len) + + " (must be >= 0)"); + } - if (iv_len < 0) { - std::ostringstream oss; - oss << "Invalid IV length for AES decryption: " << iv_len << " (must be >= 0)"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + if (iv_len < 0) { + throw std::runtime_error( + std::string("Invalid IV length for AES decryption: ") + std::to_string(iv_len) + + " (must be >= 0)"); + } - // For GCM, input is ciphertext + tag - // Ciphertext length is data_len - tag_length - int32_t cipher_len = data_len - tag_length; - if (cipher_len < 0) { - std::ostringstream oss; - oss << "Invalid tag length for AES decryption: " << tag_length - << " (must be <= data length " << data_len << ")"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + // For GCM, input is ciphertext + tag + // Ciphertext length is data_len - tag_length + int32_t cipher_len = data_len - tag_length; + if (cipher_len < 0) { + throw std::runtime_error( + std::string("Invalid tag length for AES decryption: ") + std::to_string(tag_length) + + " (must be <= data length " + std::to_string(data_len) + ")"); + } - // Allocate space for plaintext - *out_len = static_cast( - arrow::bit_util::RoundUpToPowerOf2(static_cast(cipher_len), 16)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); - if (ret == nullptr) { - std::ostringstream oss; - oss << "Could not allocate memory for AES decryption output: " << *out_len << " bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return nullptr; - } + // Allocate space for plaintext + *out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(cipher_len), 16)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + throw std::runtime_error(std::string("Could not allocate memory for AES decryption output: ") + + std::to_string(*out_len) + " bytes"); + } - try { const char* tag = data + cipher_len; if (aad != nullptr && aad_len > 0) { @@ -808,13 +713,12 @@ const char* gdv_fn_aes_decrypt_gcm(int64_t context, const char* data, int32_t da iv_len, tag, tag_length, reinterpret_cast(ret)); } + return ret; } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; return nullptr; } - - return ret; } @@ -1527,14 +1431,16 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { reinterpret_cast(gdv_fn_aes_decrypt_ecb_legacy)); // gdv_fn_aes_encrypt_cbc - // Note: The IV and padding parameters are passed as binary/UTF8 strings (data + length) - // Function signature: (context, data, data_len, key_data, key_data_len, iv, iv_len, padding, padding_len, out_len) + // Note: The mode, IV and padding parameters are passed as binary/UTF8 strings (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (UTF8 string) + types->i32_type(), // mode_length types->i8_ptr_type(), // iv types->i32_type(), // iv_length types->i8_ptr_type(), // padding @@ -1547,14 +1453,16 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { reinterpret_cast(gdv_fn_aes_encrypt_cbc)); // gdv_fn_aes_decrypt_cbc - // Note: The IV and padding parameters are passed as binary/UTF8 strings (data + length) - // Function signature: (context, data, data_len, key_data, key_data_len, iv, iv_len, padding, padding_len, out_len) + // Note: The mode, IV and padding parameters are passed as binary/UTF8 strings (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (UTF8 string) + types->i32_type(), // mode_length types->i8_ptr_type(), // iv types->i32_type(), // iv_length types->i8_ptr_type(), // padding diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index d471b9103ef..b1f5aff5396 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -217,12 +217,14 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int GANDIVA_EXPORT const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, const char* iv, int32_t iv_len, const char* padding, int32_t padding_len, int32_t* out_len); GANDIVA_EXPORT const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, const char* iv, int32_t iv_len, const char* padding, int32_t padding_len, int32_t* out_len); diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index f13d0861558..90f1654bb6c 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1474,7 +1474,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher.c_str(), cipher_len_val, key16.c_str(), key16_len, invalid_mode.c_str(), invalid_mode_len, &decrypted_len); EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("AES decryption mode mismatch")); + ::testing::HasSubstr("AES encryption mode mismatch")); EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("CBC")); ctx.Reset(); @@ -1622,7 +1622,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmModeValidation) { invalid_mode.c_str(), invalid_mode_len, iv.c_str(), iv_len, 16, nullptr, 0, &decrypted_len); EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("AES decryption mode mismatch")); + ::testing::HasSubstr("AES encryption mode mismatch")); EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("ECB")); ctx.Reset(); From 10060098ba4a3e7cd86781b2e49d49e004da8120 Mon Sep 17 00:00:00 2001 From: Tim Hurski Date: Tue, 18 Nov 2025 17:22:43 +0000 Subject: [PATCH 5/5] Update macOS runner version to 15-Intel --- .github/workflows/csharp.yml | 4 ++-- .github/workflows/java.yml | 4 ++-- .github/workflows/js.yml | 4 ++-- dev/tasks/java-jars/github.yml | 4 ++-- dev/tasks/r/github.packages.yml | 6 +++--- dev/tasks/tasks.yml | 6 +++--- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.github/workflows/csharp.yml b/.github/workflows/csharp.yml index 5f657e6c1bf..d479456d0d8 100644 --- a/.github/workflows/csharp.yml +++ b/.github/workflows/csharp.yml @@ -94,8 +94,8 @@ jobs: run: ci/scripts/csharp_test.sh $(pwd) macos: - name: AMD64 macOS 13 C# ${{ matrix.dotnet }} - runs-on: macos-13 + name: AMD64 macOS 15 C# ${{ matrix.dotnet }} + runs-on: macos-15-intel if: ${{ !contains(github.event.pull_request.title, 'WIP') }} timeout-minutes: 15 strategy: diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 5766c63bf52..6c0cf099116 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -106,8 +106,8 @@ jobs: run: archery docker push ${{ matrix.image }} macos: - name: AMD64 macOS 13 Java JDK ${{ matrix.jdk }} - runs-on: macos-13 + name: AMD64 macOS 15 Java JDK ${{ matrix.jdk }} + runs-on: macos-15-intel if: ${{ !contains(github.event.pull_request.title, 'WIP') }} timeout-minutes: 30 strategy: diff --git a/.github/workflows/js.yml b/.github/workflows/js.yml index 031310fd402..a51ad867aa7 100644 --- a/.github/workflows/js.yml +++ b/.github/workflows/js.yml @@ -81,8 +81,8 @@ jobs: run: archery docker push debian-js macos: - name: AMD64 macOS 13 NodeJS ${{ matrix.node }} - runs-on: macos-13 + name: AMD64 macOS 15 NodeJS ${{ matrix.node }} + runs-on: macos-15-intel if: ${{ !contains(github.event.pull_request.title, 'WIP') }} timeout-minutes: 30 strategy: diff --git a/dev/tasks/java-jars/github.yml b/dev/tasks/java-jars/github.yml index f7dd177e875..ff1834e63b9 100644 --- a/dev/tasks/java-jars/github.yml +++ b/dev/tasks/java-jars/github.yml @@ -91,7 +91,7 @@ jobs: fail-fast: false matrix: platform: - - { runs_on: ["macos-13"], arch: "x86_64"} + - { runs_on: ["macos-15-intel"], arch: "x86_64"} env: MACOSX_DEPLOYMENT_TARGET: "12.0" steps: @@ -190,7 +190,7 @@ jobs: fail-fast: false matrix: platform: - - { runs_on: ["macos-13"], arch: "x86_64"} + - { runs_on: ["macos-15-intel"], arch: "x86_64"} needs: - build-cpp-ubuntu - build-cpp-macos diff --git a/dev/tasks/r/github.packages.yml b/dev/tasks/r/github.packages.yml index 839e3d53410..a0005e19035 100644 --- a/dev/tasks/r/github.packages.yml +++ b/dev/tasks/r/github.packages.yml @@ -66,7 +66,7 @@ jobs: fail-fast: false matrix: platform: - - { runs_on: macos-13, arch: "x86_64" } + - { runs_on: macos-15-intel, arch: "x86_64" } - { runs_on: macos-14, arch: "arm64" } openssl: ['3.0', '1.1'] @@ -208,7 +208,7 @@ jobs: matrix: platform: - { runs_on: 'windows-latest', name: "Windows"} - - { runs_on: macos-13, name: "macOS x86_64"} + - { runs_on: macos-15-intel, name: "macOS x86_64"} - { runs_on: macos-14, name: "macOS arm64" } r_version: [oldrel, release] steps: @@ -389,7 +389,7 @@ jobs: matrix: platform: - {runs_on: "ubuntu-latest", name: "Linux"} - - {runs_on: "macos-13" , name: "macOS"} + - {runs_on: "macos-15-intel" , name: "macOS"} steps: - name: Install R uses: r-lib/actions/setup-r@v2 diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 91e1c07e1fc..5b889678765 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -425,7 +425,7 @@ tasks: python_version: "{{ python_version }}" python_abi_tag: "{{ abi_tag }}" macos_deployment_target: "12.0" - runs_on: "macos-13" + runs_on: "macos-15-intel" vcpkg_arch: "amd64" artifacts: - pyarrow-{no_rc_version}-{{ python_tag }}-{{ abi_tag }}-macosx_12_0_x86_64.whl @@ -967,7 +967,7 @@ tasks: params: target: {{ target }} use_conda: True - github_runner: "macos-13" + github_runner: "macos-15-intel" {% endfor %} {% for target in ["cpp", @@ -982,7 +982,7 @@ tasks: template: verify-rc/github.macos.yml params: target: {{ target }} - github_runner: "macos-13" + github_runner: "macos-15-intel" {% endfor %} {% for target in ["cpp",