From d0942e00d65529040084d111505a2102fcf59e6f Mon Sep 17 00:00:00 2001 From: iksnerd Date: Sat, 2 May 2026 22:50:50 +0300 Subject: [PATCH] =?UTF-8?q?Harden=20peer=20handshake=20=E2=80=94=20strict?= =?UTF-8?q?=20pstrlen=20+=20magic=20check=20+=20typed=20state=20machine?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the last LangSec gap in the BitTorrent client: the BEP 3 handshake previously read pstrlen bytes blindly without validating the protocol magic string, and PeerConn had no explicit state guarding which methods were valid when. Now the recognizer enforces the full handshake grammar before any state mutation and a typed state field gates downstream calls. - internal/client/peer.go: - protoMagic constant + handshakeLen = 68 (BEP 3 mandates exactly this). - Read the peer's full 68-byte handshake in one ReadFull, then validate: pstrlen == 19, pstr == "BitTorrent protocol", info_hash matches. Only then promote state to stateHandshook (and stateExtended after BEP 10). - peerState enum {Init, Handshook, Extended}. Handshake requires Init. ReadMessage/WriteMessage require >= Handshook. RequestMetadata requires Extended. Each precondition returns a typed error on misuse. - internal/client/peer_test.go: new tests for bad pstrlen, wrong magic, ReadMessage/WriteMessage before handshake, RequestMetadata before extended handshake. fakePeer helper for handshake-rejection cases. --- internal/client/peer.go | 83 ++++++++++++++++++++------- internal/client/peer_test.go | 105 +++++++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 20 deletions(-) diff --git a/internal/client/peer.go b/internal/client/peer.go index 9ecfb58..008710c 100644 --- a/internal/client/peer.go +++ b/internal/client/peer.go @@ -16,10 +16,30 @@ const ( ExtMetadata = "ut_metadata" ) +// BEP 3 wire constants. The handshake is *exactly* 68 bytes: 1 length byte + +// 19 magic bytes + 8 reserved + 20 info_hash + 20 peer_id. Anything else is +// off-spec and rejected. +const ( + protoMagic = "BitTorrent protocol" + pstrLen = 19 // len(protoMagic) + handshakeLen = 1 + pstrLen + 8 + 20 + 20 +) + +// peerState is the LangSec-style typed state of a PeerConn. Methods that +// operate on a connection require a minimum state and refuse to run otherwise. +type peerState int + +const ( + stateInit peerState = iota // freshly TCP-connected, no handshake yet + stateHandshook // BEP 3 handshake complete + stateExtended // BEP 10 extended handshake complete +) + // PeerConn wraps a TCP connection to a peer and tracks PWP state. type PeerConn struct { - conn net.Conn - addr string + conn net.Conn + addr string + state peerState AmChoking bool AmInterested bool @@ -50,9 +70,17 @@ func Connect(ctx context.Context, addr string) (*PeerConn, error) { // Handshake performs the standard BEP 3 handshake. // Sends: // +// LangSec recognition: the peer's reply must be exactly 68 bytes, +// pstrlen must be exactly 19, and pstr must be exactly "BitTorrent protocol". +// Anything else is off-spec and the connection is dropped before any state +// (PeerExtensions, MetadataSize, choke/interest flags) is touched. +// // NOTE: Even for hybrid v1+v2 torrents, the BitTorrent wire protocol // handshake ALWAYS uses the 20-byte SHA-1 info hash (v1) per BEP 3. func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string) error { + if p.state != stateInit { + return fmt.Errorf("handshake called in state %d (expected init)", p.state) + } deadline := time.Now().Add(10 * time.Second) if d, ok := ctx.Deadline(); ok && d.Before(deadline) { deadline = d @@ -67,11 +95,10 @@ func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string return fmt.Errorf("peer id must be 20 bytes, got %d", len(peerID)) } - pstr := "BitTorrent protocol" - buf := make([]byte, 1+len(pstr)+8+20+20) - buf[0] = byte(len(pstr)) + buf := make([]byte, handshakeLen) + buf[0] = byte(pstrLen) curr := 1 - curr += copy(buf[curr:], pstr) + curr += copy(buf[curr:], protoMagic) // Reserved bytes (8 bytes) // We set bit 43 (byte 5, bit 0x10) to signal BEP 10 Extension Protocol support. @@ -86,28 +113,30 @@ func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string return fmt.Errorf("write handshake: %w", err) } - // Read peer's handshake - resBuf := make([]byte, 1) + // Read the peer's full 68-byte handshake. LangSec: read it whole, then + // validate the entire structure before letting anything downstream touch it. + resBuf := make([]byte, handshakeLen) if _, err := io.ReadFull(p.conn, resBuf); err != nil { - return fmt.Errorf("read pstrlen: %w", err) - } - pstrlen := int(resBuf[0]) - if pstrlen == 0 { - return fmt.Errorf("invalid pstrlen 0") + return fmt.Errorf("read handshake: %w", err) } - resBuf = make([]byte, pstrlen+8+20+20) - if _, err := io.ReadFull(p.conn, resBuf); err != nil { - return fmt.Errorf("read handshake payload: %w", err) + if int(resBuf[0]) != pstrLen { + return fmt.Errorf("invalid pstrlen %d (must be %d per BEP 3)", resBuf[0], pstrLen) + } + if !bytes.Equal(resBuf[1:1+pstrLen], []byte(protoMagic)) { + return fmt.Errorf("invalid protocol magic: got %q", resBuf[1:1+pstrLen]) } - resInfoHash := resBuf[pstrlen+8 : pstrlen+8+20] + resInfoHash := resBuf[1+pstrLen+8 : 1+pstrLen+8+20] if !bytes.Equal(resInfoHash, infoHash) { return fmt.Errorf("info hash mismatch: expected %x, got %x", infoHash, resInfoHash) } + // All recognition passed — promote state. + p.state = stateHandshook + // Check if peer supports BEP 10 - peerReserved := resBuf[pstrlen : pstrlen+8] + peerReserved := resBuf[1+pstrLen : 1+pstrLen+8] supportsBEP10 := (peerReserved[5] & 0x10) != 0 if supportsBEP10 { @@ -120,6 +149,7 @@ func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string if err := p.readExtendedHandshake(); err != nil { return fmt.Errorf("read extended handshake: %w", err) } + p.state = stateExtended } return nil @@ -158,7 +188,11 @@ func (p *PeerConn) readExtendedHandshake() error { } // RequestMetadata sends a BEP 9 metadata request for the given piece. +// Requires that the BEP 10 extended handshake has completed. func (p *PeerConn) RequestMetadata(piece int) error { + if p.state < stateExtended { + return fmt.Errorf("RequestMetadata called in state %d (expected extended)", p.state) + } extID, ok := p.PeerExtensions[ExtMetadata] if !ok { return fmt.Errorf("peer does not support %s", ExtMetadata) @@ -202,15 +236,24 @@ func (p *PeerConn) sendExtendedHandshake() error { }) } -// ReadMessage reads the next message from the peer. +// ReadMessage reads the next message from the peer. Requires the BEP 3 +// handshake to have completed — until then, raw bytes on the wire don't +// frame as PWP messages. func (p *PeerConn) ReadMessage() (*Message, error) { + if p.state < stateHandshook { + return nil, fmt.Errorf("ReadMessage called in state %d (expected handshook)", p.state) + } // Set a reasonable read timeout to avoid hanging forever p.conn.SetReadDeadline(time.Now().Add(2 * time.Minute)) return ReadMessage(p.conn) } -// WriteMessage writes a message to the peer. +// WriteMessage writes a message to the peer. Requires the BEP 3 handshake +// to have completed. func (p *PeerConn) WriteMessage(m *Message) error { + if p.state < stateHandshook { + return fmt.Errorf("WriteMessage called in state %d (expected handshook)", p.state) + } p.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) return WriteMessage(p.conn, m) } diff --git a/internal/client/peer_test.go b/internal/client/peer_test.go index 899ac59..13eabc2 100644 --- a/internal/client/peer_test.go +++ b/internal/client/peer_test.go @@ -131,3 +131,108 @@ func TestHandshakeMismatch(t *testing.T) { t.Errorf("unexpected error: %v", err) } } + +// fakePeer accepts one connection, reads the client's handshake, and replies +// with a custom 68-byte handshake reply. Returns the listener's addr. +func fakePeer(t *testing.T, reply []byte) string { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + t.Cleanup(func() { ln.Close() }) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + // Drain the client's 68-byte handshake. + io.ReadFull(conn, make([]byte, 68)) + conn.Write(reply) + }() + return ln.Addr().String() +} + +func TestHandshakeRejectsBadPstrlen(t *testing.T) { + infoHash := make([]byte, 20) + copy(infoHash, "infohash123456789012") + + // Reply claims pstrlen=18 instead of 19 — handshake reads exactly 68 + // bytes (1+19+8+20+20), so we need to construct a 68-byte buffer with + // a bogus first byte. The parser should reject on the length check. + reply := make([]byte, 68) + reply[0] = 18 // wrong + copy(reply[1:], "BitTorrent protocol") // would be valid magic if length matched + copy(reply[1+19+8:], infoHash) + + addr := fakePeer(t, reply) + p, err := Connect(context.Background(), addr) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer p.Close() + + err = p.Handshake(context.Background(), infoHash, "-WL0001-123456789012") + if err == nil { + t.Fatal("expected pstrlen rejection") + } + if !bytes.Contains([]byte(err.Error()), []byte("pstrlen")) { + t.Errorf("expected pstrlen error, got %v", err) + } +} + +func TestHandshakeRejectsBadMagic(t *testing.T) { + infoHash := make([]byte, 20) + copy(infoHash, "infohash123456789012") + + // pstrlen=19 (valid) but magic is wrong. + reply := make([]byte, 68) + reply[0] = 19 + copy(reply[1:], "WrongTorrent magic!") // 19 bytes, wrong content + copy(reply[1+19+8:], infoHash) + + addr := fakePeer(t, reply) + p, err := Connect(context.Background(), addr) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer p.Close() + + err = p.Handshake(context.Background(), infoHash, "-WL0001-123456789012") + if err == nil { + t.Fatal("expected magic-string rejection") + } + if !bytes.Contains([]byte(err.Error()), []byte("protocol magic")) { + t.Errorf("expected protocol magic error, got %v", err) + } +} + +func TestReadMessageBeforeHandshakeRejected(t *testing.T) { + // Set up a connection but never handshake. + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + p := &PeerConn{conn: client, state: stateInit} + + if _, err := p.ReadMessage(); err == nil { + t.Error("ReadMessage in stateInit should fail") + } + if err := p.WriteMessage(&Message{ID: 0}); err == nil { + t.Error("WriteMessage in stateInit should fail") + } +} + +func TestRequestMetadataBeforeExtendedRejected(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + // Handshook but not Extended. + p := &PeerConn{conn: client, state: stateHandshook} + + if err := p.RequestMetadata(0); err == nil { + t.Error("RequestMetadata before extended handshake should fail") + } +}