From 2904bdff9d1def91a5b5a44e27b9881161177a07 Mon Sep 17 00:00:00 2001 From: Will Scott Date: Wed, 1 Apr 2026 08:57:06 +0200 Subject: [PATCH 1/5] Add support for extensions of the multiaddr interface. This introduces a feature flag `Custom` that can be used to allow for parsing and management of protocols that are not part of the hard-coded set of known multiaddrs. This parallels the more permissive support that is found in js and golang implementations. by default, unknown protocols will not be parsed, since the semantics of how to parse their arguments cannot be known, but they can be manually constructed and then serialized to string. Specific extension protocols can be registered in a protocol registry to extend the default set and allow for handling of multiaddrs using those additional protocols more naturally. --- Cargo.toml | 1 + src/custom.rs | 363 ++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 7 +- src/protocol.rs | 106 +++++++++++++- 4 files changed, 475 insertions(+), 2 deletions(-) create mode 100644 src/custom.rs diff --git a/Cargo.toml b/Cargo.toml index 72e1d04..71fafe6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ version = "0.18.3" [features] default = ["url"] +custom = [] [dependencies] arrayref = "0.3" diff --git a/src/custom.rs b/src/custom.rs new file mode 100644 index 0000000..3311e00 --- /dev/null +++ b/src/custom.rs @@ -0,0 +1,363 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use crate::{Error, Multiaddr, Protocol, Result}; + +/// A transcoder defines how to encode and decode a custom protocol's data +/// between its binary representation and its human-readable string representation. +pub trait Transcoder: Send + Sync { + /// Attempts to parse the human-readable string component of a protocol into bytes. + fn string_to_bytes( + &self, + s: &str, + ) -> std::result::Result, Box>; + + /// Attempts to format the binary representation of a protocol's data into a human-readable string. + fn bytes_to_string( + &self, + bytes: &[u8], + ) -> std::result::Result>; +} + +/// A custom protocol definition. +pub struct CustomProtocolDef { + pub name: &'static str, + pub code: u32, + /// The length of the binary payload. + /// `0` means no data. `> 0` means a fixed data length. `-1` denotes a length-prefixed protocol. + pub size: i32, + pub path: bool, + pub transcoder: Option>, +} + +impl std::cmp::PartialEq for CustomProtocolDef { + fn eq(&self, other: &Self) -> bool { + self.code == other.code && self.name == other.name + } +} + +impl std::cmp::Eq for CustomProtocolDef {} + +impl std::fmt::Debug for CustomProtocolDef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CustomProtocolDef") + .field("name", &self.name) + .field("code", &self.code) + .field("size", &self.size) + .field("path", &self.path) + .finish() + } +} + +/// A registry mapping protocol codes and names to their custom definitions. +#[derive(Clone)] +pub struct Registry { + by_code: HashMap>, + by_name: HashMap>, +} + +impl Default for Registry { + fn default() -> Self { + let mut r = Self { + by_code: HashMap::new(), + by_name: HashMap::new(), + }; + r.register_builtins(); + r + } +} + +impl Registry { + /// Create a new, empty protocol registry. Wait, new() actually uses default and adds built-ins. + pub fn new() -> Self { + Self::default() + } + + /// Add all built-in standard protocols to the registry + fn register_builtins(&mut self) { + for &(name, code, size, path) in crate::protocol::BUILT_IN_PROTOCOLS.iter() { + self.register(CustomProtocolDef { + name, + code, + size, + path, + transcoder: None, + }); + } + } + + /// Add a custom protocol definition to this registry. + pub fn register(&mut self, mut def: CustomProtocolDef) { + if def.path && def.name.starts_with('/') { + def.name = def.name.trim_start_matches('/'); + } + let name = def.name.to_string(); + let code = def.code; + let arc = Arc::new(def); + self.by_code.insert(code, arc.clone()); + self.by_name.insert(name, arc); + } + + /// Returns a registered custom protocol by its integer code. + pub fn get_by_code(&self, code: u32) -> Option> { + self.by_code.get(&code).cloned() + } + + /// Returns a registered custom protocol by its string name. + pub fn get_by_name(&self, name: &str) -> Option> { + self.by_name.get(name).cloned() + } + + /// Unregisters a protocol by its string name. + pub fn unregister_by_name(&mut self, name: &str) { + if let Some(def) = self.by_name.remove(name) { + self.by_code.remove(&def.code); + } + } + + /// Unregisters a protocol by its integer code. + pub fn unregister_by_code(&mut self, code: u32) { + if let Some(def) = self.by_code.remove(&code) { + self.by_name.remove(def.name); + } + } + + /// Iterate over the protocols in a `Multiaddr` using this registry. + pub fn iter<'a>(&'a self, ma: &'a Multiaddr) -> RegistryIter<'a> { + RegistryIter { + registry: self, + data: ma.as_ref(), + } + } +} + +/// Iterator over protocols using a registry. +pub struct RegistryIter<'a> { + registry: &'a Registry, + data: &'a [u8], +} + +impl<'a> Iterator for RegistryIter<'a> { + type Item = Protocol<'a>; + + fn next(&mut self) -> Option { + if self.data.is_empty() { + return None; + } + + let (p, next_data) = self.registry.parse_protocol_from_bytes(self.data).ok()?; + self.data = next_data; + Some(p) + } +} + +impl Registry { + /// Try parsing a single Protocol from bytes using the registry. + pub fn parse_protocol_from_bytes<'a>( + &self, + input: &'a [u8], + ) -> Result<(Protocol<'a>, &'a [u8])> { + let n_input = input; + let id_res = unsigned_varint::decode::u32(n_input); + if let Ok((id, _rest)) = id_res { + if !self.by_code.contains_key(&id) { + return Err(Error::UnknownProtocolId(id)); + } + } + + if let Ok(res) = Protocol::from_bytes(input) { + return Ok(res); + } + + let n_input = input; + let id_res = unsigned_varint::decode::u32(n_input); + if let Ok((id, rest)) = id_res { + if let Some(def) = self.get_by_code(id) { + let (data, out_rest) = if def.size == 0 { + (std::borrow::Cow::Borrowed(&rest[..0]), rest) + } else if def.size > 0 { + let fixed = def.size as usize; + if rest.len() < fixed { + return Err(Error::DataLessThanLen); + } + let (d, r) = rest.split_at(fixed); + (std::borrow::Cow::Borrowed(d), r) + } else { + let (len, r) = + unsigned_varint::decode::usize(rest).map_err(|_| Error::DataLessThanLen)?; + if r.len() < len { + return Err(Error::DataLessThanLen); + } + let (d, r2) = r.split_at(len); + (std::borrow::Cow::Borrowed(d), r2) + }; + return Ok((Protocol::Custom { def, data }, out_rest)); + } + } + + Err(Error::UnknownProtocolId( + id_res.map(|(i, _)| i).unwrap_or(0), + )) + } + + /// Try parsing a single Protocol from string parts using the registry. + pub fn parse_protocol_from_str_parts<'a, I>(&self, iter: &mut I) -> Result> + where + I: Iterator + Clone, + { + let mut peek_iter = iter.clone(); + if let Some(tag) = peek_iter.next() { + if !self.by_name.contains_key(tag) { + return Err(Error::UnknownProtocolString(tag.to_string())); + } + } + + let mut native_iter = iter.clone(); + if let Ok(p) = Protocol::from_str_parts(&mut native_iter) { + *iter = native_iter; + return Ok(p); + } + + let mut peek_iter = iter.clone(); + if let Some(tag) = peek_iter.next() { + if let Some(def) = self.get_by_name(tag) { + iter.next(); // consume the tag + let data = if def.size == 0 { + vec![] + } else if let Some(t) = &def.transcoder { + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + t.string_to_bytes(part) + .map_err(|_| Error::InvalidProtocolString)? + } else if def.path { + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + percent_encoding::percent_decode(part.as_bytes()).collect::>() + } else { + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + multibase::Base::Base58Btc + .decode(part) + .map_err(|_| Error::InvalidProtocolString)? + }; + return Ok(Protocol::Custom { + def, + data: std::borrow::Cow::Owned(data), + }); + } + } + + let mut final_try = iter.clone(); + if let Some(tag) = final_try.next() { + Err(Error::UnknownProtocolString(tag.to_string())) + } else { + Err(Error::InvalidProtocolString) + } + } + + /// Parse a Multiaddr string using this registry + pub fn try_from_str(&self, input: &str) -> Result { + let mut addr = Multiaddr::empty(); + let mut parts = input.split('/'); + + if Some("") != parts.next() { + return Err(Error::InvalidMultiaddr); + } + + while parts.clone().peekable().peek().is_some() { + let p = self.parse_protocol_from_str_parts(&mut parts)?; + addr = addr.with(p); + } + + Ok(addr) + } + + /// Parse a Multiaddr from bytes using this registry + pub fn try_from_bytes(&self, mut input: &[u8]) -> Result { + let mut addr = Multiaddr::empty(); + while !input.is_empty() { + let (p, rest) = self.parse_protocol_from_bytes(input)?; + addr = addr.with(p); + input = rest; + } + Ok(addr) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct SimpleTranscoder; + impl Transcoder for SimpleTranscoder { + fn string_to_bytes( + &self, + s: &str, + ) -> std::result::Result, Box> { + Ok(s.as_bytes().to_vec()) + } + fn bytes_to_string( + &self, + bytes: &[u8], + ) -> std::result::Result> { + Ok(String::from_utf8(bytes.to_vec())?) + } + } + + #[test] + fn test_custom_protocol_registry() { + let mut registry = Registry::new(); + registry.register(CustomProtocolDef { + name: "my-custom", + code: 999, + size: -1, + path: false, + transcoder: Some(Box::new(SimpleTranscoder)), + }); + + let addr = registry + .try_from_str("/ip4/127.0.0.1/my-custom/helloworld") + .unwrap(); + + // Output via normal iter should panic because the global parser doesn't know 999, + // wait, we modified the normal fmt::Display to iterate, BUT Display iterates the multiaddr. + // If we try `addr.to_string()`, it will panic if it's not a generic iterator. + + let vec = addr.to_vec(); + // Parse back from vec + let parsed = registry.try_from_bytes(&vec).unwrap(); + + let mut iter = registry.iter(&parsed); + if let Some(Protocol::Ip4(ip)) = iter.next() { + assert_eq!(ip, std::net::Ipv4Addr::new(127, 0, 0, 1)); + } else { + panic!("expected ip4"); + } + + if let Some(Protocol::Custom { def, data }) = iter.next() { + assert_eq!(def.code, 999); + assert_eq!(def.name, "my-custom"); + assert_eq!(data.as_ref(), b"helloworld"); + } else { + panic!("expected custom protocol"); + } + } + + #[test] + fn test_unregister_builtin() { + let mut registry = Registry::default(); + + // Assert tcp works + let addr = registry.try_from_str("/ip4/127.0.0.1/tcp/80").unwrap(); + let mut iter = registry.iter(&addr); + assert!(matches!(iter.next(), Some(Protocol::Ip4(_)))); + assert!(matches!(iter.next(), Some(Protocol::Tcp(80)))); + + // Unregister tcp + registry.unregister_by_name("tcp"); + + // Assert tcp fails now + assert!(registry.try_from_str("/ip4/127.0.0.1/tcp/80").is_err()); + + // And similarly from bytes + let vec = addr.to_vec(); + assert!(registry.try_from_bytes(&vec).is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index b6b0ad4..264fb11 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,9 +7,14 @@ mod errors; mod onion_addr; mod protocol; +#[cfg(feature = "custom")] +mod custom; + #[cfg(feature = "url")] mod from_url; +#[cfg(feature = "custom")] +pub use self::custom::{CustomProtocolDef, Registry, Transcoder}; pub use self::errors::{Error, Result}; pub use self::onion_addr::Onion3Addr; pub use self::protocol::Protocol; @@ -223,7 +228,7 @@ impl Multiaddr { /// Returns &str identifiers for the protocol names themselves. /// This omits specific info like addresses, ports, peer IDs, and the like. /// Example: `"/ip4/127.0.0.1/tcp/5001"` would return `["ip4", "tcp"]` - pub fn protocol_stack(&self) -> ProtoStackIter { + pub fn protocol_stack(&self) -> ProtoStackIter<'_> { ProtoStackIter { parts: self.iter() } } } diff --git a/src/protocol.rs b/src/protocol.rs index f0fa1d4..5137919 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -60,6 +60,49 @@ const P2P_STARDUST: u32 = 277; // Deprecated const WEBRTC: u32 = 281; const HTTP_PATH: u32 = 481; +#[cfg(feature = "custom")] +pub(crate) const BUILT_IN_PROTOCOLS: &[(&str, u32, i32, bool)] = &[ + ("ip4", IP4, 4, false), + ("tcp", TCP, 2, false), + ("udp", UDP, 2, false), + ("dccp", DCCP, 2, false), + ("ip6", IP6, 16, false), + ("p2p", P2P, -1, false), + ("ipfs", P2P, -1, false), + ("http", HTTP, 0, false), + ("https", HTTPS, 0, false), + ("onion", ONION, 12, false), + ("onion3", ONION3, 37, false), + ("quic", QUIC, 0, false), + ("quic-v1", QUIC_V1, 0, false), + ("ws", WS, 0, false), + ("wss", WSS, 0, false), + ("p2p-websocket-star", P2P_WEBSOCKET_STAR, 0, false), + ("webrtc-direct", WEBRTC_DIRECT, 0, false), + ("p2p-webrtc-direct", P2P_WEBRTC_DIRECT, 0, false), + ("certhash", CERTHASH, -1, false), + ("p2p-circuit", P2P_CIRCUIT, 0, false), + ("sctp", SCTP, 2, false), + ("udt", UDT, 0, false), + ("utp", UTP, 0, false), + ("unix", UNIX, -1, true), + ("dns", DNS, -1, false), + ("dns4", DNS4, -1, false), + ("dns6", DNS6, -1, false), + ("dnsaddr", DNSADDR, -1, false), + ("tls", TLS, 0, false), + ("noise", NOISE, 0, false), + ("webtransport", WEBTRANSPORT, 0, false), + ("ip6zone", IP6ZONE, -1, true), + ("ipcidr", IPCIDR, 1, false), + ("garlic64", GARLIC64, -1, false), + ("garlic32", GARLIC32, -1, false), + ("sni", SNI, -1, false), + ("webrtc", WEBRTC, 0, false), + ("http-path", HTTP_PATH, -1, true), + ("memory", MEMORY, 8, false), +]; + /// Type-alias for how multi-addresses use `Multihash`. /// /// The `64` defines the allocation size for the digest within the `Multihash`. @@ -130,6 +173,13 @@ pub enum Protocol<'a> { P2pStardust, WebRTC, HttpPath(Cow<'a, str>), + #[cfg(feature = "custom")] + Custom { + def: std::sync::Arc, + data: Cow<'a, [u8]>, + }, + #[cfg(feature = "custom")] + Unknown(u32, Cow<'a, [u8]>), } impl<'a> Protocol<'a> { @@ -625,6 +675,19 @@ impl<'a> Protocol<'a> { w.write_all(encode::usize(bytes.len(), &mut encode::usize_buffer()))?; w.write_all(bytes)? } + #[cfg(feature = "custom")] + Protocol::Custom { def, data } => { + w.write_all(encode::u32(def.code, &mut buf))?; + if def.size == -1 { + w.write_all(encode::usize(data.len(), &mut encode::usize_buffer()))?; + } + w.write_all(data.as_ref())? + } + #[cfg(feature = "custom")] + Protocol::Unknown(code, data) => { + w.write_all(encode::u32(*code, &mut buf))?; + w.write_all(data.as_ref())? + } } Ok(()) } @@ -673,6 +736,13 @@ impl<'a> Protocol<'a> { P2pStardust => P2pStardust, WebRTC => WebRTC, HttpPath(cow) => HttpPath(Cow::Owned(cow.into_owned())), + #[cfg(feature = "custom")] + Custom { def, data } => Custom { + def: def.clone(), + data: Cow::Owned(data.into_owned()), + }, + #[cfg(feature = "custom")] + Unknown(code, data) => Unknown(code, Cow::Owned(data.into_owned())), } } @@ -721,6 +791,10 @@ impl<'a> Protocol<'a> { P2pStardust => "p2p-stardust", WebRTC => "webrtc", HttpPath(_) => "http-path", + #[cfg(feature = "custom")] + Custom { def, .. } => def.name, + #[cfg(feature = "custom")] + Unknown(_, _) => "unknown", } } } @@ -728,7 +802,11 @@ impl<'a> Protocol<'a> { impl fmt::Display for Protocol<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use self::Protocol::*; - write!(f, "/{}", self.tag())?; + match self { + #[cfg(feature = "custom")] + Unknown(code, _) => write!(f, "/unknown-{code}"), + _ => write!(f, "/{}", self.tag()), + }?; match self { Dccp(port) => write!(f, "/{port}"), Dns(s) => write!(f, "/{s}"), @@ -783,6 +861,32 @@ impl fmt::Display for Protocol<'_> { percent_encoding::percent_encode(s.as_bytes(), PATH_SEGMENT_ENCODE_SET); write!(f, "/{encoded}") } + #[cfg(feature = "custom")] + Custom { def, data } => { + if let Some(t) = &def.transcoder { + let s = t.bytes_to_string(data.as_ref()).map_err(|_| fmt::Error)?; + if !s.is_empty() { + write!(f, "/{s}")?; + } + Ok(()) + } else if data.is_empty() { + Ok(()) + } else if def.path { + let s = std::str::from_utf8(data.as_ref()).map_err(|_| fmt::Error)?; + let encoded = + percent_encoding::percent_encode(s.as_bytes(), PATH_SEGMENT_ENCODE_SET); + write!(f, "/{encoded}") + } else { + write!(f, "/{}", multibase::Base::Base58Btc.encode(data.as_ref())) + } + } + #[cfg(feature = "custom")] + Unknown(_, data) => { + if !data.is_empty() { + write!(f, "/{}", multibase::Base::Base58Btc.encode(data.as_ref()))?; + } + Ok(()) + } _ => Ok(()), } } From 8aceb4daec0c3f1612ac6762188750d446859463 Mon Sep 17 00:00:00 2001 From: Will Scott Date: Mon, 13 Apr 2026 14:33:27 +0200 Subject: [PATCH 2/5] extend the handling of custom protocols to be graceful when handling binary<>string format round trips --- src/custom.rs | 69 ++++++++++++++++++++++++++++++++++++++++++------- src/protocol.rs | 22 ++++++++++++++++ tests/lib.rs | 8 ++++++ 3 files changed, 89 insertions(+), 10 deletions(-) diff --git a/src/custom.rs b/src/custom.rs index 3311e00..835e04e 100644 --- a/src/custom.rs +++ b/src/custom.rs @@ -159,18 +159,14 @@ impl Registry { ) -> Result<(Protocol<'a>, &'a [u8])> { let n_input = input; let id_res = unsigned_varint::decode::u32(n_input); - if let Ok((id, _rest)) = id_res { - if !self.by_code.contains_key(&id) { - return Err(Error::UnknownProtocolId(id)); - } - } if let Ok(res) = Protocol::from_bytes(input) { - return Ok(res); + let is_unknown = matches!(&res.0, Protocol::Unknown(_, _)); + if !is_unknown { + return Ok(res); + } } - let n_input = input; - let id_res = unsigned_varint::decode::u32(n_input); if let Ok((id, rest)) = id_res { if let Some(def) = self.get_by_code(id) { let (data, out_rest) = if def.size == 0 { @@ -193,6 +189,11 @@ impl Registry { }; return Ok((Protocol::Custom { def, data }, out_rest)); } + + return Ok(( + Protocol::Unknown(id, std::borrow::Cow::Borrowed(rest)), + [].as_ref(), + )); } Err(Error::UnknownProtocolId( @@ -279,6 +280,15 @@ impl Registry { } Ok(addr) } + + /// Format a Multiaddr into a string using this registry + pub fn to_string(&self, addr: &Multiaddr) -> String { + let mut s = String::new(); + for p in self.iter(addr) { + s.push_str(&p.to_string()); + } + s + } } #[cfg(test)] @@ -356,8 +366,47 @@ mod tests { // Assert tcp fails now assert!(registry.try_from_str("/ip4/127.0.0.1/tcp/80").is_err()); - // And similarly from bytes + // And similarly from bytes, it will now parse as Tcp since we natively fallback to standard protocols let vec = addr.to_vec(); - assert!(registry.try_from_bytes(&vec).is_err()); + let parsed_unknown = registry.try_from_bytes(&vec).unwrap(); + let mut parsed_iter = registry.iter(&parsed_unknown); + assert!(matches!(parsed_iter.next(), Some(Protocol::Ip4(_)))); + assert!(matches!(parsed_iter.next(), Some(Protocol::Tcp(80)))); + } + + #[test] + fn test_custom_protocol_registry_printing() { + let mut registry = Registry::new(); + registry.register(CustomProtocolDef { + name: "my-custom", + code: 999, + size: -1, + path: false, + transcoder: Some(Box::new(SimpleTranscoder)), + }); + + // Parsed string multi addr with a custom protocol + let addr = registry + .try_from_str("/ip4/127.0.0.1/my-custom/helloworld") + .unwrap(); + + // 1. Printing with Registry works as expected, displaying the registered custom format + let registry_printed = registry.to_string(&addr); + assert_eq!(registry_printed, "/ip4/127.0.0.1/my-custom/helloworld"); + + // 2. Native Multiaddr printing gracefully falls back to unknown without panicking + let native_printed = addr.to_string(); + // Native printing uses base58 for the rest of the bytes (the length varint and data). + // For size=-1, the length varint `10` followed by "helloworld" becomes '3ah4EQvnau95Y8K' + assert_eq!(native_printed, "/ip4/127.0.0.1/unknown-999/3ah4EQvnau95Y8K"); + + // 3. Confirm that the final 'unknown-999' round-trips on parse back to the same multiaddr + let parsed_back = native_printed + .parse::() + .expect("Should parse unknown protocol formatting natively"); + assert_eq!( + parsed_back, addr, + "Round-trip multiaddr bytes must match the original instance precisely" + ); } } diff --git a/src/protocol.rs b/src/protocol.rs index 5137919..2999ca3 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -333,6 +333,22 @@ impl<'a> Protocol<'a> { let decoded = percent_encoding::percent_decode(s.as_bytes()).decode_utf8()?; Ok(Protocol::HttpPath(decoded)) } + #[cfg(feature = "custom")] + unknown if unknown.starts_with("unknown-") => { + let id_str = &unknown["unknown-".len()..]; + let id: u32 = id_str + .parse() + .map_err(|_| Error::UnknownProtocolString(unknown.to_string()))?; + let data = match iter.next() { + Some("") => vec![], + Some(s) => match multibase::Base::Base58Btc.decode(s) { + Ok(d) => d, + Err(_) => return Err(Error::InvalidProtocolString), + }, + None => vec![], + }; + Ok(Protocol::Unknown(id, std::borrow::Cow::Owned(data))) + } unknown => Err(Error::UnknownProtocolString(unknown.to_string())), } } @@ -522,6 +538,12 @@ impl<'a> Protocol<'a> { rest, )) } + #[cfg(feature = "custom")] + _ => Ok(( + Protocol::Unknown(id, std::borrow::Cow::Borrowed(input)), + [].as_ref(), + )), + #[cfg(not(feature = "custom"))] _ => Err(Error::UnknownProtocolId(id)), } } diff --git a/tests/lib.rs b/tests/lib.rs index 936809e..794da2e 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -598,7 +598,15 @@ fn to_multiaddr() { #[test] fn from_bytes_fail() { let bytes = vec![1, 2, 3, 4]; + #[cfg(not(feature = "custom"))] assert!(Multiaddr::try_from(bytes).is_err()); + #[cfg(feature = "custom")] + { + let multiaddr = Multiaddr::try_from(bytes) + .expect("Should parse as Unknown protocol when custom feature is on"); + let mut iter = multiaddr.iter(); + assert_eq!("unknown", iter.next().unwrap().tag()); + } } #[test] From 394c1c5832be9da9067b191cd360a5ffc2653b99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Oliveira?= Date: Tue, 21 Apr 2026 17:46:21 +0100 Subject: [PATCH 3/5] hide Arcs --- src/custom.rs | 74 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 24 deletions(-) diff --git a/src/custom.rs b/src/custom.rs index 835e04e..df8c5b6 100644 --- a/src/custom.rs +++ b/src/custom.rs @@ -20,6 +20,7 @@ pub trait Transcoder: Send + Sync { } /// A custom protocol definition. +#[derive(Clone)] pub struct CustomProtocolDef { pub name: &'static str, pub code: u32, @@ -27,7 +28,27 @@ pub struct CustomProtocolDef { /// `0` means no data. `> 0` means a fixed data length. `-1` denotes a length-prefixed protocol. pub size: i32, pub path: bool, - pub transcoder: Option>, + pub transcoder: Option>, +} + +impl CustomProtocolDef { + /// Create a new custom protocol definition. + pub fn new( + name: &'static str, + code: u32, + size: i32, + path: bool, + transcoder: Option, + ) -> Self { + let transcoder = transcoder.map(|t| Arc::new(t) as Arc); + Self { + name, + code, + size, + path, + transcoder, + } + } } impl std::cmp::PartialEq for CustomProtocolDef { @@ -52,8 +73,8 @@ impl std::fmt::Debug for CustomProtocolDef { /// A registry mapping protocol codes and names to their custom definitions. #[derive(Clone)] pub struct Registry { - by_code: HashMap>, - by_name: HashMap>, + by_code: HashMap, + by_name: HashMap, } impl Default for Registry { @@ -93,18 +114,17 @@ impl Registry { } let name = def.name.to_string(); let code = def.code; - let arc = Arc::new(def); - self.by_code.insert(code, arc.clone()); - self.by_name.insert(name, arc); + self.by_code.insert(code, def.clone()); + self.by_name.insert(name, def.clone()); } /// Returns a registered custom protocol by its integer code. - pub fn get_by_code(&self, code: u32) -> Option> { + pub fn get_by_code(&self, code: u32) -> Option { self.by_code.get(&code).cloned() } /// Returns a registered custom protocol by its string name. - pub fn get_by_name(&self, name: &str) -> Option> { + pub fn get_by_name(&self, name: &str) -> Option { self.by_name.get(name).cloned() } @@ -187,7 +207,13 @@ impl Registry { let (d, r2) = r.split_at(len); (std::borrow::Cow::Borrowed(d), r2) }; - return Ok((Protocol::Custom { def, data }, out_rest)); + return Ok(( + Protocol::Custom { + def: Arc::new(def), + data, + }, + out_rest, + )); } return Ok(( @@ -239,7 +265,7 @@ impl Registry { .map_err(|_| Error::InvalidProtocolString)? }; return Ok(Protocol::Custom { - def, + def: Arc::new(def), data: std::borrow::Cow::Owned(data), }); } @@ -314,13 +340,13 @@ mod tests { #[test] fn test_custom_protocol_registry() { let mut registry = Registry::new(); - registry.register(CustomProtocolDef { - name: "my-custom", - code: 999, - size: -1, - path: false, - transcoder: Some(Box::new(SimpleTranscoder)), - }); + registry.register(CustomProtocolDef::new( + "my-custom", + 999, + -1, + false, + Some(SimpleTranscoder), + )); let addr = registry .try_from_str("/ip4/127.0.0.1/my-custom/helloworld") @@ -377,13 +403,13 @@ mod tests { #[test] fn test_custom_protocol_registry_printing() { let mut registry = Registry::new(); - registry.register(CustomProtocolDef { - name: "my-custom", - code: 999, - size: -1, - path: false, - transcoder: Some(Box::new(SimpleTranscoder)), - }); + registry.register(CustomProtocolDef::new( + "my-custom", + 999, + -1, + false, + Some(SimpleTranscoder), + )); // Parsed string multi addr with a custom protocol let addr = registry From 368608504882c07b6ef3c69394c66db1f6e23da3 Mon Sep 17 00:00:00 2001 From: Will Scott Date: Wed, 22 Apr 2026 16:58:41 -0700 Subject: [PATCH 4/5] address reviewer comments --- src/custom.rs | 204 +++++++++++++++++++++++++++++++------------------- 1 file changed, 126 insertions(+), 78 deletions(-) diff --git a/src/custom.rs b/src/custom.rs index df8c5b6..2452627 100644 --- a/src/custom.rs +++ b/src/custom.rs @@ -22,12 +22,16 @@ pub trait Transcoder: Send + Sync { /// A custom protocol definition. #[derive(Clone)] pub struct CustomProtocolDef { + /// The string identifier for the protocol (e.g. `tcp`, `http`, or `my-custom`). pub name: &'static str, + /// The unique unsigned integer code for the protocol. pub code: u32, /// The length of the binary payload. - /// `0` means no data. `> 0` means a fixed data length. `-1` denotes a length-prefixed protocol. + /// `0` means no data. `> 0` means a fixed data length. `-1` denotes a length-prefixed protocol or custom encoding. pub size: i32, + /// Whether the protocol's string representation is a file path (e.g., `/unix/tmp/socket` instead of a base-encoded payload). pub path: bool, + /// An optional transcoder used to encode and decode the protocol's data between binary and human-readable string formats. pub transcoder: Option>, } @@ -94,7 +98,7 @@ impl Registry { Self::default() } - /// Add all built-in standard protocols to the registry + /// Add all built-in standard protocols to the registry. fn register_builtins(&mut self) { for &(name, code, size, path) in crate::protocol::BUILT_IN_PROTOCOLS.iter() { self.register(CustomProtocolDef { @@ -143,7 +147,7 @@ impl Registry { } /// Iterate over the protocols in a `Multiaddr` using this registry. - pub fn iter<'a>(&'a self, ma: &'a Multiaddr) -> RegistryIter<'a> { + pub fn parse_addr<'a>(&'a self, ma: &'a Multiaddr) -> RegistryIter<'a> { RegistryIter { registry: self, data: ma.as_ref(), @@ -187,56 +191,68 @@ impl Registry { } } - if let Ok((id, rest)) = id_res { - if let Some(def) = self.get_by_code(id) { - let (data, out_rest) = if def.size == 0 { - (std::borrow::Cow::Borrowed(&rest[..0]), rest) - } else if def.size > 0 { - let fixed = def.size as usize; - if rest.len() < fixed { - return Err(Error::DataLessThanLen); - } - let (d, r) = rest.split_at(fixed); - (std::borrow::Cow::Borrowed(d), r) - } else { - let (len, r) = - unsigned_varint::decode::usize(rest).map_err(|_| Error::DataLessThanLen)?; - if r.len() < len { - return Err(Error::DataLessThanLen); - } - let (d, r2) = r.split_at(len); - (std::borrow::Cow::Borrowed(d), r2) - }; + let (id, rest) = match id_res { + Ok((id, rest)) => (id, rest), + Err(_) => return Err(Error::UnknownProtocolId(0)), + }; + + let def = match self.get_by_code(id) { + Some(def) => def, + None => { + // If the protocol isn't registered, we just return it as Unknown + // so that we can gracefully iterate over or reserialize it later. return Ok(( - Protocol::Custom { - def: Arc::new(def), - data, - }, - out_rest, + Protocol::Unknown(id, std::borrow::Cow::Borrowed(rest)), + [].as_ref(), )); } + }; - return Ok(( - Protocol::Unknown(id, std::borrow::Cow::Borrowed(rest)), - [].as_ref(), - )); - } + // Extract the protocol data based on the registered size definition + let (data, out_rest) = if def.size == 0 { + // Protocol has no data payload expected + (std::borrow::Cow::Borrowed(&rest[..0]), rest) + } else if def.size > 0 { + // Protocol has a fixed-length data payload + let fixed = def.size as usize; + if rest.len() < fixed { + return Err(Error::DataLessThanLen); + } + let (d, r) = rest.split_at(fixed); + (std::borrow::Cow::Borrowed(d), r) + } else { + // Protocol size is -1, meaning it is length-prefixed. + // Decode the varint representing the length of the upcoming data. + let (len, r) = + unsigned_varint::decode::usize(rest).map_err(|_| Error::DataLessThanLen)?; + if r.len() < len { + return Err(Error::DataLessThanLen); + } + let (d, r2) = r.split_at(len); + (std::borrow::Cow::Borrowed(d), r2) + }; - Err(Error::UnknownProtocolId( - id_res.map(|(i, _)| i).unwrap_or(0), + Ok(( + Protocol::Custom { + def: Arc::new(def), + data, + }, + out_rest, )) } /// Try parsing a single Protocol from string parts using the registry. - pub fn parse_protocol_from_str_parts<'a, I>(&self, iter: &mut I) -> Result> + pub fn parse_protocol_from_str_parts<'a, I>( + &self, + iter: &mut std::iter::Peekable, + ) -> Result> where I: Iterator + Clone, { - let mut peek_iter = iter.clone(); - if let Some(tag) = peek_iter.next() { - if !self.by_name.contains_key(tag) { - return Err(Error::UnknownProtocolString(tag.to_string())); - } + let &tag = iter.peek().ok_or(Error::InvalidProtocolString)?; + + if !self.by_name.contains_key(tag) { + return Err(Error::UnknownProtocolString(tag.to_string())); } let mut native_iter = iter.clone(); @@ -245,50 +261,39 @@ impl Registry { return Ok(p); } - let mut peek_iter = iter.clone(); - if let Some(tag) = peek_iter.next() { - if let Some(def) = self.get_by_name(tag) { - iter.next(); // consume the tag - let data = if def.size == 0 { - vec![] - } else if let Some(t) = &def.transcoder { - let part = iter.next().ok_or(Error::InvalidProtocolString)?; - t.string_to_bytes(part) - .map_err(|_| Error::InvalidProtocolString)? - } else if def.path { - let part = iter.next().ok_or(Error::InvalidProtocolString)?; - percent_encoding::percent_decode(part.as_bytes()).collect::>() - } else { - let part = iter.next().ok_or(Error::InvalidProtocolString)?; - multibase::Base::Base58Btc - .decode(part) - .map_err(|_| Error::InvalidProtocolString)? - }; - return Ok(Protocol::Custom { - def: Arc::new(def), - data: std::borrow::Cow::Owned(data), - }); - } - } - - let mut final_try = iter.clone(); - if let Some(tag) = final_try.next() { - Err(Error::UnknownProtocolString(tag.to_string())) + let def = self.get_by_name(tag).unwrap(); + iter.next(); // consume the tag + let data = if def.size == 0 { + vec![] + } else if let Some(t) = &def.transcoder { + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + t.string_to_bytes(part) + .map_err(|_| Error::InvalidProtocolString)? + } else if def.path { + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + percent_encoding::percent_decode(part.as_bytes()).collect::>() } else { - Err(Error::InvalidProtocolString) - } + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + multibase::Base::Base58Btc + .decode(part) + .map_err(|_| Error::InvalidProtocolString)? + }; + Ok(Protocol::Custom { + def: Arc::new(def), + data: std::borrow::Cow::Owned(data), + }) } /// Parse a Multiaddr string using this registry pub fn try_from_str(&self, input: &str) -> Result { let mut addr = Multiaddr::empty(); - let mut parts = input.split('/'); + let mut parts = input.split('/').peekable(); if Some("") != parts.next() { return Err(Error::InvalidMultiaddr); } - while parts.clone().peekable().peek().is_some() { + while parts.peek().is_some() { let p = self.parse_protocol_from_str_parts(&mut parts)?; addr = addr.with(p); } @@ -310,7 +315,7 @@ impl Registry { /// Format a Multiaddr into a string using this registry pub fn to_string(&self, addr: &Multiaddr) -> String { let mut s = String::new(); - for p in self.iter(addr) { + for p in self.parse_addr(addr) { s.push_str(&p.to_string()); } s @@ -360,7 +365,7 @@ mod tests { // Parse back from vec let parsed = registry.try_from_bytes(&vec).unwrap(); - let mut iter = registry.iter(&parsed); + let mut iter = registry.parse_addr(&parsed); if let Some(Protocol::Ip4(ip)) = iter.next() { assert_eq!(ip, std::net::Ipv4Addr::new(127, 0, 0, 1)); } else { @@ -382,7 +387,7 @@ mod tests { // Assert tcp works let addr = registry.try_from_str("/ip4/127.0.0.1/tcp/80").unwrap(); - let mut iter = registry.iter(&addr); + let mut iter = registry.parse_addr(&addr); assert!(matches!(iter.next(), Some(Protocol::Ip4(_)))); assert!(matches!(iter.next(), Some(Protocol::Tcp(80)))); @@ -395,7 +400,7 @@ mod tests { // And similarly from bytes, it will now parse as Tcp since we natively fallback to standard protocols let vec = addr.to_vec(); let parsed_unknown = registry.try_from_bytes(&vec).unwrap(); - let mut parsed_iter = registry.iter(&parsed_unknown); + let mut parsed_iter = registry.parse_addr(&parsed_unknown); assert!(matches!(parsed_iter.next(), Some(Protocol::Ip4(_)))); assert!(matches!(parsed_iter.next(), Some(Protocol::Tcp(80)))); } @@ -435,4 +440,47 @@ mod tests { "Round-trip multiaddr bytes must match the original instance precisely" ); } + + #[test] + fn test_custom_protocol_size_zero() { + let mut registry = Registry::new(); + // Register a custom protocol with size 0 (no data payload expected) + registry.register(CustomProtocolDef::new( + "my-empty", + 1000, + 0, + false, + None::, + )); + + // It should parse without needing an additional value + let addr = registry.try_from_str("/ip4/127.0.0.1/my-empty").unwrap(); + + let vec = addr.to_vec(); + let parsed = registry.try_from_bytes(&vec).unwrap(); + + let mut iter = registry.parse_addr(&parsed); + if let Some(Protocol::Ip4(ip)) = iter.next() { + assert_eq!(ip, std::net::Ipv4Addr::new(127, 0, 0, 1)); + } else { + panic!("expected ip4"); + } + + if let Some(Protocol::Custom { def, data }) = iter.next() { + assert_eq!(def.code, 1000); + assert_eq!(def.name, "my-empty"); + assert!(data.is_empty()); + } else { + panic!("expected custom protocol"); + } + + // Ensure that a subsequent protocol is parsed correctly, not consumed as data + let addr2 = registry.try_from_str("/my-empty/tcp/80").unwrap(); + let mut iter2 = registry.parse_addr(&addr2); + + assert!( + matches!(iter2.next(), Some(Protocol::Custom { def, .. }) if def.name == "my-empty") + ); + assert!(matches!(iter2.next(), Some(Protocol::Tcp(80)))); + } } From 71ddb831fc7d3a6ef988dc2f2f8d741b054e9514 Mon Sep 17 00:00:00 2001 From: Will Scott Date: Sat, 25 Apr 2026 11:17:58 +0200 Subject: [PATCH 5/5] use base64 --- src/custom.rs | 8 ++++---- src/protocol.rs | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/custom.rs b/src/custom.rs index 2452627..49b07fd 100644 --- a/src/custom.rs +++ b/src/custom.rs @@ -274,7 +274,7 @@ impl Registry { percent_encoding::percent_decode(part.as_bytes()).collect::>() } else { let part = iter.next().ok_or(Error::InvalidProtocolString)?; - multibase::Base::Base58Btc + multibase::Base::Base64Url .decode(part) .map_err(|_| Error::InvalidProtocolString)? }; @@ -427,9 +427,9 @@ mod tests { // 2. Native Multiaddr printing gracefully falls back to unknown without panicking let native_printed = addr.to_string(); - // Native printing uses base58 for the rest of the bytes (the length varint and data). - // For size=-1, the length varint `10` followed by "helloworld" becomes '3ah4EQvnau95Y8K' - assert_eq!(native_printed, "/ip4/127.0.0.1/unknown-999/3ah4EQvnau95Y8K"); + // Native printing uses base64url for the rest of the bytes (the length varint and data). + // For size=-1, the length varint `10` followed by "helloworld" becomes 'CmhlbGxvd29ybGQ' + assert_eq!(native_printed, "/ip4/127.0.0.1/unknown-999/CmhlbGxvd29ybGQ"); // 3. Confirm that the final 'unknown-999' round-trips on parse back to the same multiaddr let parsed_back = native_printed diff --git a/src/protocol.rs b/src/protocol.rs index 2999ca3..6b46e68 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -341,7 +341,7 @@ impl<'a> Protocol<'a> { .map_err(|_| Error::UnknownProtocolString(unknown.to_string()))?; let data = match iter.next() { Some("") => vec![], - Some(s) => match multibase::Base::Base58Btc.decode(s) { + Some(s) => match multibase::Base::Base64Url.decode(s) { Ok(d) => d, Err(_) => return Err(Error::InvalidProtocolString), }, @@ -899,13 +899,13 @@ impl fmt::Display for Protocol<'_> { percent_encoding::percent_encode(s.as_bytes(), PATH_SEGMENT_ENCODE_SET); write!(f, "/{encoded}") } else { - write!(f, "/{}", multibase::Base::Base58Btc.encode(data.as_ref())) + write!(f, "/{}", multibase::Base::Base64Url.encode(data.as_ref())) } } #[cfg(feature = "custom")] Unknown(_, data) => { if !data.is_empty() { - write!(f, "/{}", multibase::Base::Base58Btc.encode(data.as_ref()))?; + write!(f, "/{}", multibase::Base::Base64Url.encode(data.as_ref()))?; } Ok(()) }