diff --git a/.github/workflows/msbuild.yml b/.github/workflows/msbuild.yml
index f8cf54a..4df7084 100644
--- a/.github/workflows/msbuild.yml
+++ b/.github/workflows/msbuild.yml
@@ -73,3 +73,20 @@ jobs:
msm\*.msi
msm\*.msm
gui\build\Release\*.*
+
+ tests:
+ runs-on: windows-2022
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ submodules: true
+
+ - name: Configure CMake tests
+ run: cmake -S tests -B tests/build -A x64
+
+ - name: Build CMake tests
+ run: cmake --build tests/build --config Release
+
+ - name: Run CMake tests
+ run: ctest --test-dir tests/build --build-config Release --output-on-failure
diff --git a/Driver.cpp b/Driver.cpp
index 9805951..2387c4b 100644
--- a/Driver.cpp
+++ b/Driver.cpp
@@ -727,7 +727,7 @@ VOID OvpnEvtDeviceCleanup(WDFOBJECT obj) {
// OvpnCryptoUninitAlgHandles called outside of lock because
// it requires PASSIVE_LEVEL.
- OvpnCryptoUninitAlgHandles(device->AesAlgHandle, device->ChachaAlgHandle);
+ OvpnCryptoUninitAlgHandles(device->AesAlgHandle, device->ChachaAlgHandle, device->HkdfAlgHandle);
// delete control device if there are no devices left
POVPN_DRIVER driverCtx = OvpnGetDriverContext(WdfGetDriver());
@@ -892,7 +892,7 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) {
device->IRoutesIPV4.Init(FALSE);
device->IRoutesIPV6.Init(TRUE);
- GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoInitAlgHandles(&device->AesAlgHandle, &device->ChachaAlgHandle));
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoInitAlgHandles(&device->AesAlgHandle, &device->ChachaAlgHandle, &device->HkdfAlgHandle));
// Initialize peers tables
RtlInitializeGenericTable(&device->Peers, OvpnPeerCompareByPeerIdRoutine, OvpnPeerAllocateRoutine, OvpnPeerFreeRoutine, NULL);
@@ -901,7 +901,6 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) {
RtlInitializeGenericTable(&device->PeersByTransport, OvpnPeerCompareByTransportRoutine, OvpnPeerAllocateRoutine, OvpnPeerFreeRoutine, NULL);
LOG_IF_NOT_NT_SUCCESS(status = OvpnAdapterCreate(device));
-
done:
LOG_EXIT();
diff --git a/Driver.h b/Driver.h
index 4377fd0..f61921d 100644
--- a/Driver.h
+++ b/Driver.h
@@ -86,6 +86,7 @@ struct OVPN_DEVICE {
BCRYPT_ALG_HANDLE AesAlgHandle;
BCRYPT_ALG_HANDLE ChachaAlgHandle;
+ BCRYPT_ALG_HANDLE HkdfAlgHandle;
_Guarded_by_(SpinLock)
OvpnSocket Socket;
diff --git a/PropertySheet.props b/PropertySheet.props
index 1d4fed5..82ed55a 100644
--- a/PropertySheet.props
+++ b/PropertySheet.props
@@ -3,8 +3,8 @@
2
- 7
- 1
+ 8
+ 0
diff --git a/adapter.h b/adapter.h
index dfaa0fc..5035347 100644
--- a/adapter.h
+++ b/adapter.h
@@ -56,5 +56,3 @@ OvpnAdapterNotifyRx(NETADAPTER netAdapter);
VOID
OvpnAdapterSetLinkState(_In_ POVPN_ADAPTER adapter, NET_IF_MEDIA_CONNECT_STATE state);
-
-#define OVPN_PAYLOAD_BACKFILL 26 // 2 + 4 + 4 + 16 -> tcp packet size + data_v2 + pktid + auth-tag;
\ No newline at end of file
diff --git a/crypto.cpp b/crypto.cpp
index 4f0a639..ba4aa97 100644
--- a/crypto.cpp
+++ b/crypto.cpp
@@ -21,11 +21,13 @@
#include
#include
+#include
#include "crypto.h"
#include "trace.h"
#include "pktid.h"
#include "socket.h"
+#include "peer.h"
UINT
OvpnCryptoOpCompose(UINT opcode, UINT keyId)
@@ -33,6 +35,120 @@ OvpnCryptoOpCompose(UINT opcode, UINT keyId)
return (opcode << OVPN_OPCODE_SHIFT) | keyId;
}
+static
+VOID
+OvpnCryptoUpgradeLock(_In_ OvpnPeerContext* peer, _In_ BOOLEAN atDpcLevel, _Inout_opt_ PKIRQL kirql)
+{
+ if (atDpcLevel) {
+ ExReleaseSpinLockSharedFromDpcLevel(&peer->SpinLock);
+ ExAcquireSpinLockExclusiveAtDpcLevel(&peer->SpinLock);
+ }
+ else {
+ NT_ASSERT(kirql != nullptr);
+
+ KIRQL previousIrql = *kirql;
+
+ ExReleaseSpinLockShared(&peer->SpinLock, previousIrql);
+
+ KIRQL acquiredIrql = ExAcquireSpinLockExclusive(&peer->SpinLock);
+ NT_ASSERT(acquiredIrql == previousIrql);
+ *kirql = acquiredIrql;
+ }
+}
+
+_Use_decl_annotations_
+NTSTATUS
+OvpnCryptoCallWithRetry(
+ OvpnPeerContext* peer,
+ BOOLEAN atDpcLevel,
+ PBOOLEAN exclusive,
+ PKIRQL kirql,
+ POVPN_CRYPTO_RETRY_ROUTINE routine,
+ PVOID context)
+{
+ NT_ASSERT(peer != nullptr);
+ NT_ASSERT(routine != nullptr);
+
+ BOOLEAN allowRekey = (exclusive != nullptr) && *exclusive;
+
+ for (;;) {
+ OvpnCryptoContext* cryptoContext = &peer->CryptoContext;
+ NTSTATUS status = routine(cryptoContext, allowRekey, context);
+
+ if (status != STATUS_OVPN_CRYPTO_RETRY) {
+ return status;
+ }
+
+ if (!allowRekey) {
+ // Drop the shared lock before upgrading. The caller owns a peer
+ // reference, so the context remains valid while we temporarily
+ // release the lock and the loop below re-reads the crypto state
+ // under exclusive ownership.
+ OvpnCryptoUpgradeLock(peer, atDpcLevel, kirql);
+ allowRekey = TRUE;
+
+ if (exclusive != nullptr) {
+ *exclusive = TRUE;
+ }
+
+ continue;
+ }
+
+ LOG_ERROR("Crypto helper requested retry under exclusive lock");
+ return STATUS_UNSUCCESSFUL;
+ }
+}
+
+_Use_decl_annotations_
+NTSTATUS
+OvpnCryptoInvokeEncrypt(
+ OvpnCryptoContext* cryptoContext,
+ BOOLEAN allowRekey,
+ PVOID context)
+{
+ auto params = reinterpret_cast(context);
+
+ if ((cryptoContext == nullptr) || (cryptoContext->Encrypt == nullptr)) {
+ return STATUS_INVALID_DEVICE_STATE;
+ }
+
+ NT_ASSERT(params != nullptr);
+
+ return cryptoContext->Encrypt(&cryptoContext->Primary, params->Buffer, params->Length, &cryptoContext->Options, allowRekey);
+}
+
+_Use_decl_annotations_
+NTSTATUS
+OvpnCryptoInvokeDecrypt(
+ OvpnCryptoContext* cryptoContext,
+ BOOLEAN allowRekey,
+ PVOID context)
+{
+ auto params = reinterpret_cast(context);
+
+ if ((cryptoContext == nullptr) || (cryptoContext->Decrypt == nullptr)) {
+ return STATUS_INVALID_DEVICE_STATE;
+ }
+
+ NT_ASSERT(params != nullptr);
+
+ OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(cryptoContext, params->KeyId);
+ if (keySlot == nullptr) {
+ LOG_ERROR("keyId not found", TraceLoggingValue(params->KeyId, "keyId"));
+ return STATUS_INVALID_DEVICE_STATE;
+ }
+
+ NTSTATUS status = cryptoContext->Decrypt(
+ keySlot,
+ params->CipherText,
+ params->Length,
+ params->PlainText,
+ &cryptoContext->Options,
+ allowRekey);
+
+ return status;
+}
+
static
UINT
OvpnProtoOp32Compose(UINT opcode, UINT keyId, UINT opPeerId)
@@ -48,12 +164,15 @@ OvpnProtoOp32Compose(UINT opcode, UINT keyId, UINT opPeerId)
OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptNone;
_Use_decl_annotations_
-NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions)
+NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, OvpnCryptoOptions* opts, BOOLEAN allowRekey)
{
UNREFERENCED_PARAMETER(keySlot);
+ UNREFERENCED_PARAMETER(allowRekey);
- BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
- BOOLEAN cryptoOverhead = OVPN_DATA_V2_LEN + pktId64bit ? 8 : 4;
+ BOOLEAN useEpoch = (opts != NULL) && opts->UseEpoch;
+ SIZE_T pktIdLen = useEpoch ? 8 : 4;
+ SIZE_T authTagFront = useEpoch ? 0 : AEAD_AUTH_TAG_LEN;
+ SIZE_T cryptoOverhead = OVPN_DATA_V2_LEN + pktIdLen + authTagFront;
if (len < cryptoOverhead) {
LOG_WARN("Packet too short", TraceLoggingValue(len, "len"));
@@ -69,11 +188,12 @@ OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptNone;
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions)
+OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, OvpnCryptoOptions* opts, BOOLEAN allowRekey)
{
UNREFERENCED_PARAMETER(keySlot);
UNREFERENCED_PARAMETER(len);
- UNREFERENCED_PARAMETER(cryptoOptions);
+ UNREFERENCED_PARAMETER(opts);
+ UNREFERENCED_PARAMETER(allowRekey);
// prepend with opcode, key-id and peer-id
UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, 0, 0);
@@ -90,12 +210,15 @@ OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoInitAlgHandles(BCRYPT_ALG_HANDLE* aesAlgHandle, BCRYPT_ALG_HANDLE* chachaAlgHandle)
+OvpnCryptoInitAlgHandles(BCRYPT_ALG_HANDLE* aesAlgHandle, BCRYPT_ALG_HANDLE* chachaAlgHandle, BCRYPT_ALG_HANDLE* hkdfAlgHandle)
{
NTSTATUS status;
GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptOpenAlgorithmProvider(aesAlgHandle, BCRYPT_AES_ALGORITHM, NULL, BCRYPT_PROV_DISPATCH));
GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptSetProperty(*aesAlgHandle, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0));
+ // used by epoch data channel
+ GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptOpenAlgorithmProvider(hkdfAlgHandle, BCRYPT_HKDF_ALGORITHM, NULL, BCRYPT_PROV_DISPATCH));
+
// available starting from Windows 11
LOG_IF_NOT_NT_SUCCESS(BCryptOpenAlgorithmProvider(chachaAlgHandle, BCRYPT_CHACHA20_POLY1305_ALGORITHM, NULL, BCRYPT_PROV_DISPATCH));
done:
@@ -104,7 +227,7 @@ OvpnCryptoInitAlgHandles(BCRYPT_ALG_HANDLE* aesAlgHandle, BCRYPT_ALG_HANDLE* cha
_Use_decl_annotations_
VOID
-OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDLE chachaAlgHandle)
+OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDLE chachaAlgHandle, BCRYPT_ALG_HANDLE hkdfAlgHandle)
{
if (aesAlgHandle) {
LOG_IF_NOT_NT_SUCCESS(BCryptCloseAlgorithmProvider(aesAlgHandle, 0));
@@ -113,6 +236,10 @@ OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDL
if (chachaAlgHandle) {
LOG_IF_NOT_NT_SUCCESS(BCryptCloseAlgorithmProvider(chachaAlgHandle, 0));
}
+
+ if (hkdfAlgHandle) {
+ LOG_IF_NOT_NT_SUCCESS(BCryptCloseAlgorithmProvider(hkdfAlgHandle, 0));
+ }
}
#define GET_SYSTEM_ADDRESS_MDL(buf, mdl) { \
@@ -123,151 +250,245 @@ OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDL
} \
}
-static
NTSTATUS
-OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions)
+OvpnCryptoCheckReplay(OvpnCryptoKeySlot* keySlot, ULONG64 packet_id_net, UINT16 epoch, OvpnCryptoOptions *opts, BOOLEAN allowRekey)
{
- /*
- AEAD Nonce :
+ OvpnPktidRecv* recv = NULL;
- [Packet ID] [HMAC keying material]
- [4/8 bytes] [8/4 bytes ]
- [AEAD nonce total : 12 bytes ]
+ if (epoch == 0 || keySlot->Decrypt.Epoch == epoch) {
+ recv = &keySlot->PktidRecv;
+ }
+ else if (epoch == keySlot->RetiringEpochDataReceiveKey.Epoch) {
+ recv = &keySlot->PktidRecvRetiring;
+ }
+ else {
+ if (!allowRekey) {
+ return STATUS_OVPN_CRYPTO_RETRY;
+ }
- TLS wire protocol :
+ /* We have an epoch that is neither current or old recv key but
+ * is authenticated, ie we need to move to a new current recv key */
+ LOG_INFO("Received data packet with new epoch. Updating receive key", TraceLoggingValue(epoch, "epoch"));
+ OvpnCryptoEpochReplaceUpdateRecvKey(keySlot, epoch, opts);
+ recv = &keySlot->PktidRecv;
+ }
- Packet ID is 8 bytes long with CRYPTO_OPTIONS_64BIT_PKTID.
+ return OvpnPktidRecvVerify(recv, packet_id_net);
+}
- [DATA_V2 opcode] [Packet ID] [AEAD Auth tag] [ciphertext]
- [4 bytes ] [4/8 bytes] [16 bytes ]
- [AEAD additional data(AD) ]
+/*
+AEAD Nonce :
- With CRYPTO_OPTIONS_AEAD_TAG_END AEAD Auth tag is placed after ciphertext:
+ [Packet ID] [HMAC keying material]
+ [4 bytes ] [4 bytes ]
+ [AEAD nonce total : 12 bytes ]
- [DATA_V2 opcode] [Packet ID] [ciphertext] [AEAD Auth tag]
- [4 bytes ] [4/8 bytes] [16 bytes ]
- [AEAD additional data(AD) ]
- */
+TLS wire protocol :
- NTSTATUS status = STATUS_SUCCESS;
+ [DATA_V2 opcode] [Packet ID] [AEAD Auth tag] [ciphertext]
+ [4 bytes ] [4 bytes ] [16 bytes ]
+ [AEAD additional data(AD) ]
- BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
+New data format, with epoch keys and 64bit packet id:
- SIZE_T cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4);
+ struct aead_packet {
+ int opcode:5;
+ int key_id:3;
+ int peer_id:24;
+ uint64_t packet_id;
+ uint8_t* encrypted_payload;
+ uint8_t[16] authentication_tag;
+ }
- if (len < cryptoOverhead) {
- LOG_WARN("Packet too short", TraceLoggingValue(len, "len"));
- return STATUS_DATA_ERROR;
+ struct packet_id {
+ uint epoch:16;
+ uint epoch_counter:48;
}
- UCHAR nonce[12];
- if (encrypt) {
- // prepend with opcode, key-id and peer-id
- UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, keySlot->KeyId, keySlot->PeerId);
- op = RtlUlongByteSwap(op);
- *reinterpret_cast(bufOut) = op;
+ authenticated_data = opcode| key_id | peer_id | packet_id
+*/
- if (pktId64bit)
- {
- // calculate pktid
- UINT64 pktid;
- GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, true));
- ULONG64 pktidNetwork = RtlUlonglongByteSwap(pktid);
- // calculate nonce, which is pktid + nonce_tail
- RtlCopyMemory(nonce, &pktidNetwork, 8);
- RtlCopyMemory(nonce + 8, keySlot->EncNonceTail, 4);
+OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptAEAD;
- // prepend with pktid
- *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork;
- }
- else
- {
- // calculate pktid
- UINT32 pktid;
- GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, false));
- ULONG pktidNetwork = RtlUlongByteSwap(pktid);
+_Use_decl_annotations_
+NTSTATUS
+OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, OvpnCryptoOptions* opts, BOOLEAN allowRekey)
+{
+ NTSTATUS status = STATUS_SUCCESS;
- // calculate nonce, which is pktid + nonce_tail
- RtlCopyMemory(nonce, &pktidNetwork, 4);
- RtlCopyMemory(nonce + 4, keySlot->EncNonceTail, 8);
+ BOOLEAN authTagEnd = opts->UseEpoch;
+ ULONG pktidLen = opts->UseEpoch ? 8 : 4;
+ ULONG cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + pktidLen;
- // prepend with pktid
- *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork;
- }
+ if (len < cryptoOverhead) {
+ LOG_WARN("Packet too short", TraceLoggingValue(len, "len"));
+ return STATUS_DATA_ERROR;
}
- else {
- ULONG64 pktId;
- RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, pktId64bit ? 8 : 4);
- RtlCopyMemory(nonce + (pktId64bit ? 8 : 4), &keySlot->DecNonceTail, pktId64bit ? 4 : 8);
- if (pktId64bit)
- {
- pktId = RtlUlonglongByteSwap(*reinterpret_cast(nonce));
- }
- else
- {
- pktId = static_cast(RtlUlongByteSwap(*reinterpret_cast(nonce)));
- }
+ // we prepended buf with crypto overhead
+ len -= cryptoOverhead;
+
+ OvpnCryptoKeyContext* decryptKey = &keySlot->Decrypt;
+ UINT64 packet_id = 0;
+ UINT16 rx_epoch = 0;
+
+ UCHAR nonce[12];
- status = OvpnPktidRecvVerify(&keySlot->PktidRecv, pktId);
+ if (opts->UseEpoch) {
+ // read packet_id
+ UINT64 packet_id_net;
+ RtlCopyMemory(&packet_id_net, bufIn + OVPN_DATA_V2_LEN, sizeof(packet_id_net));
+ packet_id = RtlUlonglongByteSwap(packet_id_net);
- if (!NT_SUCCESS(status)) {
- LOG_ERROR("Invalid pktId", TraceLoggingUInt64(pktId, "pktId"));
+ // get epoch number and counter
+ rx_epoch = (UINT16)(packet_id >> 48);
+ if (rx_epoch == 0) {
+ LOG_ERROR("Invalid epoch 0");
return STATUS_DATA_ERROR;
}
- }
- // we prepended buf with crypto overhead
- len -= cryptoOverhead;
+ decryptKey = OvpnCryptoEpochLookupDecryptKey(keySlot, rx_epoch);
+ if (decryptKey == NULL) {
+ LOG_ERROR("Data packet with unknown epoch", TraceLoggingValue(rx_epoch, "epoch"));
+ return STATUS_DATA_ERROR;
+ }
+
+ OvpnCryptoMakeEpochNonce(decryptKey->ImplicitIV, packet_id_net, nonce);
+ }
+ else {
+ RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, 4);
+ RtlCopyMemory(nonce + 4, decryptKey->ImplicitIV + 4, 8);
- BOOLEAN aeadTagEnd = cryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
+ packet_id = static_cast(RtlUlongByteSwap(*reinterpret_cast(nonce)));
+ }
BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO authInfo;
BCRYPT_INIT_AUTH_MODE_INFO(authInfo);
authInfo.pbNonce = nonce;
authInfo.cbNonce = sizeof(nonce);
- authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? len : 0);
+ authInfo.pbTag = bufIn + OVPN_DATA_V2_LEN + pktidLen + (authTagEnd ? len : 0);
authInfo.cbTag = AEAD_AUTH_TAG_LEN;
- authInfo.pbAuthData = (encrypt ? bufOut : bufIn);
- authInfo.cbAuthData = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4);
+ authInfo.pbAuthData = bufIn;
+ authInfo.cbAuthData = OVPN_DATA_V2_LEN + pktidLen;
- auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN);
+ auto payloadOffset = OVPN_DATA_V2_LEN + pktidLen + (authTagEnd ? 0 : AEAD_AUTH_TAG_LEN);
bufOut += payloadOffset;
bufIn += payloadOffset;
// non-chaining mode
ULONG bytesDone = 0;
- GOTO_IF_NOT_NT_SUCCESS(done, status, encrypt ?
- BCryptEncrypt(keySlot->EncKey, bufIn, (ULONG)len, &authInfo, NULL, 0, bufOut, (ULONG)len, &bytesDone, 0) :
- BCryptDecrypt(keySlot->DecKey, bufIn, (ULONG)len, &authInfo, NULL, 0, bufOut, (ULONG)len, &bytesDone, 0)
- );
+ GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptDecrypt(decryptKey->Key, bufIn, (ULONG)len, &authInfo, NULL, 0, bufOut, (ULONG)len, &bytesDone, 0));
+
+ status = OvpnCryptoCheckReplay(keySlot, packet_id, rx_epoch, opts, allowRekey);
+ if (status == STATUS_OVPN_CRYPTO_RETRY) {
+ return status;
+ }
+
+ if (!NT_SUCCESS(status)) {
+ LOG_ERROR("Invalid packet_id", TraceLoggingUInt64(packet_id, "packet_id"));
+ return STATUS_DATA_ERROR;
+ }
done:
return status;
}
-OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptAEAD;
+OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptAEAD;
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions)
+OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, OvpnCryptoOptions* opts, BOOLEAN allowRekey)
{
- return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut, cryptoOptions);
-}
+ NTSTATUS status = STATUS_SUCCESS;
-OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptAEAD;
+ BOOLEAN authTagEnd = opts->UseEpoch;
+ ULONG pktidLen = opts->UseEpoch ? 8 : 4;
+ ULONG cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + pktidLen;
-_Use_decl_annotations_
-NTSTATUS
-OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions)
-{
- return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf, cryptoOptions);
+ if (len < cryptoOverhead) {
+ LOG_WARN("Packet too short", TraceLoggingValue(len, "len"));
+ return STATUS_DATA_ERROR;
+ }
+
+ // we prepended buf with crypto overhead
+ len -= cryptoOverhead;
+
+ UINT64 packet_id = 0;
+
+ UCHAR nonce[12];
+
+ // prepend with opcode, key-id and peer-id
+ UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, keySlot->KeyId, keySlot->PeerId);
+ op = RtlUlongByteSwap(op);
+ RtlCopyMemory(buf, &op, sizeof(op));
+
+ if (opts->UseEpoch) {
+ if (keySlot->EpochKeySend.Epoch == UINT16_MAX) {
+ return STATUS_BUFFER_OVERFLOW;
+ }
+
+ if (OvpnCryptoAeadUsageLimitReached(opts->AeadUsageLimit, keySlot->Encrypt.PlaintextBlocks, keySlot->PktidXmit.SeqNum) || (keySlot->PktidXmit.SeqNum == PACKET_ID_EPOCH_MAX)) {
+ if (!allowRekey) {
+ return STATUS_OVPN_CRYPTO_RETRY;
+ }
+
+ OvpnCryptoEpochIterateSendKey(keySlot, opts);
+ }
+
+ // calculate 64-bit packet-id = (epoch << 48) | ctr48
+ // the overflow of pktid is checked above
+ UINT64 ctr48 = InterlockedIncrementNoFence64(&keySlot->PktidXmit.SeqNum) & PACKET_ID_EPOCH_MAX;
+ packet_id = ((UINT64)keySlot->Encrypt.Epoch << 48) | ctr48;
+
+ // prepend with pktid
+ UINT64 packet_id_net = RtlUlonglongByteSwap(packet_id);
+ RtlCopyMemory(buf + OVPN_DATA_V2_LEN, &packet_id_net, sizeof(packet_id_net));
+
+ OvpnCryptoMakeEpochNonce(keySlot->Encrypt.ImplicitIV, packet_id_net, nonce);
+ }
+ else {
+ // calculate pktid
+ UINT32 packet_id_32;
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &packet_id_32));
+ ULONG packet_id_net = RtlUlongByteSwap(packet_id_32);
+
+ // calculate nonce, which is pktid + nonce_tail
+ RtlCopyMemory(nonce, &packet_id_net, 4);
+ RtlCopyMemory(nonce + 4, keySlot->Encrypt.ImplicitIV + 4, 8);
+
+ // prepend with pktid
+ RtlCopyMemory(buf + OVPN_DATA_V2_LEN, &packet_id_net, sizeof(packet_id_net));
+ }
+
+ // update number of plaintext blocks encrypted. Use the (x + (n-1))/n trick to round up the result to the number of blocks used
+ const ULONGLONG blocksize = AEAD_LIMIT_BLOCKSIZE;
+ ULONGLONG inc = ((ULONGLONG)len + (blocksize - 1)) / blocksize;
+ InterlockedAdd64((volatile LONG64*)&keySlot->Encrypt.PlaintextBlocks, (LONG64)inc);
+
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO authInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(authInfo);
+ authInfo.pbNonce = nonce;
+ authInfo.cbNonce = sizeof(nonce);
+ authInfo.pbTag = buf + OVPN_DATA_V2_LEN + pktidLen + (authTagEnd ? len : 0);
+ authInfo.cbTag = AEAD_AUTH_TAG_LEN;
+ authInfo.pbAuthData = buf;
+ authInfo.cbAuthData = OVPN_DATA_V2_LEN + pktidLen;
+
+ auto payloadOffset = OVPN_DATA_V2_LEN + pktidLen + (authTagEnd ? 0 : AEAD_AUTH_TAG_LEN);
+ buf += payloadOffset;
+
+ // non-chaining mode
+ ULONG bytesDone = 0;
+ GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptEncrypt(keySlot->Encrypt.Key, buf, (ULONG)len, &authInfo, NULL, 0, buf, (ULONG)len, &bytesDone, 0));
+
+done:
+ return status;
}
_Use_decl_annotations_
NTSTATUS
-OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDataV2, BCRYPT_ALG_HANDLE algHandle)
+OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDataV2, BCRYPT_ALG_HANDLE algHandle, BCRYPT_ALG_HANDLE hkdfAlgHandle)
{
OvpnCryptoKeySlot* keySlot = NULL;
NTSTATUS status = STATUS_SUCCESS;
@@ -285,25 +506,16 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDa
return STATUS_INVALID_DEVICE_REQUEST;
}
- if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID)
- {
- cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_64BIT_PKTID;
- }
- if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END)
- {
- cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_AEAD_TAG_END;
- }
-
if ((cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM) || (cryptoData->CipherAlg == OVPN_CIPHER_ALG_CHACHA20_POLY1305)) {
// destroy previous keys
- if (keySlot->EncKey) {
- BCryptDestroyKey(keySlot->EncKey);
- keySlot->EncKey = NULL;
+ if (keySlot->Encrypt.Key) {
+ BCryptDestroyKey(keySlot->Encrypt.Key);
+ keySlot->Encrypt.Key = NULL;
}
- if (keySlot->DecKey) {
- BCryptDestroyKey(keySlot->DecKey);
- keySlot->DecKey = NULL;
+ if (keySlot->Decrypt.Key) {
+ BCryptDestroyKey(keySlot->Decrypt.Key);
+ keySlot->Decrypt.Key = NULL;
}
if ((cryptoData->Encrypt.KeyLen > 32) || (cryptoData->Decrypt.KeyLen > 32))
@@ -314,24 +526,53 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDa
goto done;
}
- // generate keys from key materials
- GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(algHandle, &keySlot->EncKey, NULL, 0, cryptoData->Encrypt.Key, cryptoData->Encrypt.KeyLen, 0));
- GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(algHandle, &keySlot->DecKey, NULL, 0, cryptoData->Decrypt.Key, cryptoData->Decrypt.KeyLen, 0));
+ cryptoContext->Options.KeyLen = cryptoData->Encrypt.KeyLen;
- // copy nonce tails
- RtlCopyMemory(keySlot->EncNonceTail, cryptoData->Encrypt.NonceTail, sizeof(cryptoData->Encrypt.NonceTail));
- RtlCopyMemory(keySlot->DecNonceTail, cryptoData->Decrypt.NonceTail, sizeof(cryptoData->Decrypt.NonceTail));
+ if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_EPOCH) {
+ cryptoContext->Options.AeadUsageLimit = OvpnCryptoAeadUsageLimit(cryptoData->CipherAlg);
+ cryptoContext->Options.UseEpoch = TRUE;
+ cryptoContext->Options.HkdfAlgHandle = hkdfAlgHandle;
+ cryptoContext->Options.AeadAlgHangle = algHandle;
- cryptoContext->Encrypt = OvpnCryptoEncryptAEAD;
- cryptoContext->Decrypt = OvpnCryptoDecryptAEAD;
+ keySlot->EpochKeySend.Epoch = 1;
+ RtlCopyMemory(keySlot->EpochKeySend.EpochKey, cryptoData->Encrypt.Key, 32);
+
+ keySlot->EpochKeyRecv.Epoch = 1;
+ RtlCopyMemory(keySlot->EpochKeyRecv.EpochKey, cryptoData->Decrypt.Key, 32);
+
+ OvpnCryptoEpochInitKey(&keySlot->Encrypt, &keySlot->EpochKeySend, &cryptoContext->Options);
+ OvpnCryptoEpochInitKey(&keySlot->Decrypt, &keySlot->EpochKeyRecv, &cryptoContext->Options);
+
+ RtlZeroMemory(keySlot->FutureEpochKeys, sizeof(keySlot->FutureEpochKeys));
+ OvpnCryptoEpochGenerateFutureRecvKeys(keySlot, &cryptoContext->Options);
+ }
+ else {
+ // generate keys from key materials
+ GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(algHandle, &keySlot->Encrypt.Key, NULL, 0, cryptoData->Encrypt.Key, cryptoData->Encrypt.KeyLen, 0));
+ GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(algHandle, &keySlot->Decrypt.Key, NULL, 0, cryptoData->Decrypt.Key, cryptoData->Decrypt.KeyLen, 0));
+
+ // copy nonce tails
+ RtlCopyMemory(keySlot->Encrypt.ImplicitIV + 4, cryptoData->Encrypt.NonceTail, sizeof(cryptoData->Encrypt.NonceTail));
+ RtlCopyMemory(keySlot->Decrypt.ImplicitIV + 4, cryptoData->Decrypt.NonceTail, sizeof(cryptoData->Decrypt.NonceTail));
+ }
keySlot->KeyId = cryptoData->KeyId;
keySlot->PeerId = cryptoData->PeerId;
+ cryptoContext->Encrypt = OvpnCryptoEncryptAEAD;
+ cryptoContext->Decrypt = OvpnCryptoDecryptAEAD;
+
LOG_INFO("New key", TraceLoggingValue(cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM ? "aes-gcm" : "chacha20-poly1305", "alg"),
TraceLoggingValue(cryptoData->KeyId, "KeyId"), TraceLoggingValue(cryptoData->PeerId, "PeerId"));
}
else if (cryptoData->CipherAlg == OVPN_CIPHER_ALG_NONE) {
+ OvpnCryptoEpochUninitSlot(keySlot);
+
+ keySlot->KeyId = cryptoData->KeyId;
+ keySlot->PeerId = cryptoData->PeerId;
+
+ RtlZeroMemory(&cryptoContext->Options, sizeof(cryptoContext->Options));
+
cryptoContext->Encrypt = OvpnCryptoEncryptNone;
cryptoContext->Decrypt = OvpnCryptoDecryptNone;
@@ -343,6 +584,7 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDa
goto done;
}
+ OvpnCryptoDescribePacketLayout(cryptoContext, &cryptoContext->Layout);
// reset pktid for a new key
RtlZeroMemory(&keySlot->PktidXmit, sizeof(keySlot->PktidXmit));
RtlZeroMemory(&keySlot->PktidRecv, sizeof(keySlot->PktidRecv));
@@ -351,6 +593,32 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDa
return status;
}
+_Use_decl_annotations_
+VOID
+OvpnCryptoDescribePacketLayout(const OvpnCryptoContext* cryptoContext, OvpnCryptoPacketLayout* layout)
+{
+ NT_ASSERT(layout != NULL);
+
+ layout->FrontLen = OVPN_DATA_V2_LEN + 4;
+ layout->TailLen = 0;
+
+ if ((cryptoContext == NULL) || (cryptoContext->Encrypt == NULL)) {
+ return;
+ }
+
+ if (cryptoContext->Encrypt == OvpnCryptoEncryptAEAD) {
+ BOOLEAN useEpoch = cryptoContext->Options.UseEpoch;
+
+ layout->FrontLen = OVPN_DATA_V2_LEN + (useEpoch ? 8 : 4);
+ if (!useEpoch) {
+ layout->FrontLen += AEAD_AUTH_TAG_LEN;
+ }
+
+ layout->TailLen = useEpoch ? AEAD_AUTH_TAG_LEN : 0;
+ }
+}
+
+
_Use_decl_annotations_
OvpnCryptoKeySlot*
OvpnCryptoKeySlotFromKeyId(OvpnCryptoContext* cryptoContext, unsigned int keyId)
@@ -383,21 +651,8 @@ _Use_decl_annotations_
VOID
OvpnCryptoUninit(OvpnCryptoContext* cryptoContext)
{
- if (cryptoContext->Primary.EncKey) {
- BCryptDestroyKey(cryptoContext->Primary.EncKey);
- }
-
- if (cryptoContext->Primary.DecKey) {
- BCryptDestroyKey(cryptoContext->Primary.DecKey);
- }
-
- if (cryptoContext->Secondary.EncKey) {
- BCryptDestroyKey(cryptoContext->Secondary.EncKey);
- }
-
- if (cryptoContext->Secondary.DecKey) {
- BCryptDestroyKey(cryptoContext->Secondary.DecKey);
- }
+ OvpnCryptoEpochUninitSlot(&cryptoContext->Primary);
+ OvpnCryptoEpochUninitSlot(&cryptoContext->Secondary);
RtlZeroMemory(cryptoContext, sizeof(OvpnCryptoContext));
}
diff --git a/crypto.h b/crypto.h
index 73d286b..636faae 100644
--- a/crypto.h
+++ b/crypto.h
@@ -25,40 +25,35 @@
#include
#include
+#include "crypto_epoch.h"
#include "pktid.h"
#include "uapi\ovpn-dco.h"
#include "socket.h"
+struct OvpnPeerContext;
+
#define OVPN_DATA_V2_LEN 4
#define AEAD_AUTH_TAG_LEN 16
+#define AEAD_LIMIT_BLOCKSIZE 16
+
+// The crypto helper uses this failure status to indicate that the caller must
+// retry the operation while holding the peer spinlock exclusively so key-slot
+// mutation can proceed safely.
+#define STATUS_OVPN_CRYPTO_RETRY ((NTSTATUS)0xC0E44001L)
+
// packet opcode (high 5 bits) and key-id (low 3 bits) are combined in one byte
#define OVPN_OP_DATA_V2 9
#define OVPN_KEY_ID_MASK 0x07
#define OVPN_OPCODE_SHIFT 3
#define OVPN_PEER_ID_MASK 0x00FFFFFF
-struct OvpnCryptoKeySlot
-{
- BCRYPT_KEY_HANDLE EncKey;
- BCRYPT_KEY_HANDLE DecKey;
-
- UCHAR EncNonceTail[8];
- UCHAR DecNonceTail[8];
-
- UCHAR KeyId;
- INT32 PeerId;
-
- OvpnPktidXmit PktidXmit;
- OvpnPktidRecv PktidRecv;
-};
-
_Function_class_(OVPN_CRYPTO_ENCRYPT)
_IRQL_requires_max_(DISPATCH_LEVEL)
_Must_inspect_result_
typedef
NTSTATUS
-OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len, _In_ INT32 CryptoOptions);
+OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len, _In_ OvpnCryptoOptions* opts, BOOLEAN allowRekey);
typedef OVPN_CRYPTO_ENCRYPT* POVPN_CRYPTO_ENCRYPT;
_Function_class_(OVPN_CRYPTO_DECRYPT)
@@ -66,9 +61,15 @@ _IRQL_requires_max_(DISPATCH_LEVEL)
_Must_inspect_result_
typedef
NTSTATUS
-OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut, _In_ INT32 CryptoOptions);
+OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut, _In_ OvpnCryptoOptions* opts, BOOLEAN allowRekey);
typedef OVPN_CRYPTO_DECRYPT* POVPN_CRYPTO_DECRYPT;
+struct OvpnCryptoPacketLayout
+{
+ ULONG FrontLen;
+ ULONG TailLen;
+};
+
struct OvpnCryptoContext
{
OvpnCryptoKeySlot Primary;
@@ -77,24 +78,75 @@ struct OvpnCryptoContext
POVPN_CRYPTO_ENCRYPT Encrypt;
POVPN_CRYPTO_DECRYPT Decrypt;
- INT32 CryptoOptions;
+ OvpnCryptoOptions Options;
+ OvpnCryptoPacketLayout Layout;
+};
+
+
+VOID
+OvpnCryptoDescribePacketLayout(_In_ const OvpnCryptoContext* cryptoContext, _Out_ OvpnCryptoPacketLayout* layout);
+
+typedef
+NTSTATUS
+OVPN_CRYPTO_RETRY_ROUTINE(_In_ OvpnCryptoContext* cryptoContext, _In_ BOOLEAN allowRekey, _Inout_opt_ PVOID context);
+typedef OVPN_CRYPTO_RETRY_ROUTINE* POVPN_CRYPTO_RETRY_ROUTINE;
+
+struct OvpnCryptoEncryptParams
+{
+ PUCHAR Buffer;
+ SIZE_T Length;
+};
+
+struct OvpnCryptoDecryptParams
+{
+ UCHAR KeyId;
+ PUCHAR CipherText;
+ SIZE_T Length;
+ PUCHAR PlainText;
};
+_Must_inspect_result_
+_IRQL_requires_max_(DISPATCH_LEVEL)
+NTSTATUS
+OvpnCryptoCallWithRetry(
+ _In_ OvpnPeerContext* peer,
+ _In_ BOOLEAN atDpcLevel,
+ _Inout_opt_ PBOOLEAN exclusive,
+ _Inout_opt_ PKIRQL kirql,
+ _In_ POVPN_CRYPTO_RETRY_ROUTINE routine,
+ _Inout_opt_ PVOID context);
+
+_Must_inspect_result_
+_IRQL_requires_max_(DISPATCH_LEVEL)
+NTSTATUS
+OvpnCryptoInvokeEncrypt(
+ _In_ OvpnCryptoContext* cryptoContext,
+ _In_ BOOLEAN allowRekey,
+ _Inout_opt_ PVOID context);
+
+_Must_inspect_result_
+_IRQL_requires_max_(DISPATCH_LEVEL)
+NTSTATUS
+OvpnCryptoInvokeDecrypt(
+ _In_ OvpnCryptoContext* cryptoContext,
+ _In_ BOOLEAN allowRekey,
+ _Inout_opt_ PVOID context);
+
_Must_inspect_result_
_IRQL_requires_(PASSIVE_LEVEL)
NTSTATUS
-OvpnCryptoInitAlgHandles(_Outptr_ BCRYPT_ALG_HANDLE* aesAlgHandle, _Outptr_ BCRYPT_ALG_HANDLE* chachaAlgHandle);
+OvpnCryptoInitAlgHandles(_Outptr_ BCRYPT_ALG_HANDLE* aesAlgHandle, _Outptr_ BCRYPT_ALG_HANDLE* chachaAlgHandle, _Outptr_ BCRYPT_ALG_HANDLE* hkdfAlgHandle);
_IRQL_requires_(PASSIVE_LEVEL)
VOID
-OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDLE chachaAlgHandle);
+OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDLE chachaAlgHandle, BCRYPT_ALG_HANDLE hkdfAlgHandle);
VOID
OvpnCryptoUninit(_In_ OvpnCryptoContext* cryptoContext);
_Must_inspect_result_
NTSTATUS
-OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA_V2 cryptoData, _In_opt_ BCRYPT_ALG_HANDLE algHandle);
+OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA_V2 cryptoData, _In_opt_ BCRYPT_ALG_HANDLE algHandle, _In_opt_ BCRYPT_ALG_HANDLE hkdfAlgHandle);
_Must_inspect_result_
OvpnCryptoKeySlot*
@@ -115,3 +167,30 @@ UCHAR OvpnCryptoOpcodeExtract(UCHAR op)
{
return op >> OVPN_OPCODE_SHIFT;
}
+
+static inline
+BOOLEAN
+OvpnCryptoAeadUsageLimitReached(UINT64 limit, UINT64 plaintextBlocks, UINT64 highestPid)
+{
+ /* This is the q + s <= p^(1/2) * 2^(129/2) - 1 calculation where
+ * q is the number of protected messages (highest_pid)
+ * s Total plaintext length in all messages (in blocks) */
+ return ((limit > 0) && (plaintextBlocks + highestPid) > limit);
+}
+
+static inline
+UINT64
+OvpnCryptoAeadUsageLimit(OVPN_CIPHER_ALG alg)
+{
+ switch (alg)
+ {
+ case OVPN_CIPHER_ALG_NONE:
+ return 0;
+
+ case OVPN_CIPHER_ALG_CHACHA20_POLY1305:
+ return 0;
+
+ default:
+ return (1ull << 36) - 1; // limit for AES-GCM
+ }
+}
diff --git a/crypto_epoch.cpp b/crypto_epoch.cpp
new file mode 100644
index 0000000..c62a06e
--- /dev/null
+++ b/crypto_epoch.cpp
@@ -0,0 +1,318 @@
+/*
+ * ovpn-dco-win OpenVPN protocol accelerator for Windows
+ *
+ * Copyright (C) 2025- OpenVPN Inc
+ *
+ * Author: Lev Stipakov
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 2
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program; if not, write to the Free Software Foundation, Inc.,
+ * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "crypto_epoch.h"
+#include "trace.h"
+
+// Derive bytes via HKDF-Expand with PRK = E_i, using bcrypt HKDF.
+// info = OvpnMakeLabel(L, label)
+_Use_decl_annotations_
+NTSTATUS OvpnCryptoExpandLabel(
+ BCRYPT_ALG_HANDLE hkdfAlg,
+ const UCHAR* E_i, // PRK (32 bytes for SHA-256)
+ USHORT outLen, // bytes to derive
+ const char* label, // "data_key" / "data_iv" / "datakey upd"
+ UCHAR* outBytes
+)
+{
+ NTSTATUS status = STATUS_SUCCESS;
+
+ BCRYPT_KEY_HANDLE hKey = NULL;
+
+ // create key handle with PRK bytes
+ GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(hkdfAlg, &hKey, NULL, 0, (PUCHAR)E_i, 32, 0));
+
+ // select SHA-256
+ // BCRYPT_SHA256_ALGORITHM is a wide literal; sizeof(..) includes the NUL in bytes.
+ GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptSetProperty(hKey, BCRYPT_HKDF_HASH_ALGORITHM, (PUCHAR)BCRYPT_SHA256_ALGORITHM, (ULONG)sizeof(BCRYPT_SHA256_ALGORITHM), 0));
+
+ // tell HKDF we're already supplying the PRK in the key handle:
+ // passing NULL,0 just switches to "PRK is finalized" mode.
+ GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptSetProperty(hKey, BCRYPT_HKDF_PRK_AND_FINALIZE, NULL, 0, 0));
+
+ // build info = OvpnLabel(outLen, "data_key"/"data_iv")
+ UCHAR info[2 + 1 + 64 + 1 + 255]; // enough for our labels
+ ULONG infoLen = 0;
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoMakeLabel(info, (ULONG)sizeof(info), &infoLen, outLen, label));
+
+ // prepare KDF params
+ BCryptBuffer infoBuf;
+ BCryptBufferDesc desc;
+
+ RtlZeroMemory(&infoBuf, sizeof(infoBuf));
+ RtlZeroMemory(&desc, sizeof(desc));
+
+ infoBuf.cbBuffer = infoLen;
+ infoBuf.BufferType = KDF_HKDF_INFO;
+ infoBuf.pvBuffer = info;
+
+ desc.ulVersion = BCRYPTBUFFER_VERSION;
+ desc.cBuffers = 1;
+ desc.pBuffers = &infoBuf;
+
+ // derive
+ ULONG got = 0;
+ GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptKeyDerivation(hKey, &desc, outBytes, outLen, &got, 0));
+ if (got != outLen) {
+ status = STATUS_INTERNAL_ERROR;
+ }
+
+done:
+ if (hKey) BCryptDestroyKey(hKey);
+ return status;
+}
+
+// Build TLS 1.3-style label with "ovpn " prefix, into caller buffer.
+// struct {
+// uint16 length = L;
+// opaque label<6..255> = "ovpn " + Label;
+// opaque context<0..255>;
+// } OvpnLabel;
+_Use_decl_annotations_
+NTSTATUS OvpnCryptoMakeLabel(
+ UCHAR* out,
+ ULONG cbOut,
+ ULONG* pcbWritten,
+ USHORT L,
+ const char* label)
+{
+ NTSTATUS status = STATUS_SUCCESS;
+
+ static const char prefix[] = "ovpn ";
+
+ const size_t prefixLen = sizeof(prefix) - 1;
+ size_t labelLen = 0;
+
+ GOTO_IF_NOT_NT_SUCCESS(done, status, RtlStringCbLengthA(label, 256, &labelLen)); // labels are tiny, 256 is safe
+
+ *pcbWritten = 0;
+
+ // Total encoded label length = "ovpn " + label
+ size_t totalLabelLen = prefixLen + labelLen;
+ if (totalLabelLen < 6 || totalLabelLen > 255) {
+ status = STATUS_INVALID_PARAMETER;
+ goto done;
+ }
+
+ // total = 2(length) + 1(totalLabelLen) + totalLabelLen + 1(ctxLen=0)
+ ULONG need = 2 + 1 + (ULONG)totalLabelLen + 1;
+ if (cbOut < need) {
+ status = STATUS_BUFFER_TOO_SMALL;
+ goto done;
+ }
+
+ ULONG p = 0;
+ out[p++] = (UCHAR)(L >> 8);
+ out[p++] = (UCHAR)(L & 0xFF);
+ out[p++] = (UCHAR)totalLabelLen;
+
+ // "ovpn "
+ RtlCopyMemory(out + p, prefix, prefixLen);
+ p += (ULONG)prefixLen;
+
+
+ // Label
+ RtlCopyMemory(out + p, label, labelLen);
+ p += (ULONG)labelLen;
+
+ // context length = 0
+ out[p++] = 0;
+
+ *pcbWritten = p;
+
+done:
+ return status;
+}
+
+VOID
+OvpnCryptoEpochInitKey(OvpnCryptoKeyContext* ctx, OvpnCryptoEpochKey* epochKey, OvpnCryptoOptions* opts)
+{
+ LOG_INFO("Epoch Data Key", TraceLoggingValue(epochKey->Epoch, "epoch"));
+
+ OvpnCryptoKeyParameters key{0};
+ OvpnCryptoEpochDataKeyDerive(&key, epochKey, opts->HkdfAlgHandle, opts->AeadAlgHangle, opts->KeyLen);
+ ctx->Epoch = key.Epoch;
+ ctx->Key = key.KeyHandle;
+ RtlCopyMemory(ctx->ImplicitIV, key.IV, sizeof(ctx->ImplicitIV));
+
+ RtlSecureZeroMemory(&key, sizeof(OvpnCryptoKeyParameters));
+}
+
+NTSTATUS
+OvpnCryptoEpochDataKeyDerive(OvpnCryptoKeyParameters* key, OvpnCryptoEpochKey* epochKey, BCRYPT_ALG_HANDLE hkdfAlgHandle, BCRYPT_ALG_HANDLE algHandle, UCHAR cipherSize)
+{
+ NTSTATUS status;
+
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoExpandLabel(hkdfAlgHandle, epochKey->EpochKey, cipherSize, "data_key", key->Cipher));
+ GOTO_IF_NOT_NT_SUCCESS(done, status, BCryptGenerateSymmetricKey(algHandle, &key->KeyHandle, NULL, 0, key->Cipher, cipherSize, 0));
+
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoExpandLabel(hkdfAlgHandle, epochKey->EpochKey, 12, "data_iv", key->IV));
+
+ key->Epoch = epochKey->Epoch;
+
+done:
+ return status;
+}
+
+VOID
+OvpnCryptoEpochKeyIterate(OvpnCryptoEpochKey* epochKey, BCRYPT_ALG_HANDLE hkdfAlgHandle)
+{
+ ++epochKey->Epoch;
+ OvpnCryptoExpandLabel(hkdfAlgHandle, epochKey->EpochKey, sizeof(epochKey->EpochKey), "datakey upd", epochKey->EpochKey);
+}
+
+VOID
+OvpnCryptoEpochGenerateFutureRecvKeys(OvpnCryptoKeySlot* keySlot, OvpnCryptoOptions* opts)
+{
+ UINT16 currentDecryptEpoch = keySlot->Decrypt.Epoch;
+
+ // free unused keys
+ for (int i = 0; i < FUTURE_EPOCH_KEYS_COUNT; ++i) {
+ auto key = &keySlot->FutureEpochKeys[i];
+ if ((key->Epoch > 0) && (key->Epoch < currentDecryptEpoch)) {
+ BCryptDestroyKey(key->Key);
+ RtlZeroMemory(key, sizeof(*key));
+ }
+ }
+
+ auto highestFutureKey = &keySlot->FutureEpochKeys[FUTURE_EPOCH_KEYS_COUNT - 1];
+
+ UINT16 currentHighestKey = highestFutureKey->Epoch ? highestFutureKey->Epoch : 1;
+ UINT16 desiredHighestKey = currentDecryptEpoch + FUTURE_EPOCH_KEYS_COUNT;
+ UINT16 numKeysGenerate = desiredHighestKey - currentHighestKey;
+
+ RtlMoveMemory(keySlot->FutureEpochKeys, &keySlot->FutureEpochKeys[numKeysGenerate], (FUTURE_EPOCH_KEYS_COUNT - numKeysGenerate) * sizeof(OvpnCryptoKeyContext));
+
+ for (int i = 16 - numKeysGenerate; i < FUTURE_EPOCH_KEYS_COUNT; ++i)
+ {
+ RtlSecureZeroMemory(&keySlot->FutureEpochKeys[i], sizeof(OvpnCryptoKeyContext));
+
+ OvpnCryptoEpochKeyIterate(&keySlot->EpochKeyRecv, opts->HkdfAlgHandle);
+ OvpnCryptoEpochInitKey(&keySlot->FutureEpochKeys[i], &keySlot->EpochKeyRecv, opts);
+ }
+}
+
+VOID
+OvpnCryptoEpochReplaceUpdateRecvKey(OvpnCryptoKeySlot* keySlot, UINT16 new_epoch, OvpnCryptoOptions* opts)
+{
+ // Find the key of the new epoch in future keys
+ UINT16 fki;
+ for (fki = 0; fki < FUTURE_EPOCH_KEYS_COUNT; fki++) {
+ if (keySlot->FutureEpochKeys[fki].Epoch == new_epoch) {
+ break;
+ }
+ }
+
+ OvpnCryptoKeyContext* ctx = &keySlot->FutureEpochKeys[fki];
+
+ // Check if the new recv key epoch is higher than the send key epoch. If yes we will replace the send key as well
+ if (keySlot->Encrypt.Epoch < new_epoch) {
+ BCryptDestroyKey(keySlot->Encrypt.Key);
+ RtlZeroMemory(&keySlot->Encrypt, sizeof(OvpnCryptoKeyContext));
+
+ // Update the epoch_key for send to match the current key being used
+ while (keySlot->EpochKeySend.Epoch < new_epoch) {
+ OvpnCryptoEpochKeyIterate(&keySlot->EpochKeySend, opts->HkdfAlgHandle);
+ }
+ OvpnCryptoEpochInitKey(&keySlot->Encrypt, &keySlot->EpochKeySend, opts);
+ }
+
+ // Replace receive key
+ BCryptDestroyKey(keySlot->RetiringEpochDataReceiveKey.Key);
+ RtlZeroMemory(&keySlot->RetiringEpochDataReceiveKey, sizeof(OvpnCryptoKeyContext));
+
+ keySlot->RetiringEpochDataReceiveKey = keySlot->Decrypt;
+
+ keySlot->Decrypt = *ctx;
+
+ RtlZeroMemory(ctx, sizeof(*ctx));
+
+ // Generate new future keys
+ OvpnCryptoEpochGenerateFutureRecvKeys(keySlot, opts);
+}
+
+OvpnCryptoKeyContext*
+OvpnCryptoEpochLookupDecryptKey(OvpnCryptoKeySlot* keySlot, UINT16 epoch)
+{
+ /* Current decrypt key is the most likely one */
+ if (keySlot->Decrypt.Epoch == epoch) {
+ return &keySlot->Decrypt;
+ }
+ else if (keySlot->RetiringEpochDataReceiveKey.Epoch && keySlot->RetiringEpochDataReceiveKey.Epoch == epoch) {
+ return &keySlot->RetiringEpochDataReceiveKey;
+ }
+ else if (epoch > keySlot->Decrypt.Epoch && epoch <= keySlot->Decrypt.Epoch + FUTURE_EPOCH_KEYS_COUNT) {
+ // Key in the range of future keys
+ int index = epoch - (keySlot->Decrypt.Epoch + 1);
+
+ if (epoch > (UINT16_MAX - FUTURE_EPOCH_KEYS_COUNT - 1)) {
+ return NULL;
+ }
+ else {
+ return &keySlot->FutureEpochKeys[index];
+ }
+ }
+ else {
+ return NULL;
+ }
+}
+
+VOID
+OvpnCryptoMakeEpochNonce(UCHAR* epochIv, UINT64 packet_id_net, UCHAR* nonce)
+{
+ // first 8 bytes of IV (aka nonce) is pktid_net XOR 8 bytes of epoch's implicit IV
+ UINT64 iv0;
+ RtlCopyMemory(&iv0, epochIv, sizeof(iv0));
+ iv0 ^= packet_id_net;
+ RtlCopyMemory(nonce, &iv0, sizeof(iv0));
+
+ // last 4 bytes of IV are from epoch's implicit IV
+ RtlCopyMemory(nonce + 8, epochIv + 8, 4);
+}
+
+VOID
+OvpnCryptoEpochIterateSendKey(OvpnCryptoKeySlot* keySlot, OvpnCryptoOptions* opts)
+{
+ OvpnCryptoEpochKeyIterate(&keySlot->EpochKeySend, opts->HkdfAlgHandle);
+
+ BCryptDestroyKey(keySlot->Encrypt.Key);
+ RtlSecureZeroMemory(&keySlot->Encrypt, sizeof(OvpnCryptoKeyContext));
+ OvpnCryptoEpochInitKey(&keySlot->Encrypt, &keySlot->EpochKeySend, opts);
+
+ RtlZeroMemory(&keySlot->PktidXmit, sizeof(keySlot->PktidXmit));
+}
+
+VOID
+OvpnCryptoEpochUninitSlot(OvpnCryptoKeySlot* slot)
+{
+ if (slot->Encrypt.Key) {
+ BCryptDestroyKey(slot->Encrypt.Key);
+ }
+ for (int i = 0; i < FUTURE_EPOCH_KEYS_COUNT; ++i) {
+ if (slot->FutureEpochKeys[i].Key) {
+ BCryptDestroyKey(slot->FutureEpochKeys[i].Key);
+ }
+ }
+ if (slot->RetiringEpochDataReceiveKey.Key) {
+ BCryptDestroyKey(slot->RetiringEpochDataReceiveKey.Key);
+ }
+ RtlSecureZeroMemory(slot, sizeof(OvpnCryptoKeySlot));
+}
\ No newline at end of file
diff --git a/crypto_epoch.h b/crypto_epoch.h
new file mode 100644
index 0000000..808bcc3
--- /dev/null
+++ b/crypto_epoch.h
@@ -0,0 +1,157 @@
+/*
+ * ovpn-dco-win OpenVPN protocol accelerator for Windows
+ *
+ * Copyright (C) 2025- OpenVPN Inc
+ *
+ * Author: Lev Stipakov
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 2
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program; if not, write to the Free Software Foundation, Inc.,
+ * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#if defined(_KERNEL_MODE)
+#include
+#include
+#include "trace.h"
+#else
+#define WIN32_NO_STATUS // keep windows.h from redefining STATUS_*
+#include
+#undef WIN32_NO_STATUS
+#include
+#include // exposes STATUS_SUCCESS, NT_SUCCESS, etc.
+#include
+#define RtlStringCbLengthA StringCbLengthA
+#ifndef UINT16_MAX
+#define UINT16_MAX 0xFFFF
+#endif
+#endif
+
+#include
+
+#include "pktid.h"
+
+#define PACKET_ID_EPOCH_MAX 0x0000FFFFFFFFFFFFull
+
+#define FUTURE_EPOCH_KEYS_COUNT 16
+
+struct OvpnCryptoKeyContext
+{
+ BCRYPT_KEY_HANDLE Key;
+ UCHAR ImplicitIV[12];
+
+ // number of plaintext blocks encrypted using this key
+ UINT64 PlaintextBlocks;
+ UINT16 Epoch;
+};
+
+struct OvpnCryptoEpochKey
+{
+ UCHAR EpochKey[32];
+ UINT16 Epoch;
+};
+
+struct OvpnCryptoKeyParameters
+{
+ UCHAR Cipher[32];
+ UCHAR IV[12];
+ BCRYPT_KEY_HANDLE KeyHandle;
+ UINT16 Epoch;
+};
+
+struct OvpnCryptoOptions {
+ // Limit for AEAD cipher, sum of packets + blocks. Will switch to the new epoch when reached.
+ UINT64 AeadUsageLimit;
+
+ BOOLEAN UseEpoch;
+
+ UCHAR KeyLen;
+
+ BCRYPT_ALG_HANDLE HkdfAlgHandle;
+ BCRYPT_ALG_HANDLE AeadAlgHangle;
+};
+
+NTSTATUS OvpnCryptoExpandLabel(
+ BCRYPT_ALG_HANDLE hkdfAlg,
+ _In_reads_bytes_(32) const UCHAR* E_i, // PRK (32 bytes for SHA-256)
+ _In_ USHORT outLen, // bytes to derive
+ _In_z_ const char* label, // "data_key" / "data_iv" / "datakey upd"
+ _Out_writes_bytes_(outLen) UCHAR* outBytes
+);
+
+_IRQL_requires_max_(PASSIVE_LEVEL)
+static
+NTSTATUS OvpnCryptoMakeLabel(
+ _Out_writes_bytes_to_(cbOut, *pcbWritten) UCHAR* out,
+ _In_ ULONG cbOut,
+ _Out_ ULONG* pcbWritten,
+ _In_ USHORT L,
+ _In_z_ const char* label);
+
+struct OvpnCryptoKeySlot
+{
+ OvpnCryptoKeyContext Encrypt;
+ OvpnCryptoKeyContext Decrypt;
+
+ // last epoch key used for generating current send data keys
+ OvpnCryptoEpochKey EpochKeySend;
+
+ // epoch key used for the highest receive epoch keys
+ OvpnCryptoEpochKey EpochKeyRecv;
+
+ UCHAR KeyId;
+ INT32 PeerId;
+
+ OvpnPktidXmit PktidXmit;
+ OvpnPktidRecv PktidRecv;
+
+ // future epoch data keys for decryption
+ OvpnCryptoKeyContext FutureEpochKeys[FUTURE_EPOCH_KEYS_COUNT];
+
+ OvpnPktidRecv PktidRecvRetiring;
+ OvpnCryptoKeyContext RetiringEpochDataReceiveKey;
+};
+
+// Initialises data channel key/IV using the provided epoch key
+VOID
+OvpnCryptoEpochInitKey(OvpnCryptoKeyContext* ctx, OvpnCryptoEpochKey* epochKey, OvpnCryptoOptions* opts);
+
+// Generates a data channel key/IV from the epoch key
+NTSTATUS
+OvpnCryptoEpochDataKeyDerive(OvpnCryptoKeyParameters* key, OvpnCryptoEpochKey* epochKey, BCRYPT_ALG_HANDLE hkdfAlgHandle, BCRYPT_ALG_HANDLE algHandle, UCHAR cipherSize);
+
+VOID
+OvpnCryptoEpochKeyIterate(OvpnCryptoEpochKey* epochKey, BCRYPT_ALG_HANDLE hkdfAlgHandle);
+/**
+ * Generates and fills the FutureEpochKeys with next valid future keys
+ * using the epoch of the key in keySlot->EpochKeyRecv as starting point
+ */
+VOID
+OvpnCryptoEpochGenerateFutureRecvKeys(OvpnCryptoKeySlot* keySlot, OvpnCryptoOptions* opts);
+
+// This is called when the peer uses a new send key that is not the default key
+VOID
+OvpnCryptoEpochReplaceUpdateRecvKey(OvpnCryptoKeySlot* keySlot, UINT16 new_epoch, OvpnCryptoOptions* opts);
+
+// retrieve decryption key context that matches the epoch
+OvpnCryptoKeyContext*
+OvpnCryptoEpochLookupDecryptKey(OvpnCryptoKeySlot* keySlot, UINT16 epoch);
+
+VOID
+OvpnCryptoMakeEpochNonce(UCHAR* epochIv, UINT64 packet_id_net, UCHAR* nonce);
+
+// Updates the send key and keySlot->EpochKeySend to use the next epoch
+VOID
+OvpnCryptoEpochIterateSendKey(OvpnCryptoKeySlot* keySlot, OvpnCryptoOptions* opts);
+
+VOID
+OvpnCryptoEpochUninitSlot(OvpnCryptoKeySlot* slot);
\ No newline at end of file
diff --git a/ovpn-dco-win.vcxproj b/ovpn-dco-win.vcxproj
index bcfff8e..fbadd69 100644
--- a/ovpn-dco-win.vcxproj
+++ b/ovpn-dco-win.vcxproj
@@ -71,6 +71,7 @@
+
@@ -87,6 +88,7 @@
+
@@ -98,7 +100,6 @@
-
diff --git a/ovpn-dco-win.vcxproj.filters b/ovpn-dco-win.vcxproj.filters
index adf6018..c1cf7e4 100644
--- a/ovpn-dco-win.vcxproj.filters
+++ b/ovpn-dco-win.vcxproj.filters
@@ -52,9 +52,6 @@
Header Files
-
- Header Files
-
Header Files\uapi
@@ -76,6 +73,9 @@
Header Files
+
+ Header Files
+
@@ -120,6 +120,9 @@
Source Files
+
+ Source Files
+
diff --git a/peer.cpp b/peer.cpp
index 41fe3d9..4a61124 100644
--- a/peer.cpp
+++ b/peer.cpp
@@ -250,13 +250,7 @@ OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId, BOOLEAN dpc)
OvpnPeerContext* peer = nullptr;
OvpnPeerContext** ptr = nullptr;
- KIRQL kirql = 0;
- if (dpc) {
- ExAcquireSpinLockSharedAtDpcLevel(&device->SpinLock);
- }
- else {
- kirql = ExAcquireSpinLockShared(&device->SpinLock);
- }
+ KIRQL kirql = OvpnAcquireSpinLock(dpc, &device->SpinLock, FALSE);
if (device->Mode == OVPN_MODE_P2P) {
ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
@@ -275,12 +269,7 @@ OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId, BOOLEAN dpc)
InterlockedIncrement(&peer->RefCounter);
}
- if (dpc) {
- ExReleaseSpinLockSharedFromDpcLevel(&device->SpinLock);
- }
- else {
- ExReleaseSpinLockShared(&device->SpinLock, kirql);
- }
+ OvpnReleaseSpinLock(dpc, kirql, &device->SpinLock, FALSE);
return peer;
}
@@ -292,13 +281,7 @@ OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr, BOOLEAN dpc)
OvpnPeerContext* peer = nullptr;
OvpnPeerContext** ptr = nullptr;
- KIRQL kirql = 0;
- if (dpc) {
- ExAcquireSpinLockSharedAtDpcLevel(&device->SpinLock);
- }
- else {
- kirql = ExAcquireSpinLockShared(&device->SpinLock);
- }
+ KIRQL kirql = OvpnAcquireSpinLock(dpc, &device->SpinLock, FALSE);
if (device->Mode == OVPN_MODE_P2P) {
ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
@@ -316,12 +299,7 @@ OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr, BOOLEAN dpc)
InterlockedIncrement(&peer->RefCounter);
}
- if (dpc) {
- ExReleaseSpinLockSharedFromDpcLevel(&device->SpinLock);
- }
- else {
- ExReleaseSpinLockShared(&device->SpinLock, kirql);
- }
+ OvpnReleaseSpinLock(dpc, kirql, &device->SpinLock, FALSE);
return peer;
}
@@ -333,13 +311,7 @@ OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr, BOOLEAN dpc)
OvpnPeerContext* peer = nullptr;
OvpnPeerContext** ptr = nullptr;
- KIRQL kirql = 0;
- if (dpc) {
- ExAcquireSpinLockSharedAtDpcLevel(&device->SpinLock);
- }
- else {
- kirql = ExAcquireSpinLockShared(&device->SpinLock);
- }
+ KIRQL kirql = OvpnAcquireSpinLock(dpc, &device->SpinLock, FALSE);
if (device->Mode == OVPN_MODE_P2P) {
ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
@@ -357,12 +329,7 @@ OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr, BOOLEAN dpc)
InterlockedIncrement(&peer->RefCounter);
}
- if (dpc) {
- ExReleaseSpinLockSharedFromDpcLevel(&device->SpinLock);
- }
- else {
- ExReleaseSpinLockShared(&device->SpinLock, kirql);
- }
+ OvpnReleaseSpinLock(dpc, kirql, &device->SpinLock, FALSE);
return peer;
}
@@ -377,13 +344,7 @@ OvpnFindPeerTransport(POVPN_DEVICE device, PSOCKADDR sa, BOOLEAN dpc)
OvpnPeerContext* peer = nullptr;
OvpnPeerContext** ptr = nullptr;
- KIRQL kirql = 0;
- if (dpc) {
- ExAcquireSpinLockSharedAtDpcLevel(&device->SpinLock);
- }
- else {
- kirql = ExAcquireSpinLockShared(&device->SpinLock);
- }
+ KIRQL kirql = OvpnAcquireSpinLock(dpc, &device->SpinLock, FALSE);
if (device->Mode == OVPN_MODE_P2P) {
ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
@@ -404,12 +365,7 @@ OvpnFindPeerTransport(POVPN_DEVICE device, PSOCKADDR sa, BOOLEAN dpc)
InterlockedIncrement(&peer->RefCounter);
}
- if (dpc) {
- ExReleaseSpinLockSharedFromDpcLevel(&device->SpinLock);
- }
- else {
- ExReleaseSpinLockShared(&device->SpinLock, kirql);
- }
+ OvpnReleaseSpinLock(dpc, kirql, &device->SpinLock, FALSE);
return peer;
}
@@ -930,7 +886,7 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request)
}
RtlCopyMemory(&cryptoDataV2.V1, cryptoData, sizeof(OVPN_CRYPTO_DATA));
- GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, &cryptoDataV2, algHandle));
+ GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, &cryptoDataV2, algHandle, NULL));
done:
if (peer != nullptr) {
@@ -965,7 +921,7 @@ OvpnPeerNewKeyV2(POVPN_DEVICE device, WDFREQUEST request)
}
KIRQL irql = ExAcquireSpinLockExclusive(&peer->SpinLock);
- LOG_IF_NOT_NT_SUCCESS(status = OvpnCryptoNewKey(&peer->CryptoContext, cryptoDataV2, algHandle));
+ LOG_IF_NOT_NT_SUCCESS(status = OvpnCryptoNewKey(&peer->CryptoContext, cryptoDataV2, algHandle, device->HkdfAlgHandle));
ExReleaseSpinLockExclusive(&peer->SpinLock, irql);
done:
@@ -1133,20 +1089,12 @@ OvpnPeerHandleFloat(OVPN_DEVICE* device, OvpnPeerContext *peer, PSOCKADDR sa, BO
TraceLoggingIPv6Address(&peer->TransportAddrs.Remote.IPv6.sin6_addr, "src"),
TraceLoggingIPv6Address(&((SOCKADDR_IN6*)sa)->sin6_addr, "dst"));
- KIRQL kirql = 0;
-
// remove peer from by-transport-address hashtable
OvpnDeletePeerFromTable(device, &device->PeersByTransport, peer, "transport");
// modify peer's transport address
{
- // exclusive-lock peer
- if (dpc) {
- ExAcquireSpinLockExclusiveAtDpcLevel(&peer->SpinLock);
- }
- else {
- kirql = ExAcquireSpinLockExclusive(&peer->SpinLock);
- }
+ KIRQL kirql = OvpnAcquireSpinLock(dpc, &peer->SpinLock, TRUE);
// update peer's transport address
if (sa->sa_family == AF_INET)
@@ -1154,13 +1102,7 @@ OvpnPeerHandleFloat(OVPN_DEVICE* device, OvpnPeerContext *peer, PSOCKADDR sa, BO
else
RtlCopyMemory(&peer->TransportAddrs.Remote.IPv6, sa, sizeof(SOCKADDR_IN6));
- // exclusive-unlock peer
- if (dpc) {
- ExReleaseSpinLockExclusiveFromDpcLevel(&peer->SpinLock);
- }
- else {
- ExReleaseSpinLockExclusive(&peer->SpinLock, kirql);
- }
+ OvpnReleaseSpinLock(dpc, kirql, &peer->SpinLock, TRUE);
}
// add peer back to by-transport-address hashtable
diff --git a/peer.h b/peer.h
index ac5e707..75505b9 100644
--- a/peer.h
+++ b/peer.h
@@ -182,3 +182,48 @@ OvpnPeerGetDelReasonString(OVPN_DEL_PEER_REASON reason);
NTSTATUS
OvpnPeerHandleFloat(OVPN_DEVICE* device, OvpnPeerContext* peer, PSOCKADDR sa, BOOLEAN dpc);
+
+static inline KIRQL
+OvpnAcquireSpinLock(BOOLEAN dpc, PEX_SPIN_LOCK spinLock, BOOLEAN exclusive)
+{
+ KIRQL kirql = 0;
+ if (dpc) {
+ if (exclusive) {
+ ExAcquireSpinLockExclusiveAtDpcLevel(spinLock);
+ }
+ else {
+ ExAcquireSpinLockSharedAtDpcLevel(spinLock);
+ }
+ }
+ else {
+ if (exclusive) {
+ kirql = ExAcquireSpinLockExclusive(spinLock);
+ }
+ else {
+ kirql = ExAcquireSpinLockShared(spinLock);
+ }
+ }
+
+ return kirql;
+}
+
+static inline VOID
+OvpnReleaseSpinLock(BOOLEAN dpc, KIRQL kirql, PEX_SPIN_LOCK spinLock, BOOLEAN exclusive)
+{
+ if (dpc) {
+ if (exclusive) {
+ ExReleaseSpinLockExclusiveFromDpcLevel(spinLock);
+ }
+ else {
+ ExReleaseSpinLockSharedFromDpcLevel(spinLock);
+ }
+ }
+ else {
+ if (exclusive) {
+ ExReleaseSpinLockExclusive(spinLock, kirql);
+ }
+ else {
+ ExReleaseSpinLockShared(spinLock, kirql);
+ }
+ }
+}
diff --git a/pktid.cpp b/pktid.cpp
index 4d94ee3..2116600 100644
--- a/pktid.cpp
+++ b/pktid.cpp
@@ -28,23 +28,17 @@
#define PKTID_WRAP_WARN 0xf0000000ULL
_Use_decl_annotations_
-NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, VOID* pktId, BOOLEAN pktId64bit)
+NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, VOID* pktId)
{
ULONG64 seqNum = InterlockedIncrementNoFence64(&px->SeqNum);
- if (pktId64bit) {
- *static_cast(pktId) = seqNum;
+ *static_cast(pktId) = static_cast(seqNum);
+ if (seqNum >= PKTID_WRAP_WARN) {
+ LOG_ERROR("Pktid wrapped");
+ return STATUS_INTEGER_OVERFLOW;
+ } else {
+ return STATUS_SUCCESS;
}
- else
- {
- *static_cast(pktId) = static_cast(seqNum);
- if (seqNum >= PKTID_WRAP_WARN) {
- LOG_ERROR("Pktid wrapped");
- return STATUS_INTEGER_OVERFLOW;
- }
- }
-
- return STATUS_SUCCESS;
}
#define PKTID_RECV_EXPIRE ((30 * WDF_TIMEOUT_TO_SEC) / KeQueryTimeIncrement())
diff --git a/pktid.h b/pktid.h
index dcc4be8..6d2518b 100644
--- a/pktid.h
+++ b/pktid.h
@@ -21,7 +21,9 @@
#pragma once
+#if defined(_KERNEL_MODE)
#include
+#endif
struct OvpnPktidXmit
{
@@ -56,8 +58,8 @@ struct OvpnPktidRecv
UINT64 IdFloor;
};
-/* Get the next packet ID for xmit */
-NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ VOID* pktId, BOOLEAN pktId64bit);
+/* Get the next packet ID for xmit. Used only for non-epoch crypto. */
+NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ VOID* pktId);
/* Packet replay detection.
diff --git a/socket.cpp b/socket.cpp
index de82870..61d613e 100644
--- a/socket.cpp
+++ b/socket.cpp
@@ -211,43 +211,26 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 pee
return;
}
- // If we're at dispatch level, we can use a small optimization and use function
- // which is not calling KeRaiseIRQL to raise the IRQL to DISPATCH_LEVEL before attempting to acquire the lock
- KIRQL kirql = 0;
- if (dpc) {
- ExAcquireSpinLockSharedAtDpcLevel(&peer->SpinLock);
- }
- else {
- kirql = ExAcquireSpinLockShared(&peer->SpinLock);
- }
+ KIRQL kirql = OvpnAcquireSpinLock(dpc, &peer->SpinLock, FALSE);
+
+ BOOLEAN exclusive = FALSE;
OvpnCryptoContext* cryptoContext = &peer->CryptoContext;
if (cryptoContext->Decrypt) {
UCHAR keyId = OvpnCryptoKeyIdExtract(op);
- OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(cryptoContext, keyId);
- if (!keySlot) {
- status = STATUS_INVALID_DEVICE_STATE;
- LOG_ERROR("keyId not found", TraceLoggingValue(keyId, "keyId"));
- }
- else {
- // extend data area in the buffer for plaintext and crypto overhead
- OvpnBufferPut(buffer, len);
+ // extend data area in the buffer for plaintext and crypto overhead
+ OvpnBufferPut(buffer, len);
- // decrypt into plaintext buffer
- status = cryptoContext->Decrypt(keySlot, cipherTextBuf, len, buffer->Data, cryptoContext->CryptoOptions);
+ OvpnCryptoDecryptParams decryptParams = { keyId, cipherTextBuf, len, buffer->Data };
+ status = OvpnCryptoCallWithRetry(peer, dpc, &exclusive, dpc ? nullptr : &kirql, OvpnCryptoInvokeDecrypt, &decryptParams);
- // trim AEAD tag an the end
- auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
- if (aeadTagEnd) {
- OvpnBufferTrim(buffer, len - AEAD_AUTH_TAG_LEN);
- }
+ if (NT_SUCCESS(status)) {
+ const OvpnCryptoPacketLayout layout = cryptoContext->Layout;
- // remove crypto overhead in front
- auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
- auto cryptoOverheadFront = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN);
- OvpnBufferPull(buffer, cryptoOverheadFront);
+ OvpnBufferTrim(buffer, len - layout.TailLen);
+ OvpnBufferPull(buffer, layout.FrontLen);
}
}
else {
@@ -265,13 +248,7 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, UINT32 pee
auto mss = peer->MSS;
- // don't forget to release spinlock
- if (dpc) {
- ExReleaseSpinLockSharedFromDpcLevel(&peer->SpinLock);
- }
- else {
- ExReleaseSpinLockShared(&peer->SpinLock, kirql);
- }
+ OvpnReleaseSpinLock(dpc, kirql, &peer->SpinLock, exclusive);
// decrypt failed - don't proceed
if (!NT_SUCCESS(status)) {
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
new file mode 100644
index 0000000..47ed056
--- /dev/null
+++ b/tests/CMakeLists.txt
@@ -0,0 +1,28 @@
+cmake_minimum_required(VERSION 3.14)
+
+project(tests)
+
+include(FetchContent)
+FetchContent_Declare(
+ googletest
+ URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip
+ DOWNLOAD_EXTRACT_TIMESTAMP true)
+
+set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
+FetchContent_MakeAvailable(googletest)
+
+enable_testing()
+
+add_executable(
+ tests
+ tests.cpp
+ ../crypto_epoch.cpp
+)
+target_link_libraries(
+ tests
+ GTest::gtest_main
+ bcrypt
+)
+
+include(GoogleTest)
+gtest_discover_tests(tests)
\ No newline at end of file
diff --git a/tests/tests.cpp b/tests/tests.cpp
new file mode 100644
index 0000000..98ee58d
--- /dev/null
+++ b/tests/tests.cpp
@@ -0,0 +1,219 @@
+#include
+#include
+
+#include "../crypto_epoch.h"
+
+class CryptoTest : public testing::Test
+{
+protected:
+ void SetUp() override
+ {
+ RtlZeroMemory(&opts, sizeof(opts));
+ RtlZeroMemory(&keySlot, sizeof(keySlot));
+
+ opts.KeyLen = 32;
+ opts.UseEpoch = 1;
+
+ ASSERT_EQ(BCryptOpenAlgorithmProvider(&opts.HkdfAlgHandle, BCRYPT_HKDF_ALGORITHM, NULL, 0), STATUS_SUCCESS);
+ ASSERT_EQ(BCryptOpenAlgorithmProvider(&opts.AeadAlgHangle, BCRYPT_AES_ALGORITHM, NULL, 0), STATUS_SUCCESS);
+
+ keySlot.EpochKeySend.Epoch = 1;
+ keySlot.EpochKeyRecv.Epoch = 1;
+
+ OvpnCryptoEpochInitKey(&keySlot.Encrypt, &keySlot.EpochKeySend, &opts);
+ OvpnCryptoEpochInitKey(&keySlot.Decrypt, &keySlot.EpochKeyRecv, &opts);
+
+ RtlZeroMemory(keySlot.FutureEpochKeys, sizeof(keySlot.FutureEpochKeys));
+ OvpnCryptoEpochGenerateFutureRecvKeys(&keySlot, &opts);
+ }
+
+ OvpnCryptoOptions opts;
+ OvpnCryptoKeySlot keySlot;
+};
+
+TEST_F(CryptoTest, HkdfExpand) {
+ uint8_t secret[32] = { 0x07, 0x77, 0x09, 0x36, 0x2c, 0x2e, 0x32, 0xdf, 0x0d, 0xdc, 0x3f,
+ 0x0d, 0xc4, 0x7b, 0xba, 0x63, 0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f,
+ 0x9c, 0x31, 0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5 };
+
+ const char* label = "unit test";
+
+ uint8_t out_expected[16] = { 0x18, 0x5e, 0xaa, 0x1c, 0x7f, 0x22, 0x8a, 0xb8,
+ 0xeb, 0x29, 0x77, 0x32, 0x14, 0xd9, 0x20, 0x46 };
+
+ uint8_t out[16];
+
+ ASSERT_EQ(OvpnCryptoExpandLabel(opts.HkdfAlgHandle, secret, sizeof(out), label, out), STATUS_SUCCESS);
+ ASSERT_EQ(0, std::memcmp(out, out_expected, sizeof(out)));
+}
+
+TEST_F(CryptoTest, EpochKeyGeneration) {
+ // check that the keys look like expected
+ ASSERT_EQ(keySlot.FutureEpochKeys[0].Epoch, 2);
+ ASSERT_EQ(keySlot.FutureEpochKeys[15].Epoch, 17);
+ ASSERT_EQ(keySlot.EpochKeySend.Epoch, 1);
+ ASSERT_EQ(keySlot.EpochKeyRecv.Epoch, 17);
+
+ // Now replace the recv key with the 6th future key (epoch = 8)
+ BCryptDestroyKey(keySlot.Decrypt.Key);
+ RtlZeroMemory(&keySlot.Decrypt, sizeof(keySlot.Decrypt));
+ ASSERT_EQ(keySlot.FutureEpochKeys[6].Epoch, 8);
+ keySlot.Decrypt = keySlot.FutureEpochKeys[6];
+ RtlZeroMemory(&keySlot.FutureEpochKeys[6].Epoch, sizeof(OvpnCryptoKeyContext));
+
+ OvpnCryptoEpochGenerateFutureRecvKeys(&keySlot, &opts);
+ ASSERT_EQ(keySlot.FutureEpochKeys[0].Epoch, 9);
+ ASSERT_EQ(keySlot.FutureEpochKeys[15].Epoch, 24);
+}
+
+TEST_F(CryptoTest, EpochKeyRotation) {
+ /* should replace send + key recv */
+ OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, 9, &opts);
+
+ ASSERT_EQ(keySlot.Decrypt.Epoch, 9);
+ ASSERT_EQ(keySlot.Encrypt.Epoch, 9);
+ ASSERT_EQ(keySlot.EpochKeySend.Epoch, 9);
+ ASSERT_EQ(keySlot.RetiringEpochDataReceiveKey.Epoch, 1);
+
+ /* Iterate the data send key four times to get it to 13 */
+ for (int i = 0; i < 4; i++)
+ {
+ OvpnCryptoEpochKeyIterate(&keySlot.EpochKeySend, opts.HkdfAlgHandle);
+
+ BCryptDestroyKey(&keySlot.Encrypt.Key);
+ RtlZeroMemory(&keySlot.Encrypt, sizeof(OvpnCryptoKeyContext));
+
+ OvpnCryptoEpochInitKey(&keySlot.Encrypt, &keySlot.EpochKeySend, &opts);
+ }
+ ASSERT_EQ(keySlot.Encrypt.Epoch, 13);
+
+ OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, 10, &opts);
+ ASSERT_EQ(keySlot.Decrypt.Epoch, 10);
+ ASSERT_EQ(keySlot.Encrypt.Epoch, 13);
+ ASSERT_EQ(keySlot.EpochKeySend.Epoch, 13);
+ ASSERT_EQ(keySlot.RetiringEpochDataReceiveKey.Epoch, 9);
+
+ OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, 12, &opts);
+ ASSERT_EQ(keySlot.Decrypt.Epoch, 12);
+ ASSERT_EQ(keySlot.Encrypt.Epoch, 13);
+ ASSERT_EQ(keySlot.EpochKeySend.Epoch, 13);
+ ASSERT_EQ(keySlot.RetiringEpochDataReceiveKey.Epoch, 10);
+
+ OvpnCryptoEpochKeyIterate(&keySlot.EpochKeySend, opts.HkdfAlgHandle);
+
+ BCryptDestroyKey(&keySlot.Encrypt.Key);
+ RtlZeroMemory(&keySlot.Encrypt, sizeof(OvpnCryptoKeyContext));
+
+ OvpnCryptoEpochInitKey(&keySlot.Encrypt, &keySlot.EpochKeySend, &opts);
+
+ ASSERT_EQ(keySlot.Encrypt.Epoch, 14);
+}
+
+TEST_F(CryptoTest, EpochKeyReceiveLookup)
+{
+ /* lookup some wacky things that should fail */
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 2000), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, -1), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 0xefff), nullptr);
+
+ /* Lookup the edges of the current window */
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 0), nullptr);
+ ASSERT_EQ(keySlot.RetiringEpochDataReceiveKey.Epoch, 0);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 1)->Epoch, 1);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 2)->Epoch, 2);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 16)->Epoch, 16);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 17)->Epoch, 17);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 18), nullptr);
+
+ /* Should move 1 to retiring key but leave 2-6 undefined, 7 as
+ * active and 8-23 as future keys*/
+ OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, 7, &opts);
+
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 0), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 1)->Epoch, 1);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 1), &keySlot.RetiringEpochDataReceiveKey);
+
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 2), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 3), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 4), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 5), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 6), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 21)->Epoch, 21);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 22)->Epoch, 22);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 23)->Epoch, 23);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 24), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 25), nullptr);
+
+ /* Should move 7 to retiring key and have 8 as active key and
+ * 9-24 as future keys */
+ OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, 8, &opts);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 0), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 1), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 2), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 3), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 4), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 5), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 6), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 7)->Epoch, 7);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 7), &keySlot.RetiringEpochDataReceiveKey);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 8)->Epoch, 8);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 23)->Epoch, 23);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 24)->Epoch, 24);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 25), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, 26), nullptr);
+}
+
+TEST_F(CryptoTest, EpochKeyOverflow)
+{
+ /* Modify the receive epoch and keys to have a very high epoch to test
+ * the end of array. Iterating through all 65k keys takes a 2-3s, so we
+ * avoid this for the unit test */
+ keySlot.Decrypt.Epoch = 65516;
+ keySlot.Encrypt.Epoch = 65516;
+
+ keySlot.EpochKeySend.Epoch = 65516;
+ keySlot.EpochKeyRecv.Epoch = 65516 + FUTURE_EPOCH_KEYS_COUNT;
+
+ for (int i = 0; i < FUTURE_EPOCH_KEYS_COUNT; ++i) {
+ keySlot.FutureEpochKeys[i].Epoch = 65517 + i;
+ }
+
+ /* Move the last few keys until we are close to the limit */
+ while (keySlot.Decrypt.Epoch < (UINT16_MAX - 24))
+ {
+ OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, keySlot.Decrypt.Epoch + 10, &opts);
+ }
+
+ /* Looking up this key should still work as it will not break the limit
+ * when generating keys */
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX - 18)->Epoch, UINT16_MAX - 18);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX - 17)->Epoch, UINT16_MAX - 17);
+
+ /* This key is no longer eligible for decrypting as the 16 future keys
+ * would be larger than uint16_t maximum */
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX - FUTURE_EPOCH_KEYS_COUNT), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX), nullptr);
+
+ /* Check that moving to the last possible epoch works */
+ OvpnCryptoEpochReplaceUpdateRecvKey(&keySlot, UINT16_MAX - 17, &opts);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX - 17)->Epoch, UINT16_MAX - 17);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX - 16), nullptr);
+ ASSERT_EQ(OvpnCryptoEpochLookupDecryptKey(&keySlot, UINT16_MAX), nullptr);
+}
+
+TEST_F(CryptoTest, EpochDeriveDataKey)
+{
+ OvpnCryptoKeyParameters kp;
+ OvpnCryptoEpochKey e17{ {19, 12}, 17};
+ OvpnCryptoEpochDataKeyDerive(&kp, &e17, opts.HkdfAlgHandle, opts.AeadAlgHangle, 24);
+
+ uint8_t exp_cipherkey[24] = { 0xed, 0x85, 0x33, 0xdb, 0x1c, 0x28, 0xac, 0xe4,
+ 0x18, 0xe9, 0x00, 0x6a, 0xb2, 0x9c, 0x17, 0x41,
+ 0x7d, 0x60, 0xeb, 0xe6, 0xcd, 0x90, 0xbf, 0x0a };
+
+ uint8_t exp_impl_iv[12] = { 0x86, 0x89, 0x0a, 0xab, 0xf0, 0x32,
+ 0xcb, 0x59, 0xf4, 0xcf, 0xa3, 0x4e };
+
+ ASSERT_EQ(0, std::memcmp(kp.Cipher, exp_cipherkey, sizeof(exp_cipherkey)));
+ ASSERT_EQ(0, std::memcmp(kp.IV, exp_impl_iv, sizeof(exp_impl_iv)));
+}
diff --git a/timer.cpp b/timer.cpp
index 213e1ff..20ae520 100644
--- a/timer.cpp
+++ b/timer.cpp
@@ -80,25 +80,29 @@ static VOID OvpnTimerXmit(WDFTIMER timer)
ExAcquireSpinLockSharedAtDpcLevel(&peer->SpinLock);
auto peerId = peer->PeerId;
- SOCKADDR_STORAGE sa;
- OvpnSocketCopyRemoteToSockaddr(peer->TransportAddrs.Remote, &sa);
+ SOCKADDR_STORAGE sa = {0};
+ BOOLEAN exclusive = FALSE;
OvpnCryptoContext* cryptoContext = &peer->CryptoContext;
if (cryptoContext->Encrypt) {
- // make space to crypto overhead
- BOOLEAN pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
- BOOLEAN aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
-
- OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN));
-
- // in-place encrypt, always with primary key
- status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions);
+ const OvpnCryptoPacketLayout layout = cryptoContext->Layout;
+
+ OvpnTxBufferPush(buffer, layout.FrontLen);
+ OvpnBufferPut(buffer, layout.TailLen);
+
+ OvpnCryptoEncryptParams encryptParams = { buffer->Data, buffer->Len };
+ status = OvpnCryptoCallWithRetry(peer, TRUE, &exclusive, nullptr, OvpnCryptoInvokeEncrypt, &encryptParams);
+
+ if (NT_SUCCESS(status)) {
+ OvpnSocketCopyRemoteToSockaddr(peer->TransportAddrs.Remote, &sa);
+ }
}
else {
status = STATUS_INVALID_DEVICE_STATE;
// LOG_WARN("CryptoContext not initialized");
}
- ExReleaseSpinLockSharedFromDpcLevel(&peer->SpinLock);
+
+ OvpnReleaseSpinLock(TRUE, 0, &peer->SpinLock, exclusive);
if (NT_SUCCESS(status)) {
// start async send, completion handler will return ciphertext buffer to the pool
diff --git a/trace.h b/trace.h
index fc1c8a9..985d917 100644
--- a/trace.h
+++ b/trace.h
@@ -21,6 +21,7 @@
#pragma once
+#if defined(_KERNEL_MODE)
#include
#include
#include
@@ -30,6 +31,12 @@
TRACELOGGING_DECLARE_PROVIDER(g_hOvpnEtwProvider);
+#else
+#include
+#endif
+
+#if defined(_KERNEL_MODE)
+
#define TraceLoggingFunctionName() TraceLoggingWideString(__FUNCTIONW__, "Func")
#define LOG_NTSTATUS(Status, ...) do {\
@@ -106,16 +113,35 @@ TRACELOGGING_DECLARE_PROVIDER(g_hOvpnEtwProvider);
} \
} while(0,0)
-#define GOTO_IF_NOT_NT_SUCCESS(Label, StatusLValue, Expression, ...) do {\
- StatusLValue = (Expression); \
- if (!NT_SUCCESS(StatusLValue)) \
- { \
- LOG_NTSTATUS(StatusLValue, \
- TraceLoggingWideString(L#Expression, "Expression"), \
- __VA_ARGS__); \
- goto Label; \
- } \
-} while(0,0)
+#define GOTO_IF_NOT_NT_SUCCESS(Label, StatusLValue, Expression, ...) \
+ do \
+ { \
+ StatusLValue = (Expression); \
+ if (!NT_SUCCESS(StatusLValue)) \
+ { \
+ LOG_NTSTATUS(StatusLValue, \
+ TraceLoggingString(#Expression, "Expression"), \
+ __VA_ARGS__); \
+ goto Label; \
+ } \
+ } while (0)
+
+#else /* user-mode / tests */
+
+#define LOG_INFO(Info, ...)
+
+#define GOTO_IF_NOT_NT_SUCCESS(Label, StatusLValue, Expression, ...) \
+ do \
+ { \
+ StatusLValue = (Expression); \
+ if (!NT_SUCCESS(StatusLValue)) \
+ { \
+ goto Label; \
+ } \
+ } while (0)
+
+#endif
+
#ifndef TraceLoggingIPv4Address
#define TraceLoggingIPv4Address(value, ...) _tlgArgScalarVal(UINT32, value, TlgInUINT32, (TlgOutIPV4), __VA_ARGS__)
diff --git a/txqueue.cpp b/txqueue.cpp
index 95f26a5..9d97ca9 100644
--- a/txqueue.cpp
+++ b/txqueue.cpp
@@ -235,29 +235,26 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET
InterlockedExchangeAddNoFence64(&peer->VpnTxBytes, buffer->Len);
auto irql = ExAcquireSpinLockShared(&peer->SpinLock);
+ BOOLEAN exclusive = FALSE;
OvpnCryptoContext* cryptoContext = &peer->CryptoContext;
auto remoteAddr = peer->TransportAddrs.Remote;
if (cryptoContext->Encrypt) {
- auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END;
- auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID;
-
- // make space to crypto overhead
- OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN));
- if (aeadTagEnd)
- {
- OvpnBufferPut(buffer, AEAD_AUTH_TAG_LEN);
- }
+ const OvpnCryptoPacketLayout layout = cryptoContext->Layout;
+
+ OvpnTxBufferPush(buffer, layout.FrontLen);
+ OvpnBufferPut(buffer, layout.TailLen);
- // in-place encrypt, always with primary key
- status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions);
+ OvpnCryptoEncryptParams encryptParams = { buffer->Data, buffer->Len };
+ status = OvpnCryptoCallWithRetry(peer, FALSE, &exclusive, &irql, OvpnCryptoInvokeEncrypt, &encryptParams);
}
else {
status = STATUS_INVALID_DEVICE_STATE;
// LOG_WARN("CryptoContext not initialized");
}
- ExReleaseSpinLockShared(&peer->SpinLock, irql);
+
+ OvpnReleaseSpinLock(FALSE, irql, &peer->SpinLock, exclusive);
if (NT_SUCCESS(status)) {
InterlockedExchangeAddNoFence64(&peer->LinkTxBytes, buffer->Len);
diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h
index 6dab25f..a7b28fb 100644
--- a/uapi/ovpn-dco.h
+++ b/uapi/ovpn-dco.h
@@ -119,8 +119,7 @@ typedef struct _OVPN_CRYPTO_DATA {
int PeerId;
} OVPN_CRYPTO_DATA, * POVPN_CRYPTO_DATA;
-#define CRYPTO_OPTIONS_AEAD_TAG_END (1<<1)
-#define CRYPTO_OPTIONS_64BIT_PKTID (1<<2)
+#define CRYPTO_OPTIONS_EPOCH (1<<1)
typedef struct _OVPN_CRYPTO_DATA_V2 {
OVPN_CRYPTO_DATA V1;