From c63ccc963cca8408292e79d9027f286ca145e183 Mon Sep 17 00:00:00 2001 From: LucaCappelletti94 Date: Sat, 27 Jun 2026 10:46:27 +0200 Subject: [PATCH 1/2] Add TrailingBytes error variant for unconsumed WAL input --- src/error.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/error.rs b/src/error.rs index 06073e0..43a6448 100644 --- a/src/error.rs +++ b/src/error.rs @@ -56,6 +56,13 @@ pub enum ReplicationError { /// be spawned, exited early, or dropped a reply. Transient so the stream /// retry logic can reconnect. Backend(String), + + /// A parse left input bytes unconsumed. + /// + /// `consumed` is how many bytes the parser read, `total` is the input + /// length. Returned by `parse_wal_message` and `parse_wal_message_bytes` + /// when a frame carries bytes beyond one message. + TrailingBytes { consumed: usize, total: usize }, } impl core::fmt::Display for ReplicationError { @@ -78,6 +85,10 @@ impl core::fmt::Display for ReplicationError { Self::Generic(msg) => write!(f, "Replication error: {msg}"), Self::Deserialize(msg) => write!(f, "Deserialization error: {msg}"), Self::Backend(msg) => write!(f, "Backend worker error: {msg}"), + Self::TrailingBytes { consumed, total } => write!( + f, + "Trailing bytes after WAL message: consumed {consumed} of {total} bytes" + ), } } } @@ -505,4 +516,19 @@ mod tests { assert!(!err.is_permanent()); assert!(!err.is_cancelled()); } + + #[test] + fn test_trailing_bytes_display() { + let err = ReplicationError::TrailingBytes { + consumed: 42, + total: 44, + }; + assert_eq!( + err.to_string(), + "Trailing bytes after WAL message: consumed 42 of 44 bytes" + ); + assert!(!err.is_transient()); + assert!(!err.is_permanent()); + assert!(!err.is_cancelled()); + } } From b51ef2f8fdbf044ad83f6af657b978b56cd1f8f9 Mon Sep 17 00:00:00 2001 From: LucaCappelletti94 Date: Sat, 27 Jun 2026 10:46:27 +0200 Subject: [PATCH 2/2] Reject trailing bytes in parse_wal_message and parse_wal_message_bytes --- src/protocol.rs | 121 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 114 insertions(+), 7 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 160a6b5..925d094 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -690,22 +690,28 @@ impl LogicalReplicationParser { self.streaming_context.is_some() } - /// Parse a WAL data message from the replication stream + /// Parse a WAL data message from the replication stream. + /// + /// Returns [`ReplicationError::TrailingBytes`] when the buffer contains bytes + /// the parser did not consume. Each replication frame carries exactly one + /// message, so trailing bytes mean a corrupt or misframed input. #[inline] pub fn parse_wal_message(&mut self, data: &[u8]) -> Result { if data.is_empty() { return Err(ReplicationError::protocol("Empty WAL message".to_string())); } + let total = data.len(); let mut reader = BufferReader::new(data); - self.parse_wal_message_from_reader(&mut reader) + self.parse_wal_message_from_reader(&mut reader, total) } - /// Parse a WAL data message from pre-existing Bytes (zero-copy) + /// Parse a WAL data message from pre-existing Bytes (zero-copy). /// - /// This avoids the copy that `parse_wal_message(&[u8])` performs when - /// constructing the internal `BufferReader`. Use this when you already - /// have a `Bytes` handle (e.g. from `BufferReader::read_bytes_buf`). + /// Avoids the copy that `parse_wal_message(&[u8])` performs when constructing + /// the internal `BufferReader`. Use this when you already have a `Bytes` + /// handle (e.g. from `BufferReader::read_bytes_buf`). Rejects trailing bytes + /// with [`ReplicationError::TrailingBytes`], like `parse_wal_message`. #[inline] pub fn parse_wal_message_bytes( &mut self, @@ -715,8 +721,9 @@ impl LogicalReplicationParser { return Err(ReplicationError::protocol("Empty WAL message".to_string())); } + let total = data.len(); let mut reader = BufferReader::from_bytes(data); - self.parse_wal_message_from_reader(&mut reader) + self.parse_wal_message_from_reader(&mut reader, total) } /// Shared implementation for both `parse_wal_message` and `parse_wal_message_bytes`. @@ -724,6 +731,7 @@ impl LogicalReplicationParser { fn parse_wal_message_from_reader( &mut self, reader: &mut BufferReader, + total: usize, ) -> Result { let message_type = reader.read_u8()?; @@ -773,6 +781,13 @@ impl LogicalReplicationParser { None => StreamingReplicationMessage::new(message), }; + let remaining = reader.remaining(); + if remaining != 0 { + return Err(ReplicationError::TrailingBytes { + consumed: total - remaining, + total, + }); + } Ok(streaming_message) } @@ -3187,4 +3202,96 @@ mod tests { let mut parser = LogicalReplicationParser::with_protocol_version(1); assert!(parser.parse_wal_message(&bytes).is_err()); } + + // ======================================== + // Trailing-byte rejection (strict by default) + // ======================================== + + fn valid_insert_frame() -> Vec { + let mut data = vec![message_types::INSERT]; + data.extend_from_slice(&12345u32.to_be_bytes()); // relation_id + data.push(b'N'); // new tuple + data.extend_from_slice(&[0x00, 0x02]); // two columns + data.push(b't'); // text column + data.extend_from_slice(&4u32.to_be_bytes()); + data.extend_from_slice(b"test"); + data.push(b'n'); // null column + data + } + + fn valid_begin_frame() -> Vec { + let mut data = vec![message_types::BEGIN]; + data.extend_from_slice(&0x0100_0000u64.to_be_bytes()); // final_lsn + data.extend_from_slice(&1_700_000_000_000_000i64.to_be_bytes()); // timestamp + data.extend_from_slice(&42u32.to_be_bytes()); // xid + data + } + + fn valid_commit_frame() -> Vec { + let mut data = vec![message_types::COMMIT]; + data.push(0); // flags + data.extend_from_slice(&0x0100_0000u64.to_be_bytes()); // commit_lsn + data.extend_from_slice(&0x0100_0010u64.to_be_bytes()); // end_lsn + data.extend_from_slice(&1_700_000_000_000_000i64.to_be_bytes()); // timestamp + data + } + + #[test] + fn parse_wal_message_accepts_clean_frames() { + for frame in [ + valid_begin_frame(), + valid_insert_frame(), + valid_commit_frame(), + ] { + let mut parser = LogicalReplicationParser::with_protocol_version(1); + assert!(parser.parse_wal_message(&frame).is_ok()); + } + } + + #[test] + fn parse_wal_message_rejects_trailing_bytes() { + for frame in [ + valid_begin_frame(), + valid_insert_frame(), + valid_commit_frame(), + ] { + let clean_len = frame.len(); + let mut with_trailing = frame.clone(); + with_trailing.extend_from_slice(&[0xAA, 0xBB]); + + let mut parser = LogicalReplicationParser::with_protocol_version(1); + let Err(ReplicationError::TrailingBytes { consumed, total }) = + parser.parse_wal_message(&with_trailing) + else { + panic!("expected TrailingBytes for a frame with trailing bytes"); + }; + assert_eq!(consumed, clean_len); + assert_eq!(total, clean_len + 2); + } + } + + #[test] + fn parse_wal_message_bytes_rejects_trailing_bytes() { + let mut frame = valid_insert_frame(); + let clean_len = frame.len(); + frame.extend_from_slice(&[0xAA, 0xBB]); + + let mut parser = LogicalReplicationParser::with_protocol_version(1); + let Err(ReplicationError::TrailingBytes { consumed, total }) = + parser.parse_wal_message_bytes(bytes::Bytes::copy_from_slice(&frame)) + else { + panic!("expected TrailingBytes from the bytes path"); + }; + assert_eq!(consumed, clean_len); + assert_eq!(total, clean_len + 2); + } + + #[test] + fn parse_wal_message_rejects_empty_input() { + let mut parser = LogicalReplicationParser::with_protocol_version(1); + assert!(matches!( + parser.parse_wal_message(&[]), + Err(ReplicationError::Protocol(_)) + )); + } }