diff --git a/.github/workflows/msbuild.yml b/.github/workflows/msbuild.yml index f8cf54a..4df7084 100644 --- a/.github/workflows/msbuild.yml +++ b/.github/workflows/msbuild.yml @@ -73,3 +73,20 @@ jobs: msm\*.msi msm\*.msm gui\build\Release\*.* + + tests: + runs-on: windows-2022 + + steps: + - uses: actions/checkout@v2 + with: + submodules: true + + - name: Configure CMake tests + run: cmake -S tests -B tests/build -A x64 + + - name: Build CMake tests + run: cmake --build tests/build --config Release + + - name: Run CMake tests + run: ctest --test-dir tests/build --build-config Release --output-on-failure diff --git a/Driver.cpp b/Driver.cpp index 9805951..2387c4b 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -727,7 +727,7 @@ VOID OvpnEvtDeviceCleanup(WDFOBJECT obj) { // OvpnCryptoUninitAlgHandles called outside of lock because // it requires PASSIVE_LEVEL. - OvpnCryptoUninitAlgHandles(device->AesAlgHandle, device->ChachaAlgHandle); + OvpnCryptoUninitAlgHandles(device->AesAlgHandle, device->ChachaAlgHandle, device->HkdfAlgHandle); // delete control device if there are no devices left POVPN_DRIVER driverCtx = OvpnGetDriverContext(WdfGetDriver()); @@ -892,7 +892,7 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { device->IRoutesIPV4.Init(FALSE); device->IRoutesIPV6.Init(TRUE); - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoInitAlgHandles(&device->AesAlgHandle, &device->ChachaAlgHandle)); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoInitAlgHandles(&device->AesAlgHandle, &device->ChachaAlgHandle, &device->HkdfAlgHandle)); // Initialize peers tables RtlInitializeGenericTable(&device->Peers, OvpnPeerCompareByPeerIdRoutine, OvpnPeerAllocateRoutine, OvpnPeerFreeRoutine, NULL); @@ -901,7 +901,6 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { RtlInitializeGenericTable(&device->PeersByTransport, OvpnPeerCompareByTransportRoutine, OvpnPeerAllocateRoutine, OvpnPeerFreeRoutine, NULL); LOG_IF_NOT_NT_SUCCESS(status = OvpnAdapterCreate(device)); - done: LOG_EXIT(); diff --git a/Driver.h b/Driver.h index 4377fd0..f61921d 100644 --- a/Driver.h +++ b/Driver.h @@ -86,6 +86,7 @@ struct OVPN_DEVICE { BCRYPT_ALG_HANDLE AesAlgHandle; BCRYPT_ALG_HANDLE ChachaAlgHandle; + BCRYPT_ALG_HANDLE HkdfAlgHandle; _Guarded_by_(SpinLock) OvpnSocket Socket; diff --git a/PropertySheet.props b/PropertySheet.props index 1d4fed5..82ed55a 100644 --- a/PropertySheet.props +++ b/PropertySheet.props @@ -3,8 +3,8 @@ 2 - 7 - 1 + 8 + 0 diff --git a/adapter.h b/adapter.h index dfaa0fc..5035347 100644 --- a/adapter.h +++ b/adapter.h @@ -56,5 +56,3 @@ OvpnAdapterNotifyRx(NETADAPTER netAdapter); VOID OvpnAdapterSetLinkState(_In_ POVPN_ADAPTER adapter, NET_IF_MEDIA_CONNECT_STATE state); - -#define OVPN_PAYLOAD_BACKFILL 26 // 2 + 4 + 4 + 16 -> tcp packet size + data_v2 + pktid + auth-tag; \ No newline at end of file diff --git a/crypto.cpp b/crypto.cpp index 4f0a639..ba4aa97 100644 --- a/crypto.cpp +++ b/crypto.cpp @@ -21,11 +21,13 @@ #include #include +#include #include "crypto.h" #include "trace.h" #include "pktid.h" #include "socket.h" +#include "peer.h" UINT OvpnCryptoOpCompose(UINT opcode, UINT keyId) @@ -33,6 +35,120 @@ OvpnCryptoOpCompose(UINT opcode, UINT keyId) return (opcode << OVPN_OPCODE_SHIFT) | keyId; } +static +VOID +OvpnCryptoUpgradeLock(_In_ OvpnPeerContext* peer, _In_ BOOLEAN atDpcLevel, _Inout_opt_ PKIRQL kirql) +{ + if (atDpcLevel) { + ExReleaseSpinLockSharedFromDpcLevel(&peer->SpinLock); + ExAcquireSpinLockExclusiveAtDpcLevel(&peer->SpinLock); + } + else { + NT_ASSERT(kirql != nullptr); + + KIRQL previousIrql = *kirql; + + ExReleaseSpinLockShared(&peer->SpinLock, previousIrql); + + KIRQL acquiredIrql = ExAcquireSpinLockExclusive(&peer->SpinLock); + NT_ASSERT(acquiredIrql == previousIrql); + *kirql = acquiredIrql; + } +} + +_Use_decl_annotations_ +NTSTATUS +OvpnCryptoCallWithRetry( + OvpnPeerContext* peer, + BOOLEAN atDpcLevel, + PBOOLEAN exclusive, + PKIRQL kirql, + POVPN_CRYPTO_RETRY_ROUTINE routine, + PVOID context) +{ + NT_ASSERT(peer != nullptr); + NT_ASSERT(routine != nullptr); + + BOOLEAN allowRekey = (exclusive != nullptr) && *exclusive; + + for (;;) { + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; + NTSTATUS status = routine(cryptoContext, allowRekey, context); + + if (status != STATUS_OVPN_CRYPTO_RETRY) { + return status; + } + + if (!allowRekey) { + // Drop the shared lock before upgrading. The caller owns a peer + // reference, so the context remains valid while we temporarily + // release the lock and the loop below re-reads the crypto state + // under exclusive ownership. + OvpnCryptoUpgradeLock(peer, atDpcLevel, kirql); + allowRekey = TRUE; + + if (exclusive != nullptr) { + *exclusive = TRUE; + } + + continue; + } + + LOG_ERROR("Crypto helper requested retry under exclusive lock"); + return STATUS_UNSUCCESSFUL; + } +} + +_Use_decl_annotations_ +NTSTATUS +OvpnCryptoInvokeEncrypt( + OvpnCryptoContext* cryptoContext, + BOOLEAN allowRekey, + PVOID context) +{ + auto params = reinterpret_cast(context); + + if ((cryptoContext == nullptr) || (cryptoContext->Encrypt == nullptr)) { + return STATUS_INVALID_DEVICE_STATE; + } + + NT_ASSERT(params != nullptr); + + return cryptoContext->Encrypt(&cryptoContext->Primary, params->Buffer, params->Length, &cryptoContext->Options, allowRekey); +} + +_Use_decl_annotations_ +NTSTATUS +OvpnCryptoInvokeDecrypt( + OvpnCryptoContext* cryptoContext, + BOOLEAN allowRekey, + PVOID context) +{ + auto params = reinterpret_cast(context); + + if ((cryptoContext == nullptr) || (cryptoContext->Decrypt == nullptr)) { + return STATUS_INVALID_DEVICE_STATE; + } + + NT_ASSERT(params != nullptr); + + OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(cryptoContext, params->KeyId); + if (keySlot == nullptr) { + LOG_ERROR("keyId not found", TraceLoggingValue(params->KeyId, "keyId")); + return STATUS_INVALID_DEVICE_STATE; + } + + NTSTATUS status = cryptoContext->Decrypt( + keySlot, + params->CipherText, + params->Length, + params->PlainText, + &cryptoContext->Options, + allowRekey); + + return status; +} + static UINT OvpnProtoOp32Compose(UINT opcode, UINT keyId, UINT opPeerId) @@ -48,12 +164,15 @@ OvpnProtoOp32Compose(UINT opcode, UINT keyId, UINT opPeerId) OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptNone; _Use_decl_annotations_ -NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) +NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, OvpnCryptoOptions* opts, BOOLEAN allowRekey) { UNREFERENCED_PARAMETER(keySlot); + UNREFERENCED_PARAMETER(allowRekey); - BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - BOOLEAN cryptoOverhead = OVPN_DATA_V2_LEN + pktId64bit ? 8 : 4; + BOOLEAN useEpoch = (opts != NULL) && opts->UseEpoch; + SIZE_T pktIdLen = useEpoch ? 8 : 4; + SIZE_T authTagFront = useEpoch ? 0 : AEAD_AUTH_TAG_LEN; + SIZE_T cryptoOverhead = OVPN_DATA_V2_LEN + pktIdLen + authTagFront; if (len < cryptoOverhead) { LOG_WARN("Packet too short", TraceLoggingValue(len, "len")); @@ -69,11 +188,12 @@ OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptNone; _Use_decl_annotations_ NTSTATUS -OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions) +OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, OvpnCryptoOptions* opts, BOOLEAN allowRekey) { UNREFERENCED_PARAMETER(keySlot); UNREFERENCED_PARAMETER(len); - UNREFERENCED_PARAMETER(cryptoOptions); + UNREFERENCED_PARAMETER(opts); + UNREFERENCED_PARAMETER(allowRekey); // prepend with opcode, key-id and peer-id UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, 0, 0); @@ -90,12 +210,15 @@ OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 _Use_decl_annotations_ NTSTATUS -OvpnCryptoInitAlgHandles(BCRYPT_ALG_HANDLE* aesAlgHandle, BCRYPT_ALG_HANDLE* chachaAlgHandle) +OvpnCryptoInitAlgHandles(BCRYPT_ALG_HANDLE* aesAlgHandle, BCRYPT_ALG_HANDLE* chachaAlgHandle, BCRYPT_ALG_HANDLE* hkdfAlgHandle) { NTSTATUS status; GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptOpenAlgorithmProvider(aesAlgHandle, BCRYPT_AES_ALGORITHM, NULL, BCRYPT_PROV_DISPATCH)); GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptSetProperty(*aesAlgHandle, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0)); + // used by epoch data channel + GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptOpenAlgorithmProvider(hkdfAlgHandle, BCRYPT_HKDF_ALGORITHM, NULL, BCRYPT_PROV_DISPATCH)); + // available starting from Windows 11 LOG_IF_NOT_NT_SUCCESS(BCryptOpenAlgorithmProvider(chachaAlgHandle, BCRYPT_CHACHA20_POLY1305_ALGORITHM, NULL, BCRYPT_PROV_DISPATCH)); done: @@ -104,7 +227,7 @@ OvpnCryptoInitAlgHandles(BCRYPT_ALG_HANDLE* aesAlgHandle, BCRYPT_ALG_HANDLE* cha _Use_decl_annotations_ VOID -OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDLE chachaAlgHandle) +OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDLE chachaAlgHandle, BCRYPT_ALG_HANDLE hkdfAlgHandle) { if (aesAlgHandle) { LOG_IF_NOT_NT_SUCCESS(BCryptCloseAlgorithmProvider(aesAlgHandle, 0)); @@ -113,6 +236,10 @@ OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDL if (chachaAlgHandle) { LOG_IF_NOT_NT_SUCCESS(BCryptCloseAlgorithmProvider(chachaAlgHandle, 0)); } + + if (hkdfAlgHandle) { + LOG_IF_NOT_NT_SUCCESS(BCryptCloseAlgorithmProvider(hkdfAlgHandle, 0)); + } } #define GET_SYSTEM_ADDRESS_MDL(buf, mdl) { \ @@ -123,151 +250,245 @@ OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDL } \ } -static NTSTATUS -OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) +OvpnCryptoCheckReplay(OvpnCryptoKeySlot* keySlot, ULONG64 packet_id_net, UINT16 epoch, OvpnCryptoOptions *opts, BOOLEAN allowRekey) { - /* - AEAD Nonce : + OvpnPktidRecv* recv = NULL; - [Packet ID] [HMAC keying material] - [4/8 bytes] [8/4 bytes ] - [AEAD nonce total : 12 bytes ] + if (epoch == 0 || keySlot->Decrypt.Epoch == epoch) { + recv = &keySlot->PktidRecv; + } + else if (epoch == keySlot->RetiringEpochDataReceiveKey.Epoch) { + recv = &keySlot->PktidRecvRetiring; + } + else { + if (!allowRekey) { + return STATUS_OVPN_CRYPTO_RETRY; + } - TLS wire protocol : + /* We have an epoch that is neither current or old recv key but + * is authenticated, ie we need to move to a new current recv key */ + LOG_INFO("Received data packet with new epoch. Updating receive key", TraceLoggingValue(epoch, "epoch")); + OvpnCryptoEpochReplaceUpdateRecvKey(keySlot, epoch, opts); + recv = &keySlot->PktidRecv; + } - Packet ID is 8 bytes long with CRYPTO_OPTIONS_64BIT_PKTID. + return OvpnPktidRecvVerify(recv, packet_id_net); +} - [DATA_V2 opcode] [Packet ID] [AEAD Auth tag] [ciphertext] - [4 bytes ] [4/8 bytes] [16 bytes ] - [AEAD additional data(AD) ] +/* +AEAD Nonce : - With CRYPTO_OPTIONS_AEAD_TAG_END AEAD Auth tag is placed after ciphertext: + [Packet ID] [HMAC keying material] + [4 bytes ] [4 bytes ] + [AEAD nonce total : 12 bytes ] - [DATA_V2 opcode] [Packet ID] [ciphertext] [AEAD Auth tag] - [4 bytes ] [4/8 bytes] [16 bytes ] - [AEAD additional data(AD) ] - */ +TLS wire protocol : - NTSTATUS status = STATUS_SUCCESS; + [DATA_V2 opcode] [Packet ID] [AEAD Auth tag] [ciphertext] + [4 bytes ] [4 bytes ] [16 bytes ] + [AEAD additional data(AD) ] - BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; +New data format, with epoch keys and 64bit packet id: - SIZE_T cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4); + struct aead_packet { + int opcode:5; + int key_id:3; + int peer_id:24; + uint64_t packet_id; + uint8_t* encrypted_payload; + uint8_t[16] authentication_tag; + } - if (len < cryptoOverhead) { - LOG_WARN("Packet too short", TraceLoggingValue(len, "len")); - return STATUS_DATA_ERROR; + struct packet_id { + uint epoch:16; + uint epoch_counter:48; } - UCHAR nonce[12]; - if (encrypt) { - // prepend with opcode, key-id and peer-id - UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, keySlot->KeyId, keySlot->PeerId); - op = RtlUlongByteSwap(op); - *reinterpret_cast(bufOut) = op; + authenticated_data = opcode| key_id | peer_id | packet_id +*/ - if (pktId64bit) - { - // calculate pktid - UINT64 pktid; - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, true)); - ULONG64 pktidNetwork = RtlUlonglongByteSwap(pktid); - // calculate nonce, which is pktid + nonce_tail - RtlCopyMemory(nonce, &pktidNetwork, 8); - RtlCopyMemory(nonce + 8, keySlot->EncNonceTail, 4); +OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptAEAD; - // prepend with pktid - *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork; - } - else - { - // calculate pktid - UINT32 pktid; - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, false)); - ULONG pktidNetwork = RtlUlongByteSwap(pktid); +_Use_decl_annotations_ +NTSTATUS +OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, OvpnCryptoOptions* opts, BOOLEAN allowRekey) +{ + NTSTATUS status = STATUS_SUCCESS; - // calculate nonce, which is pktid + nonce_tail - RtlCopyMemory(nonce, &pktidNetwork, 4); - RtlCopyMemory(nonce + 4, keySlot->EncNonceTail, 8); + BOOLEAN authTagEnd = opts->UseEpoch; + ULONG pktidLen = opts->UseEpoch ? 8 : 4; + ULONG cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + pktidLen; - // prepend with pktid - *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork; - } + if (len < cryptoOverhead) { + LOG_WARN("Packet too short", TraceLoggingValue(len, "len")); + return STATUS_DATA_ERROR; } - else { - ULONG64 pktId; - RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, pktId64bit ? 8 : 4); - RtlCopyMemory(nonce + (pktId64bit ? 8 : 4), &keySlot->DecNonceTail, pktId64bit ? 4 : 8); - if (pktId64bit) - { - pktId = RtlUlonglongByteSwap(*reinterpret_cast(nonce)); - } - else - { - pktId = static_cast(RtlUlongByteSwap(*reinterpret_cast(nonce))); - } + // we prepended buf with crypto overhead + len -= cryptoOverhead; + + OvpnCryptoKeyContext* decryptKey = &keySlot->Decrypt; + UINT64 packet_id = 0; + UINT16 rx_epoch = 0; + + UCHAR nonce[12]; - status = OvpnPktidRecvVerify(&keySlot->PktidRecv, pktId); + if (opts->UseEpoch) { + // read packet_id + UINT64 packet_id_net; + RtlCopyMemory(&packet_id_net, bufIn + OVPN_DATA_V2_LEN, sizeof(packet_id_net)); + packet_id = RtlUlonglongByteSwap(packet_id_net); - if (!NT_SUCCESS(status)) { - LOG_ERROR("Invalid pktId", TraceLoggingUInt64(pktId, "pktId")); + // get epoch number and counter + rx_epoch = (UINT16)(packet_id >> 48); + if (rx_epoch == 0) { + LOG_ERROR("Invalid epoch 0"); return STATUS_DATA_ERROR; } - } - // we prepended buf with crypto overhead - len -= cryptoOverhead; + decryptKey = OvpnCryptoEpochLookupDecryptKey(keySlot, rx_epoch); + if (decryptKey == NULL) { + LOG_ERROR("Data packet with unknown epoch", TraceLoggingValue(rx_epoch, "epoch")); + return STATUS_DATA_ERROR; + } + + OvpnCryptoMakeEpochNonce(decryptKey->ImplicitIV, packet_id_net, nonce); + } + else { + RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, 4); + RtlCopyMemory(nonce + 4, decryptKey->ImplicitIV + 4, 8); - BOOLEAN aeadTagEnd = cryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + packet_id = static_cast(RtlUlongByteSwap(*reinterpret_cast(nonce))); + } BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO authInfo; BCRYPT_INIT_AUTH_MODE_INFO(authInfo); authInfo.pbNonce = nonce; authInfo.cbNonce = sizeof(nonce); - authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? len : 0); + authInfo.pbTag = bufIn + OVPN_DATA_V2_LEN + pktidLen + (authTagEnd ? len : 0); authInfo.cbTag = AEAD_AUTH_TAG_LEN; - authInfo.pbAuthData = (encrypt ? bufOut : bufIn); - authInfo.cbAuthData = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4); + authInfo.pbAuthData = bufIn; + authInfo.cbAuthData = OVPN_DATA_V2_LEN + pktidLen; - auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + auto payloadOffset = OVPN_DATA_V2_LEN + pktidLen + (authTagEnd ? 0 : AEAD_AUTH_TAG_LEN); bufOut += payloadOffset; bufIn += payloadOffset; // non-chaining mode ULONG bytesDone = 0; - GOTO_IF_NOT_NT_SUCCESS(done, status, encrypt ? - BCryptEncrypt(keySlot->EncKey, bufIn, (ULONG)len, &authInfo, NULL, 0, bufOut, (ULONG)len, &bytesDone, 0) : - BCryptDecrypt(keySlot->DecKey, bufIn, (ULONG)len, &authInfo, NULL, 0, bufOut, (ULONG)len, &bytesDone, 0) - ); + GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptDecrypt(decryptKey->Key, bufIn, (ULONG)len, &authInfo, NULL, 0, bufOut, (ULONG)len, &bytesDone, 0)); + + status = OvpnCryptoCheckReplay(keySlot, packet_id, rx_epoch, opts, allowRekey); + if (status == STATUS_OVPN_CRYPTO_RETRY) { + return status; + } + + if (!NT_SUCCESS(status)) { + LOG_ERROR("Invalid packet_id", TraceLoggingUInt64(packet_id, "packet_id")); + return STATUS_DATA_ERROR; + } done: return status; } -OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptAEAD; +OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptAEAD; _Use_decl_annotations_ NTSTATUS -OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) +OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, OvpnCryptoOptions* opts, BOOLEAN allowRekey) { - return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut, cryptoOptions); -} + NTSTATUS status = STATUS_SUCCESS; -OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptAEAD; + BOOLEAN authTagEnd = opts->UseEpoch; + ULONG pktidLen = opts->UseEpoch ? 8 : 4; + ULONG cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + pktidLen; -_Use_decl_annotations_ -NTSTATUS -OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions) -{ - return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf, cryptoOptions); + if (len < cryptoOverhead) { + LOG_WARN("Packet too short", TraceLoggingValue(len, "len")); + return STATUS_DATA_ERROR; + } + + // we prepended buf with crypto overhead + len -= cryptoOverhead; + + UINT64 packet_id = 0; + + UCHAR nonce[12]; + + // prepend with opcode, key-id and peer-id + UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, keySlot->KeyId, keySlot->PeerId); + op = RtlUlongByteSwap(op); + RtlCopyMemory(buf, &op, sizeof(op)); + + if (opts->UseEpoch) { + if (keySlot->EpochKeySend.Epoch == UINT16_MAX) { + return STATUS_BUFFER_OVERFLOW; + } + + if (OvpnCryptoAeadUsageLimitReached(opts->AeadUsageLimit, keySlot->Encrypt.PlaintextBlocks, keySlot->PktidXmit.SeqNum) || (keySlot->PktidXmit.SeqNum == PACKET_ID_EPOCH_MAX)) { + if (!allowRekey) { + return STATUS_OVPN_CRYPTO_RETRY; + } + + OvpnCryptoEpochIterateSendKey(keySlot, opts); + } + + // calculate 64-bit packet-id = (epoch << 48) | ctr48 + // the overflow of pktid is checked above + UINT64 ctr48 = InterlockedIncrementNoFence64(&keySlot->PktidXmit.SeqNum) & PACKET_ID_EPOCH_MAX; + packet_id = ((UINT64)keySlot->Encrypt.Epoch << 48) | ctr48; + + // prepend with pktid + UINT64 packet_id_net = RtlUlonglongByteSwap(packet_id); + RtlCopyMemory(buf + OVPN_DATA_V2_LEN, &packet_id_net, sizeof(packet_id_net)); + + OvpnCryptoMakeEpochNonce(keySlot->Encrypt.ImplicitIV, packet_id_net, nonce); + } + else { + // calculate pktid + UINT32 packet_id_32; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &packet_id_32)); + ULONG packet_id_net = RtlUlongByteSwap(packet_id_32); + + // calculate nonce, which is pktid + nonce_tail + RtlCopyMemory(nonce, &packet_id_net, 4); + RtlCopyMemory(nonce + 4, keySlot->Encrypt.ImplicitIV + 4, 8); + + // prepend with pktid + RtlCopyMemory(buf + OVPN_DATA_V2_LEN, &packet_id_net, sizeof(packet_id_net)); + } + + // update number of plaintext blocks encrypted. Use the (x + (n-1))/n trick to round up the result to the number of blocks used + const ULONGLONG blocksize = AEAD_LIMIT_BLOCKSIZE; + ULONGLONG inc = ((ULONGLONG)len + (blocksize - 1)) / blocksize; + InterlockedAdd64((volatile LONG64*)&keySlot->Encrypt.PlaintextBlocks, (LONG64)inc); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO authInfo; + BCRYPT_INIT_AUTH_MODE_INFO(authInfo); + authInfo.pbNonce = nonce; + authInfo.cbNonce = sizeof(nonce); + authInfo.pbTag = buf + OVPN_DATA_V2_LEN + pktidLen + (authTagEnd ? len : 0); + authInfo.cbTag = AEAD_AUTH_TAG_LEN; + authInfo.pbAuthData = buf; + authInfo.cbAuthData = OVPN_DATA_V2_LEN + pktidLen; + + auto payloadOffset = OVPN_DATA_V2_LEN + pktidLen + (authTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + buf += payloadOffset; + + // non-chaining mode + ULONG bytesDone = 0; + GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptEncrypt(keySlot->Encrypt.Key, buf, (ULONG)len, &authInfo, NULL, 0, buf, (ULONG)len, &bytesDone, 0)); + +done: + return status; } _Use_decl_annotations_ NTSTATUS -OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDataV2, BCRYPT_ALG_HANDLE algHandle) +OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDataV2, BCRYPT_ALG_HANDLE algHandle, BCRYPT_ALG_HANDLE hkdfAlgHandle) { OvpnCryptoKeySlot* keySlot = NULL; NTSTATUS status = STATUS_SUCCESS; @@ -285,25 +506,16 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDa return STATUS_INVALID_DEVICE_REQUEST; } - if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID) - { - cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_64BIT_PKTID; - } - if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END) - { - cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_AEAD_TAG_END; - } - if ((cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM) || (cryptoData->CipherAlg == OVPN_CIPHER_ALG_CHACHA20_POLY1305)) { // destroy previous keys - if (keySlot->EncKey) { - BCryptDestroyKey(keySlot->EncKey); - keySlot->EncKey = NULL; + if (keySlot->Encrypt.Key) { + BCryptDestroyKey(keySlot->Encrypt.Key); + keySlot->Encrypt.Key = NULL; } - if (keySlot->DecKey) { - BCryptDestroyKey(keySlot->DecKey); - keySlot->DecKey = NULL; + if (keySlot->Decrypt.Key) { + BCryptDestroyKey(keySlot->Decrypt.Key); + keySlot->Decrypt.Key = NULL; } if ((cryptoData->Encrypt.KeyLen > 32) || (cryptoData->Decrypt.KeyLen > 32)) @@ -314,24 +526,53 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDa goto done; } - // generate keys from key materials - GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(algHandle, &keySlot->EncKey, NULL, 0, cryptoData->Encrypt.Key, cryptoData->Encrypt.KeyLen, 0)); - GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(algHandle, &keySlot->DecKey, NULL, 0, cryptoData->Decrypt.Key, cryptoData->Decrypt.KeyLen, 0)); + cryptoContext->Options.KeyLen = cryptoData->Encrypt.KeyLen; - // copy nonce tails - RtlCopyMemory(keySlot->EncNonceTail, cryptoData->Encrypt.NonceTail, sizeof(cryptoData->Encrypt.NonceTail)); - RtlCopyMemory(keySlot->DecNonceTail, cryptoData->Decrypt.NonceTail, sizeof(cryptoData->Decrypt.NonceTail)); + if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_EPOCH) { + cryptoContext->Options.AeadUsageLimit = OvpnCryptoAeadUsageLimit(cryptoData->CipherAlg); + cryptoContext->Options.UseEpoch = TRUE; + cryptoContext->Options.HkdfAlgHandle = hkdfAlgHandle; + cryptoContext->Options.AeadAlgHangle = algHandle; - cryptoContext->Encrypt = OvpnCryptoEncryptAEAD; - cryptoContext->Decrypt = OvpnCryptoDecryptAEAD; + keySlot->EpochKeySend.Epoch = 1; + RtlCopyMemory(keySlot->EpochKeySend.EpochKey, cryptoData->Encrypt.Key, 32); + + keySlot->EpochKeyRecv.Epoch = 1; + RtlCopyMemory(keySlot->EpochKeyRecv.EpochKey, cryptoData->Decrypt.Key, 32); + + OvpnCryptoEpochInitKey(&keySlot->Encrypt, &keySlot->EpochKeySend, &cryptoContext->Options); + OvpnCryptoEpochInitKey(&keySlot->Decrypt, &keySlot->EpochKeyRecv, &cryptoContext->Options); + + RtlZeroMemory(keySlot->FutureEpochKeys, sizeof(keySlot->FutureEpochKeys)); + OvpnCryptoEpochGenerateFutureRecvKeys(keySlot, &cryptoContext->Options); + } + else { + // generate keys from key materials + GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(algHandle, &keySlot->Encrypt.Key, NULL, 0, cryptoData->Encrypt.Key, cryptoData->Encrypt.KeyLen, 0)); + GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(algHandle, &keySlot->Decrypt.Key, NULL, 0, cryptoData->Decrypt.Key, cryptoData->Decrypt.KeyLen, 0)); + + // copy nonce tails + RtlCopyMemory(keySlot->Encrypt.ImplicitIV + 4, cryptoData->Encrypt.NonceTail, sizeof(cryptoData->Encrypt.NonceTail)); + RtlCopyMemory(keySlot->Decrypt.ImplicitIV + 4, cryptoData->Decrypt.NonceTail, sizeof(cryptoData->Decrypt.NonceTail)); + } keySlot->KeyId = cryptoData->KeyId; keySlot->PeerId = cryptoData->PeerId; + cryptoContext->Encrypt = OvpnCryptoEncryptAEAD; + cryptoContext->Decrypt = OvpnCryptoDecryptAEAD; + LOG_INFO("New key", TraceLoggingValue(cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM ? "aes-gcm" : "chacha20-poly1305", "alg"), TraceLoggingValue(cryptoData->KeyId, "KeyId"), TraceLoggingValue(cryptoData->PeerId, "PeerId")); } else if (cryptoData->CipherAlg == OVPN_CIPHER_ALG_NONE) { + OvpnCryptoEpochUninitSlot(keySlot); + + keySlot->KeyId = cryptoData->KeyId; + keySlot->PeerId = cryptoData->PeerId; + + RtlZeroMemory(&cryptoContext->Options, sizeof(cryptoContext->Options)); + cryptoContext->Encrypt = OvpnCryptoEncryptNone; cryptoContext->Decrypt = OvpnCryptoDecryptNone; @@ -343,6 +584,7 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDa goto done; } + OvpnCryptoDescribePacketLayout(cryptoContext, &cryptoContext->Layout); // reset pktid for a new key RtlZeroMemory(&keySlot->PktidXmit, sizeof(keySlot->PktidXmit)); RtlZeroMemory(&keySlot->PktidRecv, sizeof(keySlot->PktidRecv)); @@ -351,6 +593,32 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDa return status; } +_Use_decl_annotations_ +VOID +OvpnCryptoDescribePacketLayout(const OvpnCryptoContext* cryptoContext, OvpnCryptoPacketLayout* layout) +{ + NT_ASSERT(layout != NULL); + + layout->FrontLen = OVPN_DATA_V2_LEN + 4; + layout->TailLen = 0; + + if ((cryptoContext == NULL) || (cryptoContext->Encrypt == NULL)) { + return; + } + + if (cryptoContext->Encrypt == OvpnCryptoEncryptAEAD) { + BOOLEAN useEpoch = cryptoContext->Options.UseEpoch; + + layout->FrontLen = OVPN_DATA_V2_LEN + (useEpoch ? 8 : 4); + if (!useEpoch) { + layout->FrontLen += AEAD_AUTH_TAG_LEN; + } + + layout->TailLen = useEpoch ? AEAD_AUTH_TAG_LEN : 0; + } +} + + _Use_decl_annotations_ OvpnCryptoKeySlot* OvpnCryptoKeySlotFromKeyId(OvpnCryptoContext* cryptoContext, unsigned int keyId) @@ -383,21 +651,8 @@ _Use_decl_annotations_ VOID OvpnCryptoUninit(OvpnCryptoContext* cryptoContext) { - if (cryptoContext->Primary.EncKey) { - BCryptDestroyKey(cryptoContext->Primary.EncKey); - } - - if (cryptoContext->Primary.DecKey) { - BCryptDestroyKey(cryptoContext->Primary.DecKey); - } - - if (cryptoContext->Secondary.EncKey) { - BCryptDestroyKey(cryptoContext->Secondary.EncKey); - } - - if (cryptoContext->Secondary.DecKey) { - BCryptDestroyKey(cryptoContext->Secondary.DecKey); - } + OvpnCryptoEpochUninitSlot(&cryptoContext->Primary); + OvpnCryptoEpochUninitSlot(&cryptoContext->Secondary); RtlZeroMemory(cryptoContext, sizeof(OvpnCryptoContext)); } diff --git a/crypto.h b/crypto.h index 73d286b..636faae 100644 --- a/crypto.h +++ b/crypto.h @@ -25,40 +25,35 @@ #include #include +#include "crypto_epoch.h" #include "pktid.h" #include "uapi\ovpn-dco.h" #include "socket.h" +struct OvpnPeerContext; + #define OVPN_DATA_V2_LEN 4 #define AEAD_AUTH_TAG_LEN 16 +#define AEAD_LIMIT_BLOCKSIZE 16 + +// The crypto helper uses this failure status to indicate that the caller must +// retry the operation while holding the peer spinlock exclusively so key-slot +// mutation can proceed safely. +#define STATUS_OVPN_CRYPTO_RETRY ((NTSTATUS)0xC0E44001L) + // packet opcode (high 5 bits) and key-id (low 3 bits) are combined in one byte #define OVPN_OP_DATA_V2 9 #define OVPN_KEY_ID_MASK 0x07 #define OVPN_OPCODE_SHIFT 3 #define OVPN_PEER_ID_MASK 0x00FFFFFF -struct OvpnCryptoKeySlot -{ - BCRYPT_KEY_HANDLE EncKey; - BCRYPT_KEY_HANDLE DecKey; - - UCHAR EncNonceTail[8]; - UCHAR DecNonceTail[8]; - - UCHAR KeyId; - INT32 PeerId; - - OvpnPktidXmit PktidXmit; - OvpnPktidRecv PktidRecv; -}; - _Function_class_(OVPN_CRYPTO_ENCRYPT) _IRQL_requires_max_(DISPATCH_LEVEL) _Must_inspect_result_ typedef NTSTATUS -OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len, _In_ INT32 CryptoOptions); +OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len, _In_ OvpnCryptoOptions* opts, BOOLEAN allowRekey); typedef OVPN_CRYPTO_ENCRYPT* POVPN_CRYPTO_ENCRYPT; _Function_class_(OVPN_CRYPTO_DECRYPT) @@ -66,9 +61,15 @@ _IRQL_requires_max_(DISPATCH_LEVEL) _Must_inspect_result_ typedef NTSTATUS -OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut, _In_ INT32 CryptoOptions); +OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut, _In_ OvpnCryptoOptions* opts, BOOLEAN allowRekey); typedef OVPN_CRYPTO_DECRYPT* POVPN_CRYPTO_DECRYPT; +struct OvpnCryptoPacketLayout +{ + ULONG FrontLen; + ULONG TailLen; +}; + struct OvpnCryptoContext { OvpnCryptoKeySlot Primary; @@ -77,24 +78,75 @@ struct OvpnCryptoContext POVPN_CRYPTO_ENCRYPT Encrypt; POVPN_CRYPTO_DECRYPT Decrypt; - INT32 CryptoOptions; + OvpnCryptoOptions Options; + OvpnCryptoPacketLayout Layout; +}; + + +VOID +OvpnCryptoDescribePacketLayout(_In_ const OvpnCryptoContext* cryptoContext, _Out_ OvpnCryptoPacketLayout* layout); + +typedef +NTSTATUS +OVPN_CRYPTO_RETRY_ROUTINE(_In_ OvpnCryptoContext* cryptoContext, _In_ BOOLEAN allowRekey, _Inout_opt_ PVOID context); +typedef OVPN_CRYPTO_RETRY_ROUTINE* POVPN_CRYPTO_RETRY_ROUTINE; + +struct OvpnCryptoEncryptParams +{ + PUCHAR Buffer; + SIZE_T Length; +}; + +struct OvpnCryptoDecryptParams +{ + UCHAR KeyId; + PUCHAR CipherText; + SIZE_T Length; + PUCHAR PlainText; }; +_Must_inspect_result_ +_IRQL_requires_max_(DISPATCH_LEVEL) +NTSTATUS +OvpnCryptoCallWithRetry( + _In_ OvpnPeerContext* peer, + _In_ BOOLEAN atDpcLevel, + _Inout_opt_ PBOOLEAN exclusive, + _Inout_opt_ PKIRQL kirql, + _In_ POVPN_CRYPTO_RETRY_ROUTINE routine, + _Inout_opt_ PVOID context); + +_Must_inspect_result_ +_IRQL_requires_max_(DISPATCH_LEVEL) +NTSTATUS +OvpnCryptoInvokeEncrypt( + _In_ OvpnCryptoContext* cryptoContext, + _In_ BOOLEAN allowRekey, + _Inout_opt_ PVOID context); + +_Must_inspect_result_ +_IRQL_requires_max_(DISPATCH_LEVEL) +NTSTATUS +OvpnCryptoInvokeDecrypt( + _In_ OvpnCryptoContext* cryptoContext, + _In_ BOOLEAN allowRekey, + _Inout_opt_ PVOID context); + _Must_inspect_result_ _IRQL_requires_(PASSIVE_LEVEL) NTSTATUS -OvpnCryptoInitAlgHandles(_Outptr_ BCRYPT_ALG_HANDLE* aesAlgHandle, _Outptr_ BCRYPT_ALG_HANDLE* chachaAlgHandle); +OvpnCryptoInitAlgHandles(_Outptr_ BCRYPT_ALG_HANDLE* aesAlgHandle, _Outptr_ BCRYPT_ALG_HANDLE* chachaAlgHandle, _Outptr_ BCRYPT_ALG_HANDLE* hkdfAlgHandle); _IRQL_requires_(PASSIVE_LEVEL) VOID -OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDLE chachaAlgHandle); +OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDLE chachaAlgHandle, BCRYPT_ALG_HANDLE hkdfAlgHandle); VOID OvpnCryptoUninit(_In_ OvpnCryptoContext* cryptoContext); _Must_inspect_result_ NTSTATUS -OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA_V2 cryptoData, _In_opt_ BCRYPT_ALG_HANDLE algHandle); +OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA_V2 cryptoData, _In_opt_ BCRYPT_ALG_HANDLE algHandle, _In_opt_ BCRYPT_ALG_HANDLE hkdfAlgHandle); _Must_inspect_result_ OvpnCryptoKeySlot* @@ -115,3 +167,30 @@ UCHAR OvpnCryptoOpcodeExtract(UCHAR op) { return op >> OVPN_OPCODE_SHIFT; } + +static inline +BOOLEAN +OvpnCryptoAeadUsageLimitReached(UINT64 limit, UINT64 plaintextBlocks, UINT64 highestPid) +{ + /* This is the q + s <= p^(1/2) * 2^(129/2) - 1 calculation where + * q is the number of protected messages (highest_pid) + * s Total plaintext length in all messages (in blocks) */ + return ((limit > 0) && (plaintextBlocks + highestPid) > limit); +} + +static inline +UINT64 +OvpnCryptoAeadUsageLimit(OVPN_CIPHER_ALG alg) +{ + switch (alg) + { + case OVPN_CIPHER_ALG_NONE: + return 0; + + case OVPN_CIPHER_ALG_CHACHA20_POLY1305: + return 0; + + default: + return (1ull << 36) - 1; // limit for AES-GCM + } +} diff --git a/crypto_epoch.cpp b/crypto_epoch.cpp new file mode 100644 index 0000000..c62a06e --- /dev/null +++ b/crypto_epoch.cpp @@ -0,0 +1,318 @@ +/* + * ovpn-dco-win OpenVPN protocol accelerator for Windows + * + * Copyright (C) 2025- OpenVPN Inc + * + * Author: Lev Stipakov + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "crypto_epoch.h" +#include "trace.h" + +// Derive bytes via HKDF-Expand with PRK = E_i, using bcrypt HKDF. +// info = OvpnMakeLabel(L, label) +_Use_decl_annotations_ +NTSTATUS OvpnCryptoExpandLabel( + BCRYPT_ALG_HANDLE hkdfAlg, + const UCHAR* E_i, // PRK (32 bytes for SHA-256) + USHORT outLen, // bytes to derive + const char* label, // "data_key" / "data_iv" / "datakey upd" + UCHAR* outBytes +) +{ + NTSTATUS status = STATUS_SUCCESS; + + BCRYPT_KEY_HANDLE hKey = NULL; + + // create key handle with PRK bytes + GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(hkdfAlg, &hKey, NULL, 0, (PUCHAR)E_i, 32, 0)); + + // select SHA-256 + // BCRYPT_SHA256_ALGORITHM is a wide literal; sizeof(..) includes the NUL in bytes. + GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptSetProperty(hKey, BCRYPT_HKDF_HASH_ALGORITHM, (PUCHAR)BCRYPT_SHA256_ALGORITHM, (ULONG)sizeof(BCRYPT_SHA256_ALGORITHM), 0)); + + // tell HKDF we're already supplying the PRK in the key handle: + // passing NULL,0 just switches to "PRK is finalized" mode. + GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptSetProperty(hKey, BCRYPT_HKDF_PRK_AND_FINALIZE, NULL, 0, 0)); + + // build info = OvpnLabel(outLen, "data_key"/"data_iv") + UCHAR info[2 + 1 + 64 + 1 + 255]; // enough for our labels + ULONG infoLen = 0; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoMakeLabel(info, (ULONG)sizeof(info), &infoLen, outLen, label)); + + // prepare KDF params + BCryptBuffer infoBuf; + BCryptBufferDesc desc; + + RtlZeroMemory(&infoBuf, sizeof(infoBuf)); + RtlZeroMemory(&desc, sizeof(desc)); + + infoBuf.cbBuffer = infoLen; + infoBuf.BufferType = KDF_HKDF_INFO; + infoBuf.pvBuffer = info; + + desc.ulVersion = BCRYPTBUFFER_VERSION; + desc.cBuffers = 1; + desc.pBuffers = &infoBuf; + + // derive + ULONG got = 0; + GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptKeyDerivation(hKey, &desc, outBytes, outLen, &got, 0)); + if (got != outLen) { + status = STATUS_INTERNAL_ERROR; + } + +done: + if (hKey) BCryptDestroyKey(hKey); + return status; +} + +// Build TLS 1.3-style label with "ovpn " prefix, into caller buffer. +// struct { +// uint16 length = L; +// opaque label<6..255> = "ovpn " + Label; +// opaque context<0..255>; +// } OvpnLabel; +_Use_decl_annotations_ +NTSTATUS OvpnCryptoMakeLabel( + UCHAR* out, + ULONG cbOut, + ULONG* pcbWritten, + USHORT L, + const char* label) +{ + NTSTATUS status = STATUS_SUCCESS; + + static const char prefix[] = "ovpn "; + + const size_t prefixLen = sizeof(prefix) - 1; + size_t labelLen = 0; + + GOTO_IF_NOT_NT_SUCCESS(done, status, RtlStringCbLengthA(label, 256, &labelLen)); // labels are tiny, 256 is safe + + *pcbWritten = 0; + + // Total encoded label length = "ovpn " + label + size_t totalLabelLen = prefixLen + labelLen; + if (totalLabelLen < 6 || totalLabelLen > 255) { + status = STATUS_INVALID_PARAMETER; + goto done; + } + + // total = 2(length) + 1(totalLabelLen) + totalLabelLen + 1(ctxLen=0) + ULONG need = 2 + 1 + (ULONG)totalLabelLen + 1; + if (cbOut < need) { + status = STATUS_BUFFER_TOO_SMALL; + goto done; + } + + ULONG p = 0; + out[p++] = (UCHAR)(L >> 8); + out[p++] = (UCHAR)(L & 0xFF); + out[p++] = (UCHAR)totalLabelLen; + + // "ovpn " + RtlCopyMemory(out + p, prefix, prefixLen); + p += (ULONG)prefixLen; + + + // Label + RtlCopyMemory(out + p, label, labelLen); + p += (ULONG)labelLen; + + // context length = 0 + out[p++] = 0; + + *pcbWritten = p; + +done: + return status; +} + +VOID +OvpnCryptoEpochInitKey(OvpnCryptoKeyContext* ctx, OvpnCryptoEpochKey* epochKey, OvpnCryptoOptions* opts) +{ + LOG_INFO("Epoch Data Key", TraceLoggingValue(epochKey->Epoch, "epoch")); + + OvpnCryptoKeyParameters key{0}; + OvpnCryptoEpochDataKeyDerive(&key, epochKey, opts->HkdfAlgHandle, opts->AeadAlgHangle, opts->KeyLen); + ctx->Epoch = key.Epoch; + ctx->Key = key.KeyHandle; + RtlCopyMemory(ctx->ImplicitIV, key.IV, sizeof(ctx->ImplicitIV)); + + RtlSecureZeroMemory(&key, sizeof(OvpnCryptoKeyParameters)); +} + +NTSTATUS +OvpnCryptoEpochDataKeyDerive(OvpnCryptoKeyParameters* key, OvpnCryptoEpochKey* epochKey, BCRYPT_ALG_HANDLE hkdfAlgHandle, BCRYPT_ALG_HANDLE algHandle, UCHAR cipherSize) +{ + NTSTATUS status; + + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoExpandLabel(hkdfAlgHandle, epochKey->EpochKey, cipherSize, "data_key", key->Cipher)); + GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(algHandle, &key->KeyHandle, NULL, 0, key->Cipher, cipherSize, 0)); + + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoExpandLabel(hkdfAlgHandle, epochKey->EpochKey, 12, "data_iv", key->IV)); + + key->Epoch = epochKey->Epoch; + +done: + return status; +} + +VOID +OvpnCryptoEpochKeyIterate(OvpnCryptoEpochKey* epochKey, BCRYPT_ALG_HANDLE hkdfAlgHandle) +{ + ++epochKey->Epoch; + OvpnCryptoExpandLabel(hkdfAlgHandle, epochKey->EpochKey, sizeof(epochKey->EpochKey), "datakey upd", epochKey->EpochKey); +} + +VOID +OvpnCryptoEpochGenerateFutureRecvKeys(OvpnCryptoKeySlot* keySlot, OvpnCryptoOptions* opts) +{ + UINT16 currentDecryptEpoch = keySlot->Decrypt.Epoch; + + // free unused keys + for (int i = 0; i < FUTURE_EPOCH_KEYS_COUNT; ++i) { + auto key = &keySlot->FutureEpochKeys[i]; + if ((key->Epoch > 0) && (key->Epoch < currentDecryptEpoch)) { + BCryptDestroyKey(key->Key); + RtlZeroMemory(key, sizeof(*key)); + } + } + + auto highestFutureKey = &keySlot->FutureEpochKeys[FUTURE_EPOCH_KEYS_COUNT - 1]; + + UINT16 currentHighestKey = highestFutureKey->Epoch ? highestFutureKey->Epoch : 1; + UINT16 desiredHighestKey = currentDecryptEpoch + FUTURE_EPOCH_KEYS_COUNT; + UINT16 numKeysGenerate = desiredHighestKey - currentHighestKey; + + RtlMoveMemory(keySlot->FutureEpochKeys, &keySlot->FutureEpochKeys[numKeysGenerate], (FUTURE_EPOCH_KEYS_COUNT - numKeysGenerate) * sizeof(OvpnCryptoKeyContext)); + + for (int i = 16 - numKeysGenerate; i < FUTURE_EPOCH_KEYS_COUNT; ++i) + { + RtlSecureZeroMemory(&keySlot->FutureEpochKeys[i], sizeof(OvpnCryptoKeyContext)); + + OvpnCryptoEpochKeyIterate(&keySlot->EpochKeyRecv, opts->HkdfAlgHandle); + OvpnCryptoEpochInitKey(&keySlot->FutureEpochKeys[i], &keySlot->EpochKeyRecv, opts); + } +} + +VOID +OvpnCryptoEpochReplaceUpdateRecvKey(OvpnCryptoKeySlot* keySlot, UINT16 new_epoch, OvpnCryptoOptions* opts) +{ + // Find the key of the new epoch in future keys + UINT16 fki; + for (fki = 0; fki < FUTURE_EPOCH_KEYS_COUNT; fki++) { + if (keySlot->FutureEpochKeys[fki].Epoch == new_epoch) { + break; + } + } + + OvpnCryptoKeyContext* ctx = &keySlot->FutureEpochKeys[fki]; + + // Check if the new recv key epoch is higher than the send key epoch. If yes we will replace the send key as well + if (keySlot->Encrypt.Epoch < new_epoch) { + BCryptDestroyKey(keySlot->Encrypt.Key); + RtlZeroMemory(&keySlot->Encrypt, sizeof(OvpnCryptoKeyContext)); + + // Update the epoch_key for send to match the current key being used + while (keySlot->EpochKeySend.Epoch < new_epoch) { + OvpnCryptoEpochKeyIterate(&keySlot->EpochKeySend, opts->HkdfAlgHandle); + } + OvpnCryptoEpochInitKey(&keySlot->Encrypt, &keySlot->EpochKeySend, opts); + } + + // Replace receive key + BCryptDestroyKey(keySlot->RetiringEpochDataReceiveKey.Key); + RtlZeroMemory(&keySlot->RetiringEpochDataReceiveKey, sizeof(OvpnCryptoKeyContext)); + + keySlot->RetiringEpochDataReceiveKey = keySlot->Decrypt; + + keySlot->Decrypt = *ctx; + + RtlZeroMemory(ctx, sizeof(*ctx)); + + // Generate new future keys + OvpnCryptoEpochGenerateFutureRecvKeys(keySlot, opts); +} + +OvpnCryptoKeyContext* +OvpnCryptoEpochLookupDecryptKey(OvpnCryptoKeySlot* keySlot, UINT16 epoch) +{ + /* Current decrypt key is the most likely one */ + if (keySlot->Decrypt.Epoch == epoch) { + return &keySlot->Decrypt; + } + else if (keySlot->RetiringEpochDataReceiveKey.Epoch && keySlot->RetiringEpochDataReceiveKey.Epoch == epoch) { + return &keySlot->RetiringEpochDataReceiveKey; + } + else if (epoch > keySlot->Decrypt.Epoch && epoch <= keySlot->Decrypt.Epoch + FUTURE_EPOCH_KEYS_COUNT) { + // Key in the range of future keys + int index = epoch - (keySlot->Decrypt.Epoch + 1); + + if (epoch > (UINT16_MAX - FUTURE_EPOCH_KEYS_COUNT - 1)) { + return NULL; + } + else { + return &keySlot->FutureEpochKeys[index]; + } + } + else { + return NULL; + } +} + +VOID +OvpnCryptoMakeEpochNonce(UCHAR* epochIv, UINT64 packet_id_net, UCHAR* nonce) +{ + // first 8 bytes of IV (aka nonce) is pktid_net XOR 8 bytes of epoch's implicit IV + UINT64 iv0; + RtlCopyMemory(&iv0, epochIv, sizeof(iv0)); + iv0 ^= packet_id_net; + RtlCopyMemory(nonce, &iv0, sizeof(iv0)); + + // last 4 bytes of IV are from epoch's implicit IV + RtlCopyMemory(nonce + 8, epochIv + 8, 4); +} + +VOID +OvpnCryptoEpochIterateSendKey(OvpnCryptoKeySlot* keySlot, OvpnCryptoOptions* opts) +{ + OvpnCryptoEpochKeyIterate(&keySlot->EpochKeySend, opts->HkdfAlgHandle); + + BCryptDestroyKey(keySlot->Encrypt.Key); + RtlSecureZeroMemory(&keySlot->Encrypt, sizeof(OvpnCryptoKeyContext)); + OvpnCryptoEpochInitKey(&keySlot->Encrypt, &keySlot->EpochKeySend, opts); + + RtlZeroMemory(&keySlot->PktidXmit, sizeof(keySlot->PktidXmit)); +} + +VOID +OvpnCryptoEpochUninitSlot(OvpnCryptoKeySlot* slot) +{ + if (slot->Encrypt.Key) { + BCryptDestroyKey(slot->Encrypt.Key); + } + for (int i = 0; i < FUTURE_EPOCH_KEYS_COUNT; ++i) { + if (slot->FutureEpochKeys[i].Key) { + BCryptDestroyKey(slot->FutureEpochKeys[i].Key); + } + } + if (slot->RetiringEpochDataReceiveKey.Key) { + BCryptDestroyKey(slot->RetiringEpochDataReceiveKey.Key); + } + RtlSecureZeroMemory(slot, sizeof(OvpnCryptoKeySlot)); +} \ No newline at end of file diff --git a/crypto_epoch.h b/crypto_epoch.h new file mode 100644 index 0000000..808bcc3 --- /dev/null +++ b/crypto_epoch.h @@ -0,0 +1,157 @@ +/* + * ovpn-dco-win OpenVPN protocol accelerator for Windows + * + * Copyright (C) 2025- OpenVPN Inc + * + * Author: Lev Stipakov + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#if defined(_KERNEL_MODE) +#include +#include +#include "trace.h" +#else +#define WIN32_NO_STATUS // keep windows.h from redefining STATUS_* +#include +#undef WIN32_NO_STATUS +#include +#include // exposes STATUS_SUCCESS, NT_SUCCESS, etc. +#include +#define RtlStringCbLengthA StringCbLengthA +#ifndef UINT16_MAX +#define UINT16_MAX 0xFFFF +#endif +#endif + +#include + +#include "pktid.h" + +#define PACKET_ID_EPOCH_MAX 0x0000FFFFFFFFFFFFull + +#define FUTURE_EPOCH_KEYS_COUNT 16 + +struct OvpnCryptoKeyContext +{ + BCRYPT_KEY_HANDLE Key; + UCHAR ImplicitIV[12]; + + // number of plaintext blocks encrypted using this key + UINT64 PlaintextBlocks; + UINT16 Epoch; +}; + +struct OvpnCryptoEpochKey +{ + UCHAR EpochKey[32]; + UINT16 Epoch; +}; + +struct OvpnCryptoKeyParameters +{ + UCHAR Cipher[32]; + UCHAR IV[12]; + BCRYPT_KEY_HANDLE KeyHandle; + UINT16 Epoch; +}; + +struct OvpnCryptoOptions { + // Limit for AEAD cipher, sum of packets + blocks. Will switch to the new epoch when reached. + UINT64 AeadUsageLimit; + + BOOLEAN UseEpoch; + + UCHAR KeyLen; + + BCRYPT_ALG_HANDLE HkdfAlgHandle; + BCRYPT_ALG_HANDLE AeadAlgHangle; +}; + +NTSTATUS OvpnCryptoExpandLabel( + BCRYPT_ALG_HANDLE hkdfAlg, + _In_reads_bytes_(32) const UCHAR* E_i, // PRK (32 bytes for SHA-256) + _In_ USHORT outLen, // bytes to derive + _In_z_ const char* label, // "data_key" / "data_iv" / "datakey upd" + _Out_writes_bytes_(outLen) UCHAR* outBytes +); + +_IRQL_requires_max_(PASSIVE_LEVEL) +static +NTSTATUS OvpnCryptoMakeLabel( + _Out_writes_bytes_to_(cbOut, *pcbWritten) UCHAR* out, + _In_ ULONG cbOut, + _Out_ ULONG* pcbWritten, + _In_ USHORT L, + _In_z_ const char* label); + +struct OvpnCryptoKeySlot +{ + OvpnCryptoKeyContext Encrypt; + OvpnCryptoKeyContext Decrypt; + + // last epoch key used for generating current send data keys + OvpnCryptoEpochKey EpochKeySend; + + // epoch key used for the highest receive epoch keys + OvpnCryptoEpochKey EpochKeyRecv; + + UCHAR KeyId; + INT32 PeerId; + + OvpnPktidXmit PktidXmit; + OvpnPktidRecv PktidRecv; + + // future epoch data keys for decryption + OvpnCryptoKeyContext FutureEpochKeys[FUTURE_EPOCH_KEYS_COUNT]; + + OvpnPktidRecv PktidRecvRetiring; + OvpnCryptoKeyContext RetiringEpochDataReceiveKey; +}; + +// Initialises data channel key/IV using the provided epoch key +VOID +OvpnCryptoEpochInitKey(OvpnCryptoKeyContext* ctx, OvpnCryptoEpochKey* epochKey, OvpnCryptoOptions* opts); + +// Generates a data channel key/IV from the epoch key +NTSTATUS +OvpnCryptoEpochDataKeyDerive(OvpnCryptoKeyParameters* key, OvpnCryptoEpochKey* epochKey, BCRYPT_ALG_HANDLE hkdfAlgHandle, BCRYPT_ALG_HANDLE algHandle, UCHAR cipherSize); + +VOID +OvpnCryptoEpochKeyIterate(OvpnCryptoEpochKey* epochKey, BCRYPT_ALG_HANDLE hkdfAlgHandle); +/** + * Generates and fills the FutureEpochKeys with next valid future keys + * using the epoch of the key in keySlot->EpochKeyRecv as starting point + */ +VOID +OvpnCryptoEpochGenerateFutureRecvKeys(OvpnCryptoKeySlot* keySlot, OvpnCryptoOptions* opts); + +// This is called when the peer uses a new send key that is not the default key +VOID +OvpnCryptoEpochReplaceUpdateRecvKey(OvpnCryptoKeySlot* keySlot, UINT16 new_epoch, OvpnCryptoOptions* opts); + +// retrieve decryption key context that matches the epoch +OvpnCryptoKeyContext* +OvpnCryptoEpochLookupDecryptKey(OvpnCryptoKeySlot* keySlot, UINT16 epoch); + +VOID +OvpnCryptoMakeEpochNonce(UCHAR* epochIv, UINT64 packet_id_net, UCHAR* nonce); + +// Updates the send key and keySlot->EpochKeySend to use the next epoch +VOID +OvpnCryptoEpochIterateSendKey(OvpnCryptoKeySlot* keySlot, OvpnCryptoOptions* opts); + +VOID +OvpnCryptoEpochUninitSlot(OvpnCryptoKeySlot* slot); \ No newline at end of file diff --git a/ovpn-dco-win.vcxproj b/ovpn-dco-win.vcxproj index bcfff8e..fbadd69 100644 --- a/ovpn-dco-win.vcxproj +++ b/ovpn-dco-win.vcxproj @@ -71,6 +71,7 @@ + @@ -87,6 +88,7 @@ + @@ -98,7 +100,6 @@ - diff --git a/ovpn-dco-win.vcxproj.filters b/ovpn-dco-win.vcxproj.filters index adf6018..c1cf7e4 100644 --- a/ovpn-dco-win.vcxproj.filters +++ b/ovpn-dco-win.vcxproj.filters @@ -52,9 +52,6 @@ Header Files - - Header Files - Header Files\uapi @@ -76,6 +73,9 @@ Header Files + + Header Files + @@ -120,6 +120,9 @@ Source Files + + Source Files + diff --git a/peer.cpp b/peer.cpp index 41fe3d9..4a61124 100644 --- a/peer.cpp +++ b/peer.cpp @@ -250,13 +250,7 @@ OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId, BOOLEAN dpc) OvpnPeerContext* peer = nullptr; OvpnPeerContext** ptr = nullptr; - KIRQL kirql = 0; - if (dpc) { - ExAcquireSpinLockSharedAtDpcLevel(&device->SpinLock); - } - else { - kirql = ExAcquireSpinLockShared(&device->SpinLock); - } + KIRQL kirql = OvpnAcquireSpinLock(dpc, &device->SpinLock, FALSE); if (device->Mode == OVPN_MODE_P2P) { ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); @@ -275,12 +269,7 @@ OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId, BOOLEAN dpc) InterlockedIncrement(&peer->RefCounter); } - if (dpc) { - ExReleaseSpinLockSharedFromDpcLevel(&device->SpinLock); - } - else { - ExReleaseSpinLockShared(&device->SpinLock, kirql); - } + OvpnReleaseSpinLock(dpc, kirql, &device->SpinLock, FALSE); return peer; } @@ -292,13 +281,7 @@ OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr, BOOLEAN dpc) OvpnPeerContext* peer = nullptr; OvpnPeerContext** ptr = nullptr; - KIRQL kirql = 0; - if (dpc) { - ExAcquireSpinLockSharedAtDpcLevel(&device->SpinLock); - } - else { - kirql = ExAcquireSpinLockShared(&device->SpinLock); - } + KIRQL kirql = OvpnAcquireSpinLock(dpc, &device->SpinLock, FALSE); if (device->Mode == OVPN_MODE_P2P) { ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); @@ -316,12 +299,7 @@ OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr, BOOLEAN dpc) InterlockedIncrement(&peer->RefCounter); } - if (dpc) { - ExReleaseSpinLockSharedFromDpcLevel(&device->SpinLock); - } - else { - ExReleaseSpinLockShared(&device->SpinLock, kirql); - } + OvpnReleaseSpinLock(dpc, kirql, &device->SpinLock, FALSE); return peer; } @@ -333,13 +311,7 @@ OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr, BOOLEAN dpc) OvpnPeerContext* peer = nullptr; OvpnPeerContext** ptr = nullptr; - KIRQL kirql = 0; - if (dpc) { - ExAcquireSpinLockSharedAtDpcLevel(&device->SpinLock); - } - else { - kirql = ExAcquireSpinLockShared(&device->SpinLock); - } + KIRQL kirql = OvpnAcquireSpinLock(dpc, &device->SpinLock, FALSE); if (device->Mode == OVPN_MODE_P2P) { ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); @@ -357,12 +329,7 @@ OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr, BOOLEAN dpc) InterlockedIncrement(&peer->RefCounter); } - if (dpc) { - ExReleaseSpinLockSharedFromDpcLevel(&device->SpinLock); - } - else { - ExReleaseSpinLockShared(&device->SpinLock, kirql); - } + OvpnReleaseSpinLock(dpc, kirql, &device->SpinLock, FALSE); return peer; } @@ -377,13 +344,7 @@ OvpnFindPeerTransport(POVPN_DEVICE device, PSOCKADDR sa, BOOLEAN dpc) OvpnPeerContext* peer = nullptr; OvpnPeerContext** ptr = nullptr; - KIRQL kirql = 0; - if (dpc) { - ExAcquireSpinLockSharedAtDpcLevel(&device->SpinLock); - } - else { - kirql = ExAcquireSpinLockShared(&device->SpinLock); - } + KIRQL kirql = OvpnAcquireSpinLock(dpc, &device->SpinLock, FALSE); if (device->Mode == OVPN_MODE_P2P) { ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); @@ -404,12 +365,7 @@ OvpnFindPeerTransport(POVPN_DEVICE device, PSOCKADDR sa, BOOLEAN dpc) InterlockedIncrement(&peer->RefCounter); } - if (dpc) { - ExReleaseSpinLockSharedFromDpcLevel(&device->SpinLock); - } - else { - ExReleaseSpinLockShared(&device->SpinLock, kirql); - } + OvpnReleaseSpinLock(dpc, kirql, &device->SpinLock, FALSE); return peer; } @@ -930,7 +886,7 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) } RtlCopyMemory(&cryptoDataV2.V1, cryptoData, sizeof(OVPN_CRYPTO_DATA)); - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, &cryptoDataV2, algHandle)); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, &cryptoDataV2, algHandle, NULL)); done: if (peer != nullptr) { @@ -965,7 +921,7 @@ OvpnPeerNewKeyV2(POVPN_DEVICE device, WDFREQUEST request) } KIRQL irql = ExAcquireSpinLockExclusive(&peer->SpinLock); - LOG_IF_NOT_NT_SUCCESS(status = OvpnCryptoNewKey(&peer->CryptoContext, cryptoDataV2, algHandle)); + LOG_IF_NOT_NT_SUCCESS(status = OvpnCryptoNewKey(&peer->CryptoContext, cryptoDataV2, algHandle, device->HkdfAlgHandle)); ExReleaseSpinLockExclusive(&peer->SpinLock, irql); done: @@ -1133,20 +1089,12 @@ OvpnPeerHandleFloat(OVPN_DEVICE* device, OvpnPeerContext *peer, PSOCKADDR sa, BO TraceLoggingIPv6Address(&peer->TransportAddrs.Remote.IPv6.sin6_addr, "src"), TraceLoggingIPv6Address(&((SOCKADDR_IN6*)sa)->sin6_addr, "dst")); - KIRQL kirql = 0; - // remove peer from by-transport-address hashtable OvpnDeletePeerFromTable(device, &device->PeersByTransport, peer, "transport"); // modify peer's transport address { - // exclusive-lock peer - if (dpc) { - ExAcquireSpinLockExclusiveAtDpcLevel(&peer->SpinLock); - } - else { - kirql = ExAcquireSpinLockExclusive(&peer->SpinLock); - } + KIRQL kirql = OvpnAcquireSpinLock(dpc, &peer->SpinLock, TRUE); // update peer's transport address if (sa->sa_family == AF_INET) @@ -1154,13 +1102,7 @@ OvpnPeerHandleFloat(OVPN_DEVICE* device, OvpnPeerContext *peer, PSOCKADDR sa, BO else RtlCopyMemory(&peer->TransportAddrs.Remote.IPv6, sa, sizeof(SOCKADDR_IN6)); - // exclusive-unlock peer - if (dpc) { - ExReleaseSpinLockExclusiveFromDpcLevel(&peer->SpinLock); - } - else { - ExReleaseSpinLockExclusive(&peer->SpinLock, kirql); - } + OvpnReleaseSpinLock(dpc, kirql, &peer->SpinLock, TRUE); } // add peer back to by-transport-address hashtable diff --git a/peer.h b/peer.h index ac5e707..75505b9 100644 --- a/peer.h +++ b/peer.h @@ -182,3 +182,48 @@ OvpnPeerGetDelReasonString(OVPN_DEL_PEER_REASON reason); NTSTATUS OvpnPeerHandleFloat(OVPN_DEVICE* device, OvpnPeerContext* peer, PSOCKADDR sa, BOOLEAN dpc); + +static inline KIRQL +OvpnAcquireSpinLock(BOOLEAN dpc, PEX_SPIN_LOCK spinLock, BOOLEAN exclusive) +{ + KIRQL kirql = 0; + if (dpc) { + if (exclusive) { + ExAcquireSpinLockExclusiveAtDpcLevel(spinLock); + } + else { + ExAcquireSpinLockSharedAtDpcLevel(spinLock); + } + } + else { + if (exclusive) { + kirql = ExAcquireSpinLockExclusive(spinLock); + } + else { + kirql = ExAcquireSpinLockShared(spinLock); + } + } + + return kirql; +} + +static inline VOID +OvpnReleaseSpinLock(BOOLEAN dpc, KIRQL kirql, PEX_SPIN_LOCK spinLock, BOOLEAN exclusive) +{ + if (dpc) { + if (exclusive) { + ExReleaseSpinLockExclusiveFromDpcLevel(spinLock); + } + else { + ExReleaseSpinLockSharedFromDpcLevel(spinLock); + } + } + else { + if (exclusive) { + ExReleaseSpinLockExclusive(spinLock, kirql); + } + else { + ExReleaseSpinLockShared(spinLock, kirql); + } + } +} diff --git a/pktid.cpp b/pktid.cpp index 4d94ee3..2116600 100644 --- a/pktid.cpp +++ b/pktid.cpp @@ -28,23 +28,17 @@ #define PKTID_WRAP_WARN 0xf0000000ULL _Use_decl_annotations_ -NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, VOID* pktId, BOOLEAN pktId64bit) +NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, VOID* pktId) { ULONG64 seqNum = InterlockedIncrementNoFence64(&px->SeqNum); - if (pktId64bit) { - *static_cast(pktId) = seqNum; + *static_cast(pktId) = static_cast(seqNum); + if (seqNum >= PKTID_WRAP_WARN) { + LOG_ERROR("Pktid wrapped"); + return STATUS_INTEGER_OVERFLOW; + } else { + return STATUS_SUCCESS; } - else - { - *static_cast(pktId) = static_cast(seqNum); - if (seqNum >= PKTID_WRAP_WARN) { - LOG_ERROR("Pktid wrapped"); - return STATUS_INTEGER_OVERFLOW; - } - } - - return STATUS_SUCCESS; } #define PKTID_RECV_EXPIRE ((30 * WDF_TIMEOUT_TO_SEC) / KeQueryTimeIncrement()) diff --git a/pktid.h b/pktid.h index dcc4be8..6d2518b 100644 --- a/pktid.h +++ b/pktid.h @@ -21,7 +21,9 @@ #pragma once +#if defined(_KERNEL_MODE) #include +#endif struct OvpnPktidXmit { @@ -56,8 +58,8 @@ struct OvpnPktidRecv UINT64 IdFloor; }; -/* Get the next packet ID for xmit */ -NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ VOID* pktId, BOOLEAN pktId64bit); +/* Get the next packet ID for xmit. Used only for non-epoch crypto. */ +NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ VOID* pktId); /* Packet replay detection. diff --git a/socket.cpp b/socket.cpp index de82870..61d613e 100644 --- a/socket.cpp +++ b/socket.cpp @@ -211,43 +211,26 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 pee return; } - // If we're at dispatch level, we can use a small optimization and use function - // which is not calling KeRaiseIRQL to raise the IRQL to DISPATCH_LEVEL before attempting to acquire the lock - KIRQL kirql = 0; - if (dpc) { - ExAcquireSpinLockSharedAtDpcLevel(&peer->SpinLock); - } - else { - kirql = ExAcquireSpinLockShared(&peer->SpinLock); - } + KIRQL kirql = OvpnAcquireSpinLock(dpc, &peer->SpinLock, FALSE); + + BOOLEAN exclusive = FALSE; OvpnCryptoContext* cryptoContext = &peer->CryptoContext; if (cryptoContext->Decrypt) { UCHAR keyId = OvpnCryptoKeyIdExtract(op); - OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(cryptoContext, keyId); - if (!keySlot) { - status = STATUS_INVALID_DEVICE_STATE; - LOG_ERROR("keyId not found", TraceLoggingValue(keyId, "keyId")); - } - else { - // extend data area in the buffer for plaintext and crypto overhead - OvpnBufferPut(buffer, len); + // extend data area in the buffer for plaintext and crypto overhead + OvpnBufferPut(buffer, len); - // decrypt into plaintext buffer - status = cryptoContext->Decrypt(keySlot, cipherTextBuf, len, buffer->Data, cryptoContext->CryptoOptions); + OvpnCryptoDecryptParams decryptParams = { keyId, cipherTextBuf, len, buffer->Data }; + status = OvpnCryptoCallWithRetry(peer, dpc, &exclusive, dpc ? nullptr : &kirql, OvpnCryptoInvokeDecrypt, &decryptParams); - // trim AEAD tag an the end - auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; - if (aeadTagEnd) { - OvpnBufferTrim(buffer, len - AEAD_AUTH_TAG_LEN); - } + if (NT_SUCCESS(status)) { + const OvpnCryptoPacketLayout layout = cryptoContext->Layout; - // remove crypto overhead in front - auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - auto cryptoOverheadFront = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); - OvpnBufferPull(buffer, cryptoOverheadFront); + OvpnBufferTrim(buffer, len - layout.TailLen); + OvpnBufferPull(buffer, layout.FrontLen); } } else { @@ -265,13 +248,7 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 pee auto mss = peer->MSS; - // don't forget to release spinlock - if (dpc) { - ExReleaseSpinLockSharedFromDpcLevel(&peer->SpinLock); - } - else { - ExReleaseSpinLockShared(&peer->SpinLock, kirql); - } + OvpnReleaseSpinLock(dpc, kirql, &peer->SpinLock, exclusive); // decrypt failed - don't proceed if (!NT_SUCCESS(status)) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000..47ed056 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,28 @@ +cmake_minimum_required(VERSION 3.14) + +project(tests) + +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip + DOWNLOAD_EXTRACT_TIMESTAMP true) + +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +enable_testing() + +add_executable( + tests + tests.cpp + ../crypto_epoch.cpp +) +target_link_libraries( + tests + GTest::gtest_main + bcrypt +) + +include(GoogleTest) +gtest_discover_tests(tests) \ No newline at end of file diff --git a/tests/tests.cpp b/tests/tests.cpp new file mode 100644 index 0000000..98ee58d --- /dev/null +++ b/tests/tests.cpp @@ -0,0 +1,219 @@ +#include +#include + +#include "../crypto_epoch.h" + +class CryptoTest : public testing::Test +{ +protected: + void SetUp() override + { + RtlZeroMemory(&opts, sizeof(opts)); + RtlZeroMemory(&keySlot, sizeof(keySlot)); + + opts.KeyLen = 32; + opts.UseEpoch = 1; + + ASSERT_EQ(BCryptOpenAlgorithmProvider(&opts.HkdfAlgHandle, BCRYPT_HKDF_ALGORITHM, NULL, 0), STATUS_SUCCESS); + ASSERT_EQ(BCryptOpenAlgorithmProvider(&opts.AeadAlgHangle, BCRYPT_AES_ALGORITHM, NULL, 0), STATUS_SUCCESS); + + keySlot.EpochKeySend.Epoch = 1; + keySlot.EpochKeyRecv.Epoch = 1; + + OvpnCryptoEpochInitKey(&keySlot.Encrypt, &keySlot.EpochKeySend, &opts); + OvpnCryptoEpochInitKey(&keySlot.Decrypt, &keySlot.EpochKeyRecv, &opts); + + RtlZeroMemory(keySlot.FutureEpochKeys, sizeof(keySlot.FutureEpochKeys)); + OvpnCryptoEpochGenerateFutureRecvKeys(&keySlot, &opts); + } + + OvpnCryptoOptions opts; + OvpnCryptoKeySlot keySlot; +}; + +TEST_F(CryptoTest, HkdfExpand) { + uint8_t secret[32] = { 0x07, 0x77, 0x09, 0x36, 0x2c, 0x2e, 0x32, 0xdf, 0x0d, 0xdc, 0x3f, + 0x0d, 0xc4, 0x7b, 0xba, 0x63, 0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f, + 0x9c, 0x31, 0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5 }; + + const char* label = "unit test"; + + uint8_t out_expected[16] = { 0x18, 0x5e, 0xaa, 0x1c, 0x7f, 0x22, 0x8a, 0xb8, + 0xeb, 0x29, 0x77, 0x32, 0x14, 0xd9, 0x20, 0x46 }; + + uint8_t out[16]; + + ASSERT_EQ(OvpnCryptoExpandLabel(opts.HkdfAlgHandle, secret, sizeof(out), label, out), STATUS_SUCCESS); + ASSERT_EQ(0, std::memcmp(out, out_expected, sizeof(out))); +} + +TEST_F(CryptoTest, EpochKeyGeneration) { + // check that the keys look like expected + ASSERT_EQ(keySlot.FutureEpochKeys[0].Epoch, 2); + ASSERT_EQ(keySlot.FutureEpochKeys[15].Epoch, 17); + ASSERT_EQ(keySlot.EpochKeySend.Epoch, 1); + ASSERT_EQ(keySlot.EpochKeyRecv.Epoch, 17); + + // Now replace the recv key with the 6th future key (epoch = 8) + BCryptDestroyKey(keySlot.Decrypt.Key); + RtlZeroMemory(&keySlot.Decrypt, sizeof(keySlot.Decrypt)); + ASSERT_EQ(keySlot.FutureEpochKeys[6].Epoch, 8); + keySlot.Decrypt = keySlot.FutureEpochKeys[6]; + RtlZeroMemory(&keySlot.FutureEpochKeys[6].Epoch, sizeof(OvpnCryptoKeyContext)); + + OvpnCryptoEpochGenerateFutureRecvKeys(&keySlot, &opts); + ASSERT_EQ(keySlot.FutureEpochKeys[0].Epoch, 9); + ASSERT_EQ(keySlot.FutureEpochKeys[15].Epoch, 24); +} + +TEST_F(CryptoTest, EpochKeyRotation) { + /* should replace send + key recv */ + OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, 9, &opts); + + ASSERT_EQ(keySlot.Decrypt.Epoch, 9); + ASSERT_EQ(keySlot.Encrypt.Epoch, 9); + ASSERT_EQ(keySlot.EpochKeySend.Epoch, 9); + ASSERT_EQ(keySlot.RetiringEpochDataReceiveKey.Epoch, 1); + + /* Iterate the data send key four times to get it to 13 */ + for (int i = 0; i < 4; i++) + { + OvpnCryptoEpochKeyIterate(&keySlot.EpochKeySend, opts.HkdfAlgHandle); + + BCryptDestroyKey(&keySlot.Encrypt.Key); + RtlZeroMemory(&keySlot.Encrypt, sizeof(OvpnCryptoKeyContext)); + + OvpnCryptoEpochInitKey(&keySlot.Encrypt, &keySlot.EpochKeySend, &opts); + } + ASSERT_EQ(keySlot.Encrypt.Epoch, 13); + + OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, 10, &opts); + ASSERT_EQ(keySlot.Decrypt.Epoch, 10); + ASSERT_EQ(keySlot.Encrypt.Epoch, 13); + ASSERT_EQ(keySlot.EpochKeySend.Epoch, 13); + ASSERT_EQ(keySlot.RetiringEpochDataReceiveKey.Epoch, 9); + + OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, 12, &opts); + ASSERT_EQ(keySlot.Decrypt.Epoch, 12); + ASSERT_EQ(keySlot.Encrypt.Epoch, 13); + ASSERT_EQ(keySlot.EpochKeySend.Epoch, 13); + ASSERT_EQ(keySlot.RetiringEpochDataReceiveKey.Epoch, 10); + + OvpnCryptoEpochKeyIterate(&keySlot.EpochKeySend, opts.HkdfAlgHandle); + + BCryptDestroyKey(&keySlot.Encrypt.Key); + RtlZeroMemory(&keySlot.Encrypt, sizeof(OvpnCryptoKeyContext)); + + OvpnCryptoEpochInitKey(&keySlot.Encrypt, &keySlot.EpochKeySend, &opts); + + ASSERT_EQ(keySlot.Encrypt.Epoch, 14); +} + +TEST_F(CryptoTest, EpochKeyReceiveLookup) +{ + /* lookup some wacky things that should fail */ + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 2000), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, -1), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 0xefff), nullptr); + + /* Lookup the edges of the current window */ + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 0), nullptr); + ASSERT_EQ(keySlot.RetiringEpochDataReceiveKey.Epoch, 0); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 1)->Epoch, 1); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 2)->Epoch, 2); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 16)->Epoch, 16); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 17)->Epoch, 17); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 18), nullptr); + + /* Should move 1 to retiring key but leave 2-6 undefined, 7 as + * active and 8-23 as future keys*/ + OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, 7, &opts); + + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 0), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 1)->Epoch, 1); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 1), &keySlot.RetiringEpochDataReceiveKey); + + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 2), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 3), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 4), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 5), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 6), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 21)->Epoch, 21); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 22)->Epoch, 22); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 23)->Epoch, 23); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 24), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 25), nullptr); + + /* Should move 7 to retiring key and have 8 as active key and + * 9-24 as future keys */ + OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, 8, &opts); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 0), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 1), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 2), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 3), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 4), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 5), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 6), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 7)->Epoch, 7); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 7), &keySlot.RetiringEpochDataReceiveKey); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 8)->Epoch, 8); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 23)->Epoch, 23); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 24)->Epoch, 24); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 25), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 26), nullptr); +} + +TEST_F(CryptoTest, EpochKeyOverflow) +{ + /* Modify the receive epoch and keys to have a very high epoch to test + * the end of array. Iterating through all 65k keys takes a 2-3s, so we + * avoid this for the unit test */ + keySlot.Decrypt.Epoch = 65516; + keySlot.Encrypt.Epoch = 65516; + + keySlot.EpochKeySend.Epoch = 65516; + keySlot.EpochKeyRecv.Epoch = 65516 + FUTURE_EPOCH_KEYS_COUNT; + + for (int i = 0; i < FUTURE_EPOCH_KEYS_COUNT; ++i) { + keySlot.FutureEpochKeys[i].Epoch = 65517 + i; + } + + /* Move the last few keys until we are close to the limit */ + while (keySlot.Decrypt.Epoch < (UINT16_MAX - 24)) + { + OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, keySlot.Decrypt.Epoch + 10, &opts); + } + + /* Looking up this key should still work as it will not break the limit + * when generating keys */ + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX - 18)->Epoch, UINT16_MAX - 18); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX - 17)->Epoch, UINT16_MAX - 17); + + /* This key is no longer eligible for decrypting as the 16 future keys + * would be larger than uint16_t maximum */ + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX - FUTURE_EPOCH_KEYS_COUNT), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX), nullptr); + + /* Check that moving to the last possible epoch works */ + OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, UINT16_MAX - 17, &opts); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX - 17)->Epoch, UINT16_MAX - 17); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX - 16), nullptr); + ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX), nullptr); +} + +TEST_F(CryptoTest, EpochDeriveDataKey) +{ + OvpnCryptoKeyParameters kp; + OvpnCryptoEpochKey e17{ {19, 12}, 17}; + OvpnCryptoEpochDataKeyDerive(&kp, &e17, opts.HkdfAlgHandle, opts.AeadAlgHangle, 24); + + uint8_t exp_cipherkey[24] = { 0xed, 0x85, 0x33, 0xdb, 0x1c, 0x28, 0xac, 0xe4, + 0x18, 0xe9, 0x00, 0x6a, 0xb2, 0x9c, 0x17, 0x41, + 0x7d, 0x60, 0xeb, 0xe6, 0xcd, 0x90, 0xbf, 0x0a }; + + uint8_t exp_impl_iv[12] = { 0x86, 0x89, 0x0a, 0xab, 0xf0, 0x32, + 0xcb, 0x59, 0xf4, 0xcf, 0xa3, 0x4e }; + + ASSERT_EQ(0, std::memcmp(kp.Cipher, exp_cipherkey, sizeof(exp_cipherkey))); + ASSERT_EQ(0, std::memcmp(kp.IV, exp_impl_iv, sizeof(exp_impl_iv))); +} diff --git a/timer.cpp b/timer.cpp index 213e1ff..20ae520 100644 --- a/timer.cpp +++ b/timer.cpp @@ -80,25 +80,29 @@ static VOID OvpnTimerXmit(WDFTIMER timer) ExAcquireSpinLockSharedAtDpcLevel(&peer->SpinLock); auto peerId = peer->PeerId; - SOCKADDR_STORAGE sa; - OvpnSocketCopyRemoteToSockaddr(peer->TransportAddrs.Remote, &sa); + SOCKADDR_STORAGE sa = {0}; + BOOLEAN exclusive = FALSE; OvpnCryptoContext* cryptoContext = &peer->CryptoContext; if (cryptoContext->Encrypt) { - // make space to crypto overhead - BOOLEAN pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - BOOLEAN aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; - - OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN)); - - // in-place encrypt, always with primary key - status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions); + const OvpnCryptoPacketLayout layout = cryptoContext->Layout; + + OvpnTxBufferPush(buffer, layout.FrontLen); + OvpnBufferPut(buffer, layout.TailLen); + + OvpnCryptoEncryptParams encryptParams = { buffer->Data, buffer->Len }; + status = OvpnCryptoCallWithRetry(peer, TRUE, &exclusive, nullptr, OvpnCryptoInvokeEncrypt, &encryptParams); + + if (NT_SUCCESS(status)) { + OvpnSocketCopyRemoteToSockaddr(peer->TransportAddrs.Remote, &sa); + } } else { status = STATUS_INVALID_DEVICE_STATE; // LOG_WARN("CryptoContext not initialized"); } - ExReleaseSpinLockSharedFromDpcLevel(&peer->SpinLock); + + OvpnReleaseSpinLock(TRUE, 0, &peer->SpinLock, exclusive); if (NT_SUCCESS(status)) { // start async send, completion handler will return ciphertext buffer to the pool diff --git a/trace.h b/trace.h index fc1c8a9..985d917 100644 --- a/trace.h +++ b/trace.h @@ -21,6 +21,7 @@ #pragma once +#if defined(_KERNEL_MODE) #include #include #include @@ -30,6 +31,12 @@ TRACELOGGING_DECLARE_PROVIDER(g_hOvpnEtwProvider); +#else +#include +#endif + +#if defined(_KERNEL_MODE) + #define TraceLoggingFunctionName() TraceLoggingWideString(__FUNCTIONW__, "Func") #define LOG_NTSTATUS(Status, ...) do {\ @@ -106,16 +113,35 @@ TRACELOGGING_DECLARE_PROVIDER(g_hOvpnEtwProvider); } \ } while(0,0) -#define GOTO_IF_NOT_NT_SUCCESS(Label, StatusLValue, Expression, ...) do {\ - StatusLValue = (Expression); \ - if (!NT_SUCCESS(StatusLValue)) \ - { \ - LOG_NTSTATUS(StatusLValue, \ - TraceLoggingWideString(L#Expression, "Expression"), \ - __VA_ARGS__); \ - goto Label; \ - } \ -} while(0,0) +#define GOTO_IF_NOT_NT_SUCCESS(Label, StatusLValue, Expression, ...) \ + do \ + { \ + StatusLValue = (Expression); \ + if (!NT_SUCCESS(StatusLValue)) \ + { \ + LOG_NTSTATUS(StatusLValue, \ + TraceLoggingString(#Expression, "Expression"), \ + __VA_ARGS__); \ + goto Label; \ + } \ + } while (0) + +#else /* user-mode / tests */ + +#define LOG_INFO(Info, ...) + +#define GOTO_IF_NOT_NT_SUCCESS(Label, StatusLValue, Expression, ...) \ + do \ + { \ + StatusLValue = (Expression); \ + if (!NT_SUCCESS(StatusLValue)) \ + { \ + goto Label; \ + } \ + } while (0) + +#endif + #ifndef TraceLoggingIPv4Address #define TraceLoggingIPv4Address(value, ...) _tlgArgScalarVal(UINT32, value, TlgInUINT32, (TlgOutIPV4), __VA_ARGS__) diff --git a/txqueue.cpp b/txqueue.cpp index 95f26a5..9d97ca9 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -235,29 +235,26 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET InterlockedExchangeAddNoFence64(&peer->VpnTxBytes, buffer->Len); auto irql = ExAcquireSpinLockShared(&peer->SpinLock); + BOOLEAN exclusive = FALSE; OvpnCryptoContext* cryptoContext = &peer->CryptoContext; auto remoteAddr = peer->TransportAddrs.Remote; if (cryptoContext->Encrypt) { - auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; - auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; - - // make space to crypto overhead - OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN)); - if (aeadTagEnd) - { - OvpnBufferPut(buffer, AEAD_AUTH_TAG_LEN); - } + const OvpnCryptoPacketLayout layout = cryptoContext->Layout; + + OvpnTxBufferPush(buffer, layout.FrontLen); + OvpnBufferPut(buffer, layout.TailLen); - // in-place encrypt, always with primary key - status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions); + OvpnCryptoEncryptParams encryptParams = { buffer->Data, buffer->Len }; + status = OvpnCryptoCallWithRetry(peer, FALSE, &exclusive, &irql, OvpnCryptoInvokeEncrypt, &encryptParams); } else { status = STATUS_INVALID_DEVICE_STATE; // LOG_WARN("CryptoContext not initialized"); } - ExReleaseSpinLockShared(&peer->SpinLock, irql); + + OvpnReleaseSpinLock(FALSE, irql, &peer->SpinLock, exclusive); if (NT_SUCCESS(status)) { InterlockedExchangeAddNoFence64(&peer->LinkTxBytes, buffer->Len); diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index 6dab25f..a7b28fb 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -119,8 +119,7 @@ typedef struct _OVPN_CRYPTO_DATA { int PeerId; } OVPN_CRYPTO_DATA, * POVPN_CRYPTO_DATA; -#define CRYPTO_OPTIONS_AEAD_TAG_END (1<<1) -#define CRYPTO_OPTIONS_64BIT_PKTID (1<<2) +#define CRYPTO_OPTIONS_EPOCH (1<<1) typedef struct _OVPN_CRYPTO_DATA_V2 { OVPN_CRYPTO_DATA V1;