diff --git a/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/UnauthorizedMessagePacketHandler.java b/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/UnauthorizedMessagePacketHandler.java index b314639f1..082d35911 100644 --- a/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/UnauthorizedMessagePacketHandler.java +++ b/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/UnauthorizedMessagePacketHandler.java @@ -41,10 +41,21 @@ public void handle(Envelope envelope) { envelope.getIdString())); NodeSession session = envelope.get(Field.SESSION); + OrdinaryMessagePacket unknownPacket = envelope.get(Field.UNAUTHORIZED_PACKET_MESSAGE); + Bytes12 msgNonce = unknownPacket.getHeader().getStaticHeader().getNonce(); + + // If already awaiting handshake completion, resend the original WHOAREYOU for retransmissions + // of the same packet (same nonce). For a new nonce, fall through and issue a fresh WHOAREYOU + // so the initiator's nonce check can pass. + if (session.getState() == SessionState.WHOAREYOU_SENT) { + if (session.getPendingWhoAreYouNonce().map(msgNonce::equals).orElse(false)) { + session.resendOutgoingWhoAreYou(); + return; + } + } + try { - // packet it either random or message packet if session is expired - Bytes12 msgNonce = unknownPacket.getHeader().getStaticHeader().getNonce(); Bytes16 idNonce = Bytes16.random(Functions.getRandom()); Header header = diff --git a/src/main/java/org/ethereum/beacon/discovery/schema/NodeSession.java b/src/main/java/org/ethereum/beacon/discovery/schema/NodeSession.java index 7bdb1f3bf..aca82d6e2 100644 --- a/src/main/java/org/ethereum/beacon/discovery/schema/NodeSession.java +++ b/src/main/java/org/ethereum/beacon/discovery/schema/NodeSession.java @@ -69,6 +69,7 @@ public class NodeSession { private final Signer signer; private Optional reportedExternalAddress = Optional.empty(); private Optional whoAreYouChallenge = Optional.empty(); + private Optional pendingWhoAreYouPacket = Optional.empty(); private Optional lastOutboundNonce = Optional.empty(); private boolean active = true; private final Function nonceGenerator; @@ -161,10 +162,33 @@ public void sendOutgoingRandom(final Bytes randomData) { sendOutgoing(generateMaskingIV(), packet); } - public void sendOutgoingWhoAreYou(final WhoAreYouPacket packet) { + public synchronized void sendOutgoingWhoAreYou(final WhoAreYouPacket packet) { LOG.trace( () -> String.format("Sending outgoing WhoAreYou message %s in session %s", packet, this)); Bytes16 maskingIV = generateMaskingIV(); + pendingWhoAreYouPacket = Optional.of(packet); + dispatchWhoAreYou(maskingIV, packet); + } + + public synchronized Optional getPendingWhoAreYouNonce() { + return pendingWhoAreYouPacket.map(p -> p.getHeader().getStaticHeader().getNonce()); + } + + public synchronized void resendOutgoingWhoAreYou() { + pendingWhoAreYouPacket.ifPresent( + packet -> { + LOG.trace( + () -> + String.format( + "Resending outgoing WhoAreYou message %s in session %s", packet, this)); + // Reuse the original maskingIV so the stored challenge remains stable; the initiator + // may have already signed against it. + Bytes16 maskingIV = Bytes16.wrap(whoAreYouChallenge.orElseThrow().slice(0, 16)); + sendOutgoing(maskingIV, packet); + }); + } + + private void dispatchWhoAreYou(final Bytes16 maskingIV, final WhoAreYouPacket packet) { whoAreYouChallenge = Optional.of(Bytes.wrap(maskingIV, packet.getHeader().getBytes())); sendOutgoing(maskingIV, packet); } @@ -228,6 +252,7 @@ public synchronized RequestInfo createNextRequest(final Request request) { private synchronized void resetHandshakeState() { if (state == SessionState.WHOAREYOU_SENT || state == SessionState.RANDOM_PACKET_SENT) { + pendingWhoAreYouPacket = Optional.empty(); setState(SessionState.INITIAL); } } diff --git a/src/test/java/org/ethereum/beacon/discovery/pipeline/handler/UnauthorizedMessagePacketHandlerTest.java b/src/test/java/org/ethereum/beacon/discovery/pipeline/handler/UnauthorizedMessagePacketHandlerTest.java new file mode 100644 index 000000000..bad0e1cdc --- /dev/null +++ b/src/test/java/org/ethereum/beacon/discovery/pipeline/handler/UnauthorizedMessagePacketHandlerTest.java @@ -0,0 +1,136 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.ethereum.beacon.discovery.pipeline.handler; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Optional; +import org.apache.tuweni.bytes.Bytes; +import org.apache.tuweni.bytes.Bytes32; +import org.ethereum.beacon.discovery.packet.Header; +import org.ethereum.beacon.discovery.packet.OrdinaryMessagePacket; +import org.ethereum.beacon.discovery.packet.OrdinaryMessagePacket.OrdinaryAuthData; +import org.ethereum.beacon.discovery.packet.WhoAreYouPacket; +import org.ethereum.beacon.discovery.pipeline.Envelope; +import org.ethereum.beacon.discovery.pipeline.Field; +import org.ethereum.beacon.discovery.schema.NodeSession; +import org.ethereum.beacon.discovery.schema.NodeSession.SessionState; +import org.ethereum.beacon.discovery.type.Bytes12; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +class UnauthorizedMessagePacketHandlerTest { + + private final UnauthorizedMessagePacketHandler handler = new UnauthorizedMessagePacketHandler(); + + @Test + void shouldResendExistingWhoAreYouWhenInWhoAreYouSentState() { + final NodeSession session = mock(NodeSession.class); + when(session.getState()).thenReturn(SessionState.WHOAREYOU_SENT); + + final OrdinaryMessagePacket packet = createOrdinaryPacket(); + final Bytes12 nonce = packet.getHeader().getStaticHeader().getNonce(); + when(session.getPendingWhoAreYouNonce()).thenReturn(Optional.of(nonce)); + + handler.handle(envelopeWith(session, packet)); + + verify(session).resendOutgoingWhoAreYou(); + verify(session, never()).sendOutgoingWhoAreYou(any()); + verify(session, never()).setState(any()); + } + + @Test + void shouldNotChangeStateWhenResendingInWhoAreYouSentState() { + final NodeSession session = mock(NodeSession.class); + when(session.getState()).thenReturn(SessionState.WHOAREYOU_SENT); + + final OrdinaryMessagePacket packet = createOrdinaryPacket(); + final Bytes12 nonce = packet.getHeader().getStaticHeader().getNonce(); + when(session.getPendingWhoAreYouNonce()).thenReturn(Optional.of(nonce)); + + handler.handle(envelopeWith(session, packet)); + + verify(session, never()).setState(any()); + } + + @Test + void shouldSendNewWhoAreYouWhenInWhoAreYouSentStateButDifferentNonce() { + final NodeSession session = mock(NodeSession.class); + when(session.getState()).thenReturn(SessionState.WHOAREYOU_SENT); + when(session.getNodeRecord()).thenReturn(Optional.empty()); + // Pending WhoAreYou was for a different nonce + when(session.getPendingWhoAreYouNonce()) + .thenReturn(Optional.of(Bytes12.wrap(Bytes.random(12)))); + + final OrdinaryMessagePacket packet = createOrdinaryPacket(); + handler.handle(envelopeWith(session, packet)); + + verify(session, never()).resendOutgoingWhoAreYou(); + final ArgumentCaptor captor = ArgumentCaptor.forClass(WhoAreYouPacket.class); + verify(session).sendOutgoingWhoAreYou(captor.capture()); + final Bytes12 expectedNonce = packet.getHeader().getStaticHeader().getNonce(); + assertThat(captor.getValue().getHeader().getStaticHeader().getNonce()).isEqualTo(expectedNonce); + } + + @Test + void shouldSendNewWhoAreYouWithIncomingNonceWhenInInitialState() { + final NodeSession session = mock(NodeSession.class); + when(session.getState()).thenReturn(SessionState.INITIAL); + when(session.getNodeRecord()).thenReturn(Optional.empty()); + + final OrdinaryMessagePacket packet = createOrdinaryPacket(); + handler.handle(envelopeWith(session, packet)); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(WhoAreYouPacket.class); + verify(session).sendOutgoingWhoAreYou(captor.capture()); + verify(session, never()).resendOutgoingWhoAreYou(); + verify(session).setState(SessionState.WHOAREYOU_SENT); + + // The WHOAREYOU nonce must echo the incoming packet's nonce so the initiator can + // match it to their pending request. + final Bytes12 expectedNonce = packet.getHeader().getStaticHeader().getNonce(); + assertThat(captor.getValue().getHeader().getStaticHeader().getNonce()).isEqualTo(expectedNonce); + } + + @Test + void shouldSkipWhenUnauthorizedPacketMessageFieldAbsent() { + final NodeSession session = mock(NodeSession.class); + final Envelope envelope = new Envelope(); + envelope.put(Field.SESSION, session); + + handler.handle(envelope); + + verify(session, never()).resendOutgoingWhoAreYou(); + verify(session, never()).sendOutgoingWhoAreYou(any()); + } + + @Test + void shouldSkipWhenSessionFieldAbsent() { + final Envelope envelope = new Envelope(); + envelope.put(Field.UNAUTHORIZED_PACKET_MESSAGE, createOrdinaryPacket()); + + // Should not throw even without a session. + handler.handle(envelope); + } + + private static OrdinaryMessagePacket createOrdinaryPacket() { + final Bytes12 nonce = Bytes12.wrap(Bytes.random(12)); + final Header header = Header.createOrdinaryHeader(Bytes32.ZERO, nonce); + return OrdinaryMessagePacket.createRandom(header, Bytes.random(20)); + } + + private static Envelope envelopeWith( + final NodeSession session, final OrdinaryMessagePacket packet) { + final Envelope envelope = new Envelope(); + envelope.put(Field.UNAUTHORIZED_PACKET_MESSAGE, packet); + envelope.put(Field.SESSION, session); + return envelope; + } +} diff --git a/src/test/java/org/ethereum/beacon/discovery/schema/NodeSessionTest.java b/src/test/java/org/ethereum/beacon/discovery/schema/NodeSessionTest.java index 0d355e7d8..cc200575e 100644 --- a/src/test/java/org/ethereum/beacon/discovery/schema/NodeSessionTest.java +++ b/src/test/java/org/ethereum/beacon/discovery/schema/NodeSessionTest.java @@ -5,10 +5,13 @@ package org.ethereum.beacon.discovery.schema; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -22,11 +25,16 @@ import java.util.function.Consumer; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; +import org.apache.tuweni.units.bigints.UInt64; import org.ethereum.beacon.discovery.SimpleIdentitySchemaInterpreter; import org.ethereum.beacon.discovery.crypto.DefaultSigner; import org.ethereum.beacon.discovery.crypto.Signer; import org.ethereum.beacon.discovery.message.V5Message; import org.ethereum.beacon.discovery.network.NetworkParcel; +import org.ethereum.beacon.discovery.network.NetworkParcelV5; +import org.ethereum.beacon.discovery.packet.Header; +import org.ethereum.beacon.discovery.packet.WhoAreYouPacket; +import org.ethereum.beacon.discovery.packet.WhoAreYouPacket.WhoAreYouAuthData; import org.ethereum.beacon.discovery.pipeline.handler.NodeSessionManager; import org.ethereum.beacon.discovery.pipeline.info.Request; import org.ethereum.beacon.discovery.pipeline.info.RequestInfo; @@ -36,6 +44,8 @@ import org.ethereum.beacon.discovery.storage.LocalNodeRecordStore; import org.ethereum.beacon.discovery.storage.NewAddressHandler; import org.ethereum.beacon.discovery.storage.NodeRecordListener; +import org.ethereum.beacon.discovery.type.Bytes12; +import org.ethereum.beacon.discovery.type.Bytes16; import org.ethereum.beacon.discovery.util.Functions; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -187,6 +197,73 @@ void createNextRequest_shouldNotResetAuthenticatedStatesWhenRequestTimesOut() { assertThat(session.getState()).isEqualTo(SessionState.AUTHENTICATED); } + @Test + void resendOutgoingWhoAreYou_shouldSendPacketWhenPendingPacketExists() { + session.sendOutgoingWhoAreYou(createWhoAreYouPacket(Bytes12.wrap(Bytes.random(12)))); + + final ArgumentCaptor firstCaptor = + ArgumentCaptor.forClass(NetworkParcelV5.class); + verify(outgoingPipeline).accept(firstCaptor.capture()); + + session.resendOutgoingWhoAreYou(); + + final ArgumentCaptor secondCaptor = + ArgumentCaptor.forClass(NetworkParcelV5.class); + verify(outgoingPipeline, times(2)).accept(secondCaptor.capture()); + // A packet must actually be sent on resend. + assertThat(secondCaptor.getAllValues()).hasSize(2); + } + + @Test + void resendOutgoingWhoAreYou_shouldDoNothingWhenNoPendingPacket() { + session.resendOutgoingWhoAreYou(); + + verify(outgoingPipeline, never()).accept(any()); + } + + @Test + void resendOutgoingWhoAreYou_shouldDoNothingAfterHandshakeStateReset() { + final Request request = createRequestMock(); + final RequestInfo requestInfo = session.createNextRequest(request); + + final ArgumentCaptor timeoutHandlerCaptor = ArgumentCaptor.forClass(Runnable.class); + verify(expirationScheduler).put(eq(requestInfo.getRequestId()), timeoutHandlerCaptor.capture()); + + session.sendOutgoingWhoAreYou(createWhoAreYouPacket(Bytes12.wrap(Bytes.random(12)))); + session.setState(SessionState.WHOAREYOU_SENT); + + // Simulate request timeout which resets the handshake state. + timeoutHandlerCaptor.getValue().run(); + assertThat(session.getState()).isEqualTo(SessionState.INITIAL); + + session.resendOutgoingWhoAreYou(); + + // sendOutgoingWhoAreYou was called once above; resend should not add another send. + verify(outgoingPipeline, times(1)).accept(any()); + } + + @Test + void resendOutgoingWhoAreYou_shouldPreserveOriginalNonce() { + final Bytes12 originalNonce = Bytes12.wrap(Bytes.random(12)); + final WhoAreYouPacket originalPacket = createWhoAreYouPacket(originalNonce); + session.sendOutgoingWhoAreYou(originalPacket); + + final Bytes challengeAfterSend = session.getWhoAreYouChallenge().orElseThrow(); + + session.resendOutgoingWhoAreYou(); + + // Challenge must be unchanged after resend so a handshake signed against the original + // challenge remains valid. + assertThat(session.getWhoAreYouChallenge()).contains(challengeAfterSend); + } + + private static WhoAreYouPacket createWhoAreYouPacket(final Bytes12 nonce) { + final Bytes16 idNonce = Bytes16.wrap(Bytes.random(16)); + final Header header = + Header.createWhoAreYouHeader(nonce, idNonce, UInt64.ZERO); + return WhoAreYouPacket.create(header); + } + private Request createRequestMock() { final Request request = mock(Request.class); when(request.getResultPromise()).thenReturn(new CompletableFuture<>());