Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ pub enum ReplicationError {
/// be spawned, exited early, or dropped a reply. Transient so the stream
/// retry logic can reconnect.
Backend(String),

/// A parse left input bytes unconsumed.
///
/// `consumed` is how many bytes the parser read, `total` is the input
/// length. Returned by `parse_wal_message` and `parse_wal_message_bytes`
/// when a frame carries bytes beyond one message.
TrailingBytes { consumed: usize, total: usize },
}

impl core::fmt::Display for ReplicationError {
Expand All @@ -78,6 +85,10 @@ impl core::fmt::Display for ReplicationError {
Self::Generic(msg) => write!(f, "Replication error: {msg}"),
Self::Deserialize(msg) => write!(f, "Deserialization error: {msg}"),
Self::Backend(msg) => write!(f, "Backend worker error: {msg}"),
Self::TrailingBytes { consumed, total } => write!(
f,
"Trailing bytes after WAL message: consumed {consumed} of {total} bytes"
),
}
}
}
Expand Down Expand Up @@ -505,4 +516,19 @@ mod tests {
assert!(!err.is_permanent());
assert!(!err.is_cancelled());
}

#[test]
fn test_trailing_bytes_display() {
let err = ReplicationError::TrailingBytes {
consumed: 42,
total: 44,
};
assert_eq!(
err.to_string(),
"Trailing bytes after WAL message: consumed 42 of 44 bytes"
);
assert!(!err.is_transient());
assert!(!err.is_permanent());
assert!(!err.is_cancelled());
}
}
121 changes: 114 additions & 7 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,22 +690,28 @@ impl LogicalReplicationParser {
self.streaming_context.is_some()
}

/// Parse a WAL data message from the replication stream
/// Parse a WAL data message from the replication stream.
///
/// Returns [`ReplicationError::TrailingBytes`] when the buffer contains bytes
/// the parser did not consume. Each replication frame carries exactly one
/// message, so trailing bytes mean a corrupt or misframed input.
#[inline]
pub fn parse_wal_message(&mut self, data: &[u8]) -> Result<StreamingReplicationMessage> {
if data.is_empty() {
return Err(ReplicationError::protocol("Empty WAL message".to_string()));
}

let total = data.len();
let mut reader = BufferReader::new(data);
self.parse_wal_message_from_reader(&mut reader)
self.parse_wal_message_from_reader(&mut reader, total)
}

/// Parse a WAL data message from pre-existing Bytes (zero-copy)
/// Parse a WAL data message from pre-existing Bytes (zero-copy).
///
/// This avoids the copy that `parse_wal_message(&[u8])` performs when
/// constructing the internal `BufferReader`. Use this when you already
/// have a `Bytes` handle (e.g. from `BufferReader::read_bytes_buf`).
/// Avoids the copy that `parse_wal_message(&[u8])` performs when constructing
/// the internal `BufferReader`. Use this when you already have a `Bytes`
/// handle (e.g. from `BufferReader::read_bytes_buf`). Rejects trailing bytes
/// with [`ReplicationError::TrailingBytes`], like `parse_wal_message`.
#[inline]
pub fn parse_wal_message_bytes(
&mut self,
Expand All @@ -715,15 +721,17 @@ impl LogicalReplicationParser {
return Err(ReplicationError::protocol("Empty WAL message".to_string()));
}

let total = data.len();
let mut reader = BufferReader::from_bytes(data);
self.parse_wal_message_from_reader(&mut reader)
self.parse_wal_message_from_reader(&mut reader, total)
}

/// Shared implementation for both `parse_wal_message` and `parse_wal_message_bytes`.
#[inline]
fn parse_wal_message_from_reader(
&mut self,
reader: &mut BufferReader,
total: usize,
) -> Result<StreamingReplicationMessage> {
let message_type = reader.read_u8()?;

Expand Down Expand Up @@ -773,6 +781,13 @@ impl LogicalReplicationParser {
None => StreamingReplicationMessage::new(message),
};

let remaining = reader.remaining();
if remaining != 0 {
return Err(ReplicationError::TrailingBytes {
consumed: total - remaining,
total,
});
}
Ok(streaming_message)
}

Expand Down Expand Up @@ -3187,4 +3202,96 @@ mod tests {
let mut parser = LogicalReplicationParser::with_protocol_version(1);
assert!(parser.parse_wal_message(&bytes).is_err());
}

// ========================================
// Trailing-byte rejection (strict by default)
// ========================================

fn valid_insert_frame() -> Vec<u8> {
let mut data = vec![message_types::INSERT];
data.extend_from_slice(&12345u32.to_be_bytes()); // relation_id
data.push(b'N'); // new tuple
data.extend_from_slice(&[0x00, 0x02]); // two columns
data.push(b't'); // text column
data.extend_from_slice(&4u32.to_be_bytes());
data.extend_from_slice(b"test");
data.push(b'n'); // null column
data
}

fn valid_begin_frame() -> Vec<u8> {
let mut data = vec![message_types::BEGIN];
data.extend_from_slice(&0x0100_0000u64.to_be_bytes()); // final_lsn
data.extend_from_slice(&1_700_000_000_000_000i64.to_be_bytes()); // timestamp
data.extend_from_slice(&42u32.to_be_bytes()); // xid
data
}

fn valid_commit_frame() -> Vec<u8> {
let mut data = vec![message_types::COMMIT];
data.push(0); // flags
data.extend_from_slice(&0x0100_0000u64.to_be_bytes()); // commit_lsn
data.extend_from_slice(&0x0100_0010u64.to_be_bytes()); // end_lsn
data.extend_from_slice(&1_700_000_000_000_000i64.to_be_bytes()); // timestamp
data
}

#[test]
fn parse_wal_message_accepts_clean_frames() {
for frame in [
valid_begin_frame(),
valid_insert_frame(),
valid_commit_frame(),
] {
let mut parser = LogicalReplicationParser::with_protocol_version(1);
assert!(parser.parse_wal_message(&frame).is_ok());
}
}

#[test]
fn parse_wal_message_rejects_trailing_bytes() {
for frame in [
valid_begin_frame(),
valid_insert_frame(),
valid_commit_frame(),
] {
let clean_len = frame.len();
let mut with_trailing = frame.clone();
with_trailing.extend_from_slice(&[0xAA, 0xBB]);

let mut parser = LogicalReplicationParser::with_protocol_version(1);
let Err(ReplicationError::TrailingBytes { consumed, total }) =
parser.parse_wal_message(&with_trailing)
else {
panic!("expected TrailingBytes for a frame with trailing bytes");
};
assert_eq!(consumed, clean_len);
assert_eq!(total, clean_len + 2);
}
}

#[test]
fn parse_wal_message_bytes_rejects_trailing_bytes() {
let mut frame = valid_insert_frame();
let clean_len = frame.len();
frame.extend_from_slice(&[0xAA, 0xBB]);

let mut parser = LogicalReplicationParser::with_protocol_version(1);
let Err(ReplicationError::TrailingBytes { consumed, total }) =
parser.parse_wal_message_bytes(bytes::Bytes::copy_from_slice(&frame))
else {
panic!("expected TrailingBytes from the bytes path");
};
assert_eq!(consumed, clean_len);
assert_eq!(total, clean_len + 2);
}

#[test]
fn parse_wal_message_rejects_empty_input() {
let mut parser = LogicalReplicationParser::with_protocol_version(1);
assert!(matches!(
parser.parse_wal_message(&[]),
Err(ReplicationError::Protocol(_))
));
}
}
Loading