diff --git a/.github/workflows/publish-crates.yml b/.github/workflows/publish-crates.yml index bfdffde..68f417e 100644 --- a/.github/workflows/publish-crates.yml +++ b/.github/workflows/publish-crates.yml @@ -22,11 +22,20 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Install Rust stable toolchain + - name: Install Rust nightly toolchain uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: nightly + components: rustfmt, clippy + + - name: Check formatting + run: cargo +nightly fmt -- --check + + - name: Run clippy + run: cargo +nightly clippy --all-targets --all-features -- -D warnings - name: Test crate - run: cargo test --all-features --all-targets + run: cargo +nightly test --all-features --all-targets - name: Authenticate to crates.io if: github.ref_type == 'tag' && startsWith(github.ref_name, 'v') @@ -47,7 +56,7 @@ jobs: fi package_name=dyns - package_version="$(cargo metadata --no-deps --format-version 1 | python3 -c 'import json, sys; print(json.load(sys.stdin)["packages"][0]["version"])')" + package_version="$(cargo +nightly metadata --no-deps --format-version 1 | python3 -c 'import json, sys; print(json.load(sys.stdin)["packages"][0]["version"])')" crate_state="$( python3 - <<'PY' "$package_name" "$package_version" @@ -97,9 +106,9 @@ jobs: if [[ "$mode" == "dry-run" ]]; then echo "dry-run $package_name $package_version" - cargo publish --dry-run --locked + cargo +nightly publish --dry-run exit 0 fi echo "publish $package_name $package_version" - cargo publish --locked + cargo +nightly publish diff --git a/Cargo.toml b/Cargo.toml index abb3316..264ecc5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "dyns" description = "DNS discovery and resolver support for DHTTP applications" -version = "0.4.0" +version = "0.5.0" edition = "2024" license = "Apache-2.0" repository = "https://github.com/genmeta/ddns" @@ -20,7 +20,7 @@ bitfield-struct = "0.13" bytes = "1" dashmap = { version = "6", optional = true } dhttp-identity = "0.2.0" -dquic = "0.5.1" +dquic = "0.6.0" flume = { version = "0.12", optional = true } futures = "0.3" libc = { version = "0.2", optional = true } @@ -47,7 +47,7 @@ tokio = { version = "1", features = [ tracing = "0.1" x509-parser = { version = "0.18", features = ["verify"] } -h3x = { version = "0.4.0", default-features = false, optional = true } +h3x = { version = "0.5.0", default-features = false, optional = true } http = { version = "1", optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } @@ -80,7 +80,7 @@ mdns = ["dep:dashmap", "dep:flume", "dep:libc", "dep:socket2"] [dev-dependencies] clap = { version = "4", features = ["derive"] } -h3x = { version = "0.4.0", default-features = false, features = ["dquic"] } +h3x = { version = "0.5.0", default-features = false, features = ["dquic"] } shellexpand = "3" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/core/wire.rs b/src/core/wire.rs index 028e455..57b5d52 100644 --- a/src/core/wire.rs +++ b/src/core/wire.rs @@ -8,8 +8,12 @@ /// +-----------+ | u32 BE | ... | u32 BE | ... | /// +-----------+------+-----------+------+ /// ``` +use std::borrow::Borrow; + use bytes::BufMut; +use dhttp_identity::certificate::CertificateChainKey; use nom::{IResult, bytes::streaming::take, number::streaming::be_u32}; +use rustls::pki_types::CertificateDer; use crate::core::signature::SignatureFields; @@ -48,6 +52,17 @@ impl ResponseRecord { let digest = digest(&SHA256, &self.cert); Some(digest.as_ref().iter().map(|b| format!("{b:02x}")).collect()) } + + pub fn publisher_certificate_chain_key(&self) -> Option { + if self.cert.is_empty() { + return None; + } + + let cert = CertificateDer::from(self.cert.clone()); + dhttp_identity::identity::extract_dhttp_subject_key_identifier(std::slice::from_ref(&cert)) + .ok() + .map(|ski| ski.chain().clone()) + } } /// HTTP response body carrying zero or more DNS records. @@ -86,6 +101,26 @@ impl MultiResponse { buf.put_multi_response(self); buf } + + pub fn encode_records(records: I) -> Vec + where + I: IntoIterator, + R: Borrow, + { + let mut buf = Vec::new(); + buf.put_u32(0); + + let mut count = 0u32; + for record in records { + count = count + .checked_add(1) + .expect("multi response record count exceeds u32 range"); + put_response_record(&mut buf, record.borrow()); + } + + buf[..4].copy_from_slice(&count.to_be_bytes()); + buf + } } pub trait WriteMultiResponse { @@ -96,15 +131,19 @@ impl WriteMultiResponse for B { fn put_multi_response(&mut self, response: &MultiResponse) { self.put_u32(response.records.len() as u32); for record in &response.records { - put_field(self, &record.signature_fields.content_digest); - put_field(self, &record.signature_fields.signature_input); - put_field(self, &record.signature_fields.signature); - put_field(self, &record.dns); - put_field(self, &record.cert); + put_response_record(self, record); } } } +fn put_response_record(buf: &mut B, record: &ResponseRecord) { + put_field(buf, &record.signature_fields.content_digest); + put_field(buf, &record.signature_fields.signature_input); + put_field(buf, &record.signature_fields.signature); + put_field(buf, &record.dns); + put_field(buf, &record.cert); +} + fn put_field(buf: &mut B, value: &[u8]) { buf.put_u32(value.len() as u32); buf.put_slice(value); @@ -164,4 +203,36 @@ mod tests { assert!(remain.is_empty()); assert_eq!(decoded, response); } + + #[test] + fn encode_records_matches_multi_response_encoding_for_owned_records() { + let records = [ + ResponseRecord::unsigned(vec![1, 2, 3], vec![4, 5]), + ResponseRecord::new( + SignatureFields { + content_digest: b"sha-256=:abc:".to_vec(), + signature_input: b"dns=(\"content-digest\")".to_vec(), + signature: b"dns=:sig:".to_vec(), + }, + vec![6, 7], + Vec::new(), + ), + ]; + let response = MultiResponse::new(records.clone()); + + assert_eq!(MultiResponse::encode_records(records), response.encode()); + } + + #[test] + fn encode_records_matches_multi_response_encoding_for_borrowed_records() { + let response = MultiResponse::new([ + ResponseRecord::unsigned(vec![1, 2, 3], vec![4, 5]), + ResponseRecord::unsigned(vec![6, 7], Vec::new()), + ]); + + assert_eq!( + MultiResponse::encode_records(response.records.iter()), + response.encode() + ); + } } diff --git a/src/h3.rs b/src/h3.rs index 760e89b..2898783 100644 --- a/src/h3.rs +++ b/src/h3.rs @@ -248,7 +248,7 @@ mod tests { )); let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); resolver.cache.insert_positive( - "nat.genmeta.net", + "nat.genmeta.net:20004", vec![EndpointAddr::direct("192.0.2.10:21000".parse().unwrap())], ); diff --git a/src/h3/cache.rs b/src/h3/cache.rs index 3f143ad..3aab7f4 100644 --- a/src/h3/cache.rs +++ b/src/h3/cache.rs @@ -76,4 +76,20 @@ mod tests { assert!(cache.negative_hit("missing.dhttp.net")); } + + #[test] + fn positive_cache_hit_keeps_selector_entries_separate() { + let cache = LookupCache::default(); + cache.insert_positive("demo.dhttp.net", vec![endpoint("192.0.2.10:4433")]); + cache.insert_positive("demo.dhttp.net:1", vec![endpoint("192.0.2.11:4433")]); + + assert_eq!( + cache.positive_hit("demo.dhttp.net").unwrap(), + vec![endpoint("192.0.2.10:4433")] + ); + assert_eq!( + cache.positive_hit("demo.dhttp.net:1").unwrap(), + vec![endpoint("192.0.2.11:4433")] + ); + } } diff --git a/src/h3/lookup.rs b/src/h3/lookup.rs index 0f99bf6..902e505 100644 --- a/src/h3/lookup.rs +++ b/src/h3/lookup.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use dhttp_identity::certificate::CertificateSequence; use dquic::qresolve::{RecordStream, Source}; use futures::{StreamExt, stream}; use h3x::quic; @@ -15,11 +16,16 @@ use crate::core::{parser::packet::be_packet, wire::be_multi_response}; const LOOKUP_API_PATH: &str = "/api/v2/lookup"; -fn lookup_url(base_url: &url::Url, name: &str) -> url::Url { +fn lookup_url(base_url: &url::Url, name: &str, sequence: Option) -> url::Url { let mut url = base_url .join(LOOKUP_API_PATH) .expect("h3 dns lookup api path must be valid"); url.query_pairs_mut().append_pair("host", name); + if let Some(sequence) = sequence { + let sequence_text = sequence.get().to_string(); + url.query_pairs_mut() + .append_pair("sequence", &sequence_text); + } url } @@ -29,7 +35,11 @@ pub(super) struct LookupRecords { } impl LookupRecords { - pub(super) fn decode(domain: &str, response: &[u8]) -> Result { + pub(super) fn decode( + domain: &str, + sequence: Option, + response: &[u8], + ) -> Result { use crate::core::parser::record; let (remain, multi) = match be_multi_response(response) { @@ -42,6 +52,7 @@ impl LookupRecords { let mut endpoint_records = Vec::new(); for r in multi.records { + let publisher_chain_key = r.publisher_certificate_chain_key(); if !r.signature_fields.is_empty() { match r.signature_fields.verify(&r.dns, &r.cert) { Ok(true) => {} @@ -79,7 +90,7 @@ impl LookupRecords { ); return None; } - Some(ep.clone()) + Some((ep.clone(), publisher_chain_key.clone())) } _ => { tracing::debug!(?answer, "ignored record"); @@ -90,7 +101,16 @@ impl LookupRecords { } Ok(Self { - endpoints: crate::resolvers::endpoint_group::selected_endpoint_addrs(endpoint_records), + endpoints: + crate::resolvers::endpoint_group::selected_endpoint_records_with_fallback_chain_keys( + endpoint_records + .into_iter() + .map(|(endpoint, fallback_chain_key)| ((), endpoint, fallback_chain_key)), + sequence, + ) + .into_iter() + .map(|((), endpoint)| endpoint) + .collect(), }) } } @@ -175,45 +195,46 @@ where let server = Arc::from(self.base_url.origin().ascii_serialization()); let source = Source::H3 { server }; - let Some(domain) = crate::resolvers::resolvable_name(name) else { + let Some((domain, sequence)) = crate::resolvers::endpoint_lookup_name_and_sequence(name) + else { return Err(H3LookupError::NoRecordFound); }; let now = Instant::now(); self.cache.prune_expired(now); - if self.cache.negative_hit(domain) { + if self.cache.negative_hit(name) { return Err(H3LookupError::NoRecordFound); } - if let Some(addrs) = self.cache.positive_hit(domain) { + if let Some(addrs) = self.cache.positive_hit(name) { let stream = stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))); return Ok(stream.boxed()); } - let url = lookup_url(&self.base_url, domain); + let url = lookup_url(&self.base_url, domain, sequence); let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); tracing::trace!("sending lookup request to {}", self.base_url); let response = match self.lookup_response_with_retry(uri).await { Ok(response) => response, Err(H3LookupError::NoRecordFound) => { - self.cache.insert_negative(domain); + self.cache.insert_negative(name); return Err(H3LookupError::NoRecordFound); } Err(error) => return Err(error), }; - let records = LookupRecords::decode(domain, response.as_ref()) + let records = LookupRecords::decode(domain, sequence, response.as_ref()) .context(h3_lookup_error::DecodeSnafu)?; let addrs = records.endpoints; if addrs.is_empty() { - self.cache.insert_negative(domain); + self.cache.insert_negative(name); return Err(H3LookupError::NoRecordFound); } - self.cache.insert_positive(domain, addrs.clone()); + self.cache.insert_positive(name, addrs.clone()); Ok(stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))).boxed()) } @@ -250,7 +271,7 @@ mod tests { #[test] fn h3_lookup_url_targets_v2_api_from_origin_base() { let base_url = url::Url::parse("https://dns.example.test:4433").expect("url"); - let url = lookup_url(&base_url, "demo.dhttp.net"); + let url = lookup_url(&base_url, "demo.dhttp.net", None); assert_eq!( url.as_str(), @@ -261,7 +282,7 @@ mod tests { #[test] fn h3_lookup_url_does_not_duplicate_v2_base_path() { let base_url = url::Url::parse("https://dns.example.test:4433/api/v2/").expect("url"); - let url = lookup_url(&base_url, "demo.dhttp.net"); + let url = lookup_url(&base_url, "demo.dhttp.net", None); assert_eq!( url.as_str(), @@ -269,6 +290,21 @@ mod tests { ); } + #[test] + fn h3_lookup_url_appends_sequence_query() { + let base_url = url::Url::parse("https://dns.example.test:4433").expect("url"); + let url = lookup_url( + &base_url, + "demo.dhttp.net", + Some(CertificateSequence::from(3u8)), + ); + + assert_eq!( + url.as_str(), + "https://dns.example.test:4433/api/v2/lookup?host=demo.dhttp.net&sequence=3" + ); + } + #[test] fn lookup_records_select_primary_group() { let response = response_for( @@ -281,7 +317,7 @@ mod tests { ], ); - let records = LookupRecords::decode("demo.dhttp.net", &response).expect("records"); + let records = LookupRecords::decode("demo.dhttp.net", None, &response).expect("records"); assert_eq!(records.endpoints.len(), 2); assert_eq!( @@ -298,8 +334,34 @@ mod tests { fn lookup_records_ignore_answer_name_mismatch() { let response = response_for("other.dhttp.net", vec![direct("192.0.2.50:4433", true, 1)]); - let records = LookupRecords::decode("demo.dhttp.net", &response).expect("records"); + let records = LookupRecords::decode("demo.dhttp.net", None, &response).expect("records"); assert!(records.endpoints.is_empty()); } + + #[test] + fn lookup_records_filter_requested_primary_sequence() { + let response = response_for( + "demo.dhttp.net", + vec![ + direct("192.0.2.10:4433", true, 0), + direct("192.0.2.11:4433", true, 0), + direct("192.0.2.20:4433", true, 1), + ], + ); + + let records = LookupRecords::decode( + "demo.dhttp.net", + Some(CertificateSequence::from(1u8)), + &response, + ) + .expect("records"); + + assert_eq!( + records.endpoints, + vec![dquic::qbase::net::addr::EndpointAddr::direct( + "192.0.2.20:4433".parse().unwrap() + )] + ); + } } diff --git a/src/http.rs b/src/http.rs index f6c0090..1a00659 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,6 +1,7 @@ use std::{fmt::Display, io, sync::Arc}; use dashmap::DashMap; +use dhttp_identity::certificate::CertificateSequence; use dquic::{ qbase::net::addr::EndpointAddr, qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}, @@ -31,8 +32,14 @@ pub struct HttpResolver { cached_records: DashMap, } -fn lookup_url(base_url: &Url, name: &str) -> Url { - api_url(base_url, LOOKUP_API_PATH, name) +fn lookup_url(base_url: &Url, name: &str, sequence: Option) -> Url { + let mut url = api_url(base_url, LOOKUP_API_PATH, name); + if let Some(sequence) = sequence { + let sequence_text = sequence.get().to_string(); + url.query_pairs_mut() + .append_pair("sequence", &sequence_text); + } + url } fn publish_url(base_url: &Url, name: &str) -> Url { @@ -189,28 +196,30 @@ impl Publish for HttpResolver { impl Resolve for HttpResolver { fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { let lookup = async move { - let Some(domain) = crate::resolvers::resolvable_name(name) else { + let Some((domain, sequence)) = + crate::resolvers::endpoint_lookup_name_and_sequence(name) + else { return Err(Error::NoRecordFound); }; let now = Instant::now(); let server = Arc::from(self.base_url.host_str().unwrap_or("")); - let soource = Source::Http { server }; + let source = Source::Http { server }; use crate::core::parser::record; self.cached_records - .retain(|_host, Record { expire, .. }| *expire < now); - if let Some(record) = self.cached_records.get(domain) { + .retain(|_host, Record { expire, .. }| *expire > now); + if let Some(record) = self.cached_records.get(name) { let endpoint_addrs: Vec<_> = record .addrs .iter() - .map(|endpoint: &EndpointAddr| (soource.clone(), *endpoint)) + .map(|endpoint: &EndpointAddr| (source.clone(), *endpoint)) .collect(); return Ok(stream::iter(endpoint_addrs).boxed()); } let response = self .http_client - .get(lookup_url(&self.base_url, domain)) + .get(lookup_url(&self.base_url, domain, sequence)) .send() .await; @@ -221,8 +230,9 @@ impl Resolve for HttpResolver { return Err(Error::ParseMultiResponse); } - let mut addrs = Vec::new(); + let mut endpoint_records = Vec::new(); for r in multi.records { + let publisher_chain_key = r.publisher_certificate_chain_key(); if !r.signature_fields.is_empty() { match r.signature_fields.verify(&r.dns, &r.cert) { Ok(true) => {} @@ -241,45 +251,50 @@ impl Resolve for HttpResolver { source: source.to_owned(), })?; - addrs.extend( - packet - .answers - .iter() - .filter_map(|answer| match answer.data() { - record::RData::E(ep) => { - if answer.name() != domain { - tracing::debug!( - answer_name = %answer.name(), - query = domain, - "ignored endpoint answer for different name" - ); - return None; - } - let endpoint = - TryInto::::try_into(ep.clone()).ok()?; - Some(endpoint) + endpoint_records.extend(packet.answers.iter().filter_map(|answer| { + match answer.data() { + record::RData::E(ep) => { + if answer.name() != domain { + tracing::debug!( + answer_name = %answer.name(), + query = domain, + "ignored endpoint answer for different name" + ); + return None; } - _ => { - tracing::debug!(?answer, "ignored record"); - None - } - }), - ); + Some((ep.clone(), publisher_chain_key.clone())) + } + _ => { + tracing::debug!(?answer, "ignored record"); + None + } + } + })); } + let addrs = + crate::resolvers::endpoint_group::selected_endpoint_records_with_fallback_chain_keys( + endpoint_records + .into_iter() + .map(|(endpoint, fallback_chain_key)| ((), endpoint, fallback_chain_key)), + sequence, + ) + .into_iter() + .map(|((), endpoint)| endpoint) + .collect::>(); if addrs.is_empty() { return Err(Error::NoRecordFound); } // cache the addrs self.cached_records.insert( - domain.to_string(), + name.to_string(), Record { addrs: addrs.clone(), expire: now + std::time::Duration::from_secs(300), }, ); - Ok(stream::iter(addrs.into_iter().map(move |ep| (soource.clone(), ep))).boxed()) + Ok(stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))).boxed()) }; Box::pin(lookup.map_err(io::Error::other)) } @@ -303,11 +318,26 @@ mod tests { #[test] fn http_lookup_url_does_not_duplicate_v2_base_path() { let base_url = Url::parse("https://dns.example.test/api/v2/").expect("url"); - let url = lookup_url(&base_url, "demo.dhttp.net"); + let url = lookup_url(&base_url, "demo.dhttp.net", None); assert_eq!( url.as_str(), "https://dns.example.test/api/v2/lookup?host=demo.dhttp.net" ); } + + #[test] + fn http_lookup_url_appends_sequence_query() { + let base_url = Url::parse("https://dns.example.test").expect("url"); + let url = lookup_url( + &base_url, + "demo.dhttp.net", + Some(CertificateSequence::from(7u8)), + ); + + assert_eq!( + url.as_str(), + "https://dns.example.test/api/v2/lookup?host=demo.dhttp.net&sequence=7" + ); + } } diff --git a/src/mdns.rs b/src/mdns.rs index fce3ee9..f2ea5f8 100644 --- a/src/mdns.rs +++ b/src/mdns.rs @@ -57,9 +57,16 @@ impl Publish for MdnsPublisher { impl Resolve for MdnsResolver { fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { let source = self.source(); - self.query(name.to_owned()) + let Some((domain, sequence)) = crate::resolvers::endpoint_lookup_name_and_sequence(name) + else { + return future::ready(Err(io::Error::other("no DNS record found"))).boxed(); + }; + self.query(domain.to_owned()) .map_ok(move |list| { - let endpoints = crate::resolvers::endpoint_group::selected_endpoint_addrs(list); + let endpoints = + crate::resolvers::endpoint_group::selected_endpoint_addrs_for_sequence( + list, sequence, + ); stream::iter(endpoints.into_iter().map(move |ep| (source.clone(), ep))).boxed() }) .boxed() @@ -257,6 +264,14 @@ impl MdnsResolvers { } pub async fn query(&self, name: &str) -> io::Result { + self.query_with_sequence(name, None).await + } + + pub async fn query_with_sequence( + &self, + name: &str, + sequence: Option, + ) -> io::Result { let mut lookup_futures = FuturesUnordered::new(); let mut has_resolver = false; self.for_each_resolver(|resolver| { @@ -295,7 +310,9 @@ impl MdnsResolvers { ); } - let records = crate::resolvers::endpoint_group::selected_endpoint_records(records); + let records = crate::resolvers::endpoint_group::selected_endpoint_records_for_sequence( + records, sequence, + ); Ok(stream::iter(records).boxed()) } @@ -354,6 +371,10 @@ impl Publish for MdnsResolvers { #[cfg(feature = "dquic-network")] impl Resolve for MdnsResolvers { fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - self.query(name).boxed() + let Some((domain, sequence)) = crate::resolvers::endpoint_lookup_name_and_sequence(name) + else { + return future::ready(Err(io::Error::other("no DNS record found"))).boxed(); + }; + self.query_with_sequence(domain, sequence).boxed() } } diff --git a/src/publishers.rs b/src/publishers.rs index 0ce35a2..bd56e37 100644 --- a/src/publishers.rs +++ b/src/publishers.rs @@ -8,7 +8,16 @@ mod packet; mod publisher; #[cfg(all(feature = "publishers", feature = "dquic-network"))] -use std::{any::TypeId, net::SocketAddr, time::Duration}; +use std::{ + future::Future, + pin::Pin, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, + task::{Context, Poll}, + time::Duration, +}; #[cfg(feature = "publishers")] pub use address::{ @@ -22,9 +31,7 @@ pub use aggregate::{Publishers, PublishersError}; #[cfg(feature = "publishers")] use dhttp_identity::name::Name; #[cfg(all(feature = "publishers", feature = "dquic-network"))] -use dquic::{ - qinterface::component::location::AddressEvent, qtraversal::nat::client::ClientLocationData, -}; +use dquic::qinterface::component::local_endpoint::InterfaceEndpointUpdate; #[cfg(feature = "publishers")] pub use publisher::{Publisher, PublisherError}; @@ -42,11 +49,38 @@ pub const DEFAULT_PUBLISH_INTERVAL: Duration = Duration::from_secs(20); /// Network changes can leave an in-flight publish waiting on paths that no /// longer exist. Timing out the attempt keeps consecutive publishes /// independent: the next interval observes the current bindings again. +/// +/// This timeout must stay above the QUIC endpoint's default connect-path +/// timeout (20s) so a publish attempt can survive the transport's own path +/// discovery window instead of being aborted before connect has a chance to +/// complete. #[cfg(all(feature = "publishers", feature = "dquic-network"))] -pub const DEFAULT_PUBLISH_TIMEOUT: Duration = Duration::from_secs(10); +pub const DEFAULT_PUBLISH_TIMEOUT: Duration = Duration::from_secs(30); #[cfg(all(feature = "publishers", feature = "dquic-network"))] const PUBLISH_CHANGE_DEBOUNCE: Duration = Duration::from_millis(50); +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +struct ScheduledPublish<'a> { + future: futures::future::BoxFuture<'a, ()>, + attempt_started: Arc, +} + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +impl ScheduledPublish<'_> { + fn attempt_started(&self) -> bool { + self.attempt_started.load(Ordering::SeqCst) + } +} + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +impl Future for ScheduledPublish<'_> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.future.as_mut().poll(cx) + } +} + #[cfg(all(feature = "publishers", feature = "dquic-network"))] pub struct EndpointPublicationLoop { name: Name<'static>, @@ -106,46 +140,59 @@ where } pub async fn run(&self) -> ! { - let mut locations = self.source.subscribe(); + let mut local_endpoints = self.source.subscribe(); let interval = tokio::time::sleep(self.interval); tokio::pin!(interval); - let mut current_publish = self.new_publish_loop_future(); + let mut current_publish = Some(self.new_publish_loop_future()); loop { tokio::select! { - _ = &mut current_publish => { - current_publish = Self::pending_publish_loop_future(); + _ = async { + match current_publish.as_mut() { + Some(current_publish) => current_publish.await, + None => std::future::pending::<()>().await, + } + } => { + current_publish = None; } _ = &mut interval => { interval.as_mut().reset(tokio::time::Instant::now() + self.interval); - current_publish = self.new_publish_loop_future(); + if current_publish + .as_ref() + .is_some_and(ScheduledPublish::attempt_started) + { + continue; + } + current_publish = Some(self.new_publish_loop_future()); } - event = locations.recv() => { - let Some((bind_uri, event)) = event else { + update = local_endpoints.recv() => { + let Some((bind_uri, update)) = update else { continue; }; if !self.source.observes(&bind_uri) { continue; } - if !Self::location_event_requires_publish(&event) { + if !Self::local_endpoint_update_requires_publish(&update) { continue; } - current_publish = self.new_publish_loop_future(); + current_publish = Some(self.new_publish_loop_future()); } } } } - fn new_publish_loop_future(&self) -> futures::future::BoxFuture<'_, ()> { - Box::pin(async move { - tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE).await; - let _ = self.publish_attempt().await; - }) - } - - fn pending_publish_loop_future<'a>() -> futures::future::BoxFuture<'a, ()> { - Box::pin(std::future::pending()) + fn new_publish_loop_future(&self) -> ScheduledPublish<'_> { + let attempt_started = Arc::new(AtomicBool::new(false)); + let mark_attempt_started = attempt_started.clone(); + ScheduledPublish { + future: Box::pin(async move { + tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE).await; + mark_attempt_started.store(true, Ordering::SeqCst); + let _ = self.publish_attempt().await; + }), + attempt_started, + } } async fn publish_attempt(&self) -> bool { @@ -180,22 +227,340 @@ where } } - fn location_event_requires_publish(event: &AddressEvent) -> bool { - match event { - AddressEvent::Upsert(data) => { - if let Some(bound_addr) = data.downcast_ref::>() { - return bound_addr.is_ok(); - } - if let Some(stun_addr) = data.downcast_ref::() { - return stun_addr.is_ok(); + fn local_endpoint_update_requires_publish(update: &InterfaceEndpointUpdate) -> bool { + match update { + InterfaceEndpointUpdate::Upsert { .. } + | InterfaceEndpointUpdate::Remove { .. } + | InterfaceEndpointUpdate::Close => true, + } + } +} + +#[cfg(all(test, feature = "publishers", feature = "dquic-network"))] +mod tests { + use std::{ + fmt, + sync::{ + Arc, + atomic::{AtomicBool, AtomicUsize, Ordering}, + }, + time::Duration, + }; + + use dhttp_identity::name::Name; + use dquic::{ + qbase::net::addr::EndpointAddr, + qinterface::component::local_endpoint::{ + InterfaceEndpointKey, InterfaceEndpointUpdate, LocalEndpointSubscriber, LocalEndpoints, + }, + qresolve::{Publish, PublishFuture}, + }; + use futures::FutureExt; + use h3x::dquic::net::BindUri; + use tokio::sync::Notify; + + use super::{ + AddressView, AddressViewSource, EndpointPublicationLoop, PublishAddresses, PublishScope, + Publisher, Publishers, + }; + + #[derive(Debug, Default)] + struct PublishState { + started: AtomicUsize, + completed: AtomicUsize, + canceled: AtomicUsize, + releases: AtomicUsize, + release_notify: Notify, + } + + impl PublishState { + fn allow_attempts(&self, count: usize) { + self.releases.store(count, Ordering::SeqCst); + self.release_notify.notify_waiters(); + } + } + + struct AttemptGuard { + state: Arc, + completed: AtomicBool, + } + + impl AttemptGuard { + fn new(state: Arc) -> Self { + Self { + state, + completed: AtomicBool::new(false), + } + } + + fn complete(&self) { + self.completed.store(true, Ordering::SeqCst); + } + } + + impl Drop for AttemptGuard { + fn drop(&mut self) { + if !self.completed.load(Ordering::SeqCst) { + self.state.canceled.fetch_add(1, Ordering::SeqCst); + } + } + } + + #[derive(Debug)] + struct BlockingPublisher { + state: Arc, + } + + impl fmt::Display for BlockingPublisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("blocking publisher") + } + } + + impl Publish for BlockingPublisher { + fn publish<'a>(&'a self, _name: &'a str, _packet: &'a [u8]) -> PublishFuture<'a> { + let state = self.state.clone(); + async move { + let attempt = state.started.fetch_add(1, Ordering::SeqCst) + 1; + + let guard = AttemptGuard::new(state.clone()); + loop { + if state.releases.load(Ordering::SeqCst) >= attempt { + guard.complete(); + state.completed.fetch_add(1, Ordering::SeqCst); + return Ok(()); + } + + state.release_notify.notified().await; } - false } - AddressEvent::Remove(type_id) => { - *type_id == TypeId::of::>() - || *type_id == TypeId::of::() + .boxed() + } + } + + #[derive(Debug)] + struct DelayedPublisher { + delay: Duration, + } + + impl fmt::Display for DelayedPublisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("delayed publisher") + } + } + + impl Publish for DelayedPublisher { + fn publish<'a>(&'a self, _name: &'a str, _packet: &'a [u8]) -> PublishFuture<'a> { + let delay = self.delay; + async move { + tokio::time::sleep(delay).await; + Ok(()) + } + .boxed() + } + } + + #[derive(Clone)] + struct TestSource { + bind_uri: BindUri, + local_endpoints: Arc, + addresses: PublishAddresses, + } + + impl TestSource { + fn new(bind_uri: BindUri) -> Self { + Self { + bind_uri, + local_endpoints: Arc::new(LocalEndpoints::new()), + addresses: PublishAddresses::new().wide_area([EndpointAddr::direct( + "127.0.0.1:4433".parse().expect("socket address"), + )]), } - AddressEvent::Closed => true, } + + fn notify_publishable_local_endpoint(&self) { + let publishers = self.local_endpoints.publisher(self.bind_uri.clone()); + let mut direct = publishers + .direct_endpoint_publisher() + .expect("direct endpoint publisher"); + assert!(direct.upsert("127.0.0.1:4433".parse().expect("socket address"))); + } + } + + impl AddressViewSource for TestSource { + fn address_view(&self) -> impl AddressView + Send + Sync + '_ { + self.addresses.clone() + } + + fn subscribe(&self) -> LocalEndpointSubscriber { + self.local_endpoints.subscribe() + } + + fn observes(&self, bind_uri: &BindUri) -> bool { + bind_uri == &self.bind_uri + } + } + + async fn wait_until(description: &str, timeout: Duration, predicate: impl Fn() -> bool) { + let deadline = tokio::time::Instant::now() + timeout; + loop { + if predicate() { + return; + } + + if tokio::time::Instant::now() >= deadline { + panic!("timed out waiting for {description}"); + } + + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + + fn name() -> Name<'static> { + "alice.dhttp.net".parse().expect("valid name") + } + + #[test] + fn typed_local_endpoint_updates_require_publication_refresh() { + let direct = InterfaceEndpointUpdate::Upsert { + key: InterfaceEndpointKey::Direct, + endpoint: EndpointAddr::direct("127.0.0.1:12345".parse().expect("addr")), + }; + let remove = InterfaceEndpointUpdate::Remove { + key: InterfaceEndpointKey::Direct, + }; + let close = InterfaceEndpointUpdate::Close; + + assert!( + EndpointPublicationLoop::::local_endpoint_update_requires_publish(&direct) + ); + assert!( + EndpointPublicationLoop::::local_endpoint_update_requires_publish(&remove) + ); + assert!( + EndpointPublicationLoop::::local_endpoint_update_requires_publish(&close) + ); + } + + #[tokio::test] + async fn publication_loop_replaces_inflight_publish_when_location_changes() { + let bind_uri: BindUri = "iface://v4.eth0:0/".parse().expect("bind uri"); + let source = TestSource::new(bind_uri); + let state = Arc::new(PublishState::default()); + let publishers = Publishers::new().with(Publisher::new( + PublishScope::WideArea, + Arc::new(BlockingPublisher { + state: state.clone(), + }), + )); + let loop_ = EndpointPublicationLoop::new(name(), publishers, source.clone()) + .with_interval(Duration::from_secs(60)) + .with_publish_timeout(Duration::from_secs(60)); + + let task = tokio::spawn(async move { + loop_.run().await; + }); + + wait_until( + "initial publish attempt to start", + Duration::from_secs(1), + || state.started.load(Ordering::SeqCst) == 1, + ) + .await; + + source.notify_publishable_local_endpoint(); + + wait_until( + "replacement publish attempt to start after the location update", + Duration::from_secs(1), + || state.started.load(Ordering::SeqCst) == 2, + ) + .await; + + assert_eq!( + state.canceled.load(Ordering::SeqCst), + 1, + "location updates should cancel the stale in-flight publish attempt" + ); + + state.allow_attempts(2); + wait_until( + "replacement publish attempt to complete", + Duration::from_secs(1), + || state.completed.load(Ordering::SeqCst) == 1, + ) + .await; + + task.abort(); + let _ = task.await; + } + + #[tokio::test] + async fn publication_loop_interval_does_not_cancel_active_publish_attempt() { + let bind_uri: BindUri = "iface://v4.eth0:0/".parse().expect("bind uri"); + let source = TestSource::new(bind_uri); + let state = Arc::new(PublishState::default()); + let publishers = Publishers::new().with(Publisher::new( + PublishScope::WideArea, + Arc::new(BlockingPublisher { + state: state.clone(), + }), + )); + let loop_ = EndpointPublicationLoop::new(name(), publishers, source) + .with_interval(Duration::from_millis(120)) + .with_publish_timeout(Duration::from_secs(1)); + + let task = tokio::spawn(async move { + loop_.run().await; + }); + + wait_until( + "initial publish attempt to start", + Duration::from_secs(1), + || state.started.load(Ordering::SeqCst) == 1, + ) + .await; + + tokio::time::sleep(Duration::from_millis(100)).await; + + assert_eq!( + state.started.load(Ordering::SeqCst), + 1, + "interval ticks should not schedule a replacement while the publish is still active" + ); + assert_eq!( + state.canceled.load(Ordering::SeqCst), + 0, + "interval ticks should not cancel the active publish attempt" + ); + + state.allow_attempts(1); + wait_until( + "active publish attempt to complete", + Duration::from_secs(1), + || state.completed.load(Ordering::SeqCst) == 1, + ) + .await; + + task.abort(); + let _ = task.await; + } + + #[tokio::test] + async fn default_publish_timeout_allows_slow_publish_attempts_to_finish() { + let bind_uri: BindUri = "iface://v4.eth0:0/".parse().expect("bind uri"); + let source = TestSource::new(bind_uri); + let publishers = Publishers::new().with(Publisher::new( + PublishScope::WideArea, + Arc::new(DelayedPublisher { + delay: Duration::from_secs(11), + }), + )); + let loop_ = EndpointPublicationLoop::new(name(), publishers, source); + + assert!( + loop_.publish_attempt().await, + "default publish timeout should allow a slow publish attempt to finish" + ); } } diff --git a/src/publishers/address.rs b/src/publishers/address.rs index ca8bcfa..a20cef7 100644 --- a/src/publishers/address.rs +++ b/src/publishers/address.rs @@ -6,7 +6,7 @@ use std::{net::SocketAddr, sync::OnceLock}; use dquic::qbase::net::{Family, addr::EndpointAddr}; #[cfg(feature = "dquic-network")] -use dquic::qinterface::component::location::Observer; +use dquic::qinterface::component::local_endpoint::LocalEndpointSubscriber; #[cfg(feature = "dquic-network")] use h3x::dquic::{ Network, @@ -55,7 +55,7 @@ where #[cfg(feature = "dquic-network")] pub trait AddressViewSource { fn address_view(&self) -> impl AddressView + Send + Sync + '_; - fn subscribe(&self) -> Observer; + fn subscribe(&self) -> LocalEndpointSubscriber; fn observes(&self, bind_uri: &BindUri) -> bool; } @@ -199,8 +199,8 @@ impl AddressViewSource for EndpointBindingAddresses { EndpointBindingAddressView::new(self.network.clone(), self.bind_patterns.clone()) } - fn subscribe(&self) -> Observer { - self.network.quic().locations().subscribe() + fn subscribe(&self) -> LocalEndpointSubscriber { + self.network.quic().local_endpoints().subscribe() } fn observes(&self, bind_uri: &BindUri) -> bool { diff --git a/src/publishers/aggregate.rs b/src/publishers/aggregate.rs index 4c61208..4a73789 100644 --- a/src/publishers/aggregate.rs +++ b/src/publishers/aggregate.rs @@ -1,6 +1,7 @@ use std::{error::Error, fmt}; use dhttp_identity::name::Name; +use snafu::Report; use super::{AddressView, Publisher, PublisherError}; @@ -75,7 +76,17 @@ impl Publishers { for publisher in &self.publishers { match publisher.publish(name, view).await { Ok(()) => succeeded = true, - Err(error) => errors.push((publisher.to_string(), error)), + Err(error) => { + let publisher_name = publisher.to_string(); + let report = Report::from_error(&error); + tracing::debug!( + publisher = %publisher_name, + error = %report, + name = %name, + "dns publisher failed" + ); + errors.push((publisher_name, error)); + } } } diff --git a/src/resolvers.rs b/src/resolvers.rs index fd95981..91cf984 100644 --- a/src/resolvers.rs +++ b/src/resolvers.rs @@ -27,7 +27,6 @@ use crate::mdns::MdnsResolvers; /// Extract and validate the DNS host from `name`, which may include a `:port` /// suffix. Returns `Some(host)` if the host part is a valid RFC-compliant DNS /// name, or `None` for raw IP addresses, bracketed IPv6, or malformed input. -#[cfg_attr(not(any(feature = "h3", feature = "http")), allow(dead_code))] pub(crate) fn resolvable_name(name: &str) -> Option<&str> { let host = match name.rsplit_once(':') { Some((h, port)) if !port.is_empty() && port.chars().all(|c| c.is_ascii_digit()) => h, @@ -37,6 +36,32 @@ pub(crate) fn resolvable_name(name: &str) -> Option<&str> { Some(host) } +#[cfg_attr( + not(any(feature = "h3", feature = "http", feature = "mdns")), + allow(dead_code) +)] +pub(crate) fn endpoint_lookup_name_and_sequence( + name: &str, +) -> Option<( + &str, + Option, +)> { + use dhttp_identity::certificate::CertificateSequence; + + let (host, sequence) = match name.rsplit_once(':') { + Some((host, digits)) + if !digits.is_empty() && digits.chars().all(|c| c.is_ascii_digit()) => + { + let sequence = digits.parse::().ok()?; + let sequence = CertificateSequence::try_from(sequence).ok()?; + (host, Some(sequence)) + } + _ => (name, None), + }; + + Some((resolvable_name(host)?, sequence)) +} + /// Default DNS-over-H3 server for DHTTP endpoints. pub const DHTTP_H3_DNS_SERVER: &str = crate::bootstrap::DHTTP_H3_DNS_SERVER; @@ -410,6 +435,35 @@ mod tests { assert_eq!(resolvable_name("[::1]:443"), None); } + #[test] + fn endpoint_lookup_name_and_sequence_accepts_plain_name() { + let (name, sequence) = + super::endpoint_lookup_name_and_sequence("example.dhttp.net").expect("dns name"); + + assert_eq!(name, "example.dhttp.net"); + assert_eq!(sequence, None); + } + + #[test] + fn endpoint_lookup_name_and_sequence_parses_numeric_selector() { + let (name, sequence) = + super::endpoint_lookup_name_and_sequence("reimu.hakurei.dhttp.net:1") + .expect("dns name"); + + assert_eq!(name, "reimu.hakurei.dhttp.net"); + assert_eq!( + sequence.map(dhttp_identity::certificate::CertificateSequence::get), + Some(1) + ); + } + + #[test] + fn endpoint_lookup_name_and_sequence_rejects_out_of_range_selector() { + let invalid = format!("example.dhttp.net:{}", (1u64 << 62) + 1); + + assert_eq!(super::endpoint_lookup_name_and_sequence(&invalid), None); + } + #[cfg(feature = "resolvers")] #[test] fn dns_scheme_round_trips_supported_schemes_and_rejects_dht() { diff --git a/src/resolvers/deferred.rs b/src/resolvers/deferred.rs index a961ac8..c824185 100644 --- a/src/resolvers/deferred.rs +++ b/src/resolvers/deferred.rs @@ -3,7 +3,7 @@ use std::{fmt, io}; use dquic::qresolve::{Publish, PublishFuture, RecordStream, Resolve, ResolveFuture}; use futures::FutureExt; use snafu::{ResultExt, Snafu}; -use tokio::sync::OnceCell; +use tokio::sync::{Notify, OnceCell}; #[derive(Debug, Snafu)] #[snafu(module, visibility(pub))] @@ -32,6 +32,7 @@ pub enum SetDeferredResolverError { pub struct DeferredResolver { inner: OnceCell, + initialized: Notify, } impl fmt::Debug for DeferredResolver { @@ -53,6 +54,7 @@ impl DeferredResolver { pub fn new() -> Self { Self { inner: OnceCell::new(), + initialized: Notify::new(), } } @@ -60,6 +62,7 @@ impl DeferredResolver { if self.inner.set(resolver).is_err() { return set_deferred_resolver_error::AlreadyInitializedSnafu.fail(); } + self.initialized.notify_waiters(); Ok(()) } @@ -67,6 +70,16 @@ impl DeferredResolver { pub fn get(&self) -> Option<&R> { self.inner.get() } + + async fn wait(&self) -> &R { + loop { + let initialized = self.initialized.notified(); + if let Some(resolver) = self.get() { + return resolver; + } + initialized.await; + } + } } impl fmt::Display for DeferredResolver @@ -101,7 +114,7 @@ where R: Resolve + 'static, { fn lookup<'a>(&'a self, name: &'a str) -> ResolveFuture<'a> { - async move { self.lookup_typed(name).await.map_err(io::Error::other) }.boxed() + async move { self.wait().await.lookup(name).await }.boxed() } } @@ -129,18 +142,13 @@ where R: Publish + 'static, { fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { - async move { - self.publish_typed(name, packet) - .await - .map_err(io::Error::other) - } - .boxed() + async move { self.wait().await.publish(name, packet).await }.boxed() } } #[cfg(test)] mod tests { - use std::fmt; + use std::{fmt, time::Duration}; use dquic::{ qbase::net::addr::EndpointAddr, @@ -191,6 +199,28 @@ mod tests { assert!(matches!(error, DeferredLookupError::Uninitialized)); } + #[tokio::test] + async fn resolve_trait_lookup_waits_until_set() { + let resolver = DeferredResolver::new(); + let mut lookup = resolver.lookup("example.test"); + + assert!( + tokio::time::timeout(Duration::from_millis(10), &mut lookup) + .await + .is_err(), + "trait lookup must not fail fast before set" + ); + + resolver.set(TestResolver).expect("first set succeeds"); + + let mut stream = lookup.await.expect("lookup completes after set"); + let (_source, endpoint) = stream.next().await.expect("forwarded endpoint"); + assert_eq!( + endpoint, + EndpointAddr::direct("127.0.0.1:4433".parse().unwrap()) + ); + } + #[tokio::test] async fn lookup_after_set_forwards_to_inner_resolver() { let resolver = DeferredResolver::new(); diff --git a/src/resolvers/endpoint_group.rs b/src/resolvers/endpoint_group.rs index 659fac3..5083467 100644 --- a/src/resolvers/endpoint_group.rs +++ b/src/resolvers/endpoint_group.rs @@ -1,4 +1,4 @@ -use dhttp_identity::certificate::CertificateChainKey; +use dhttp_identity::certificate::{CertificateChainKey, CertificateChainKind, CertificateSequence}; use dquic::qbase::net::addr::EndpointAddr as DquicEndpointAddr; use crate::core::parser::record::endpoint::EndpointAddr as DnsEndpointAddr; @@ -15,13 +15,49 @@ pub(crate) fn selected_endpoint_addrs( .collect() } +pub(crate) fn selected_endpoint_addrs_for_sequence( + records: impl IntoIterator, + sequence: Option, +) -> Vec { + match sequence { + Some(sequence) => selected_endpoint_records_for_sequence( + records.into_iter().map(|record| ((), record)), + Some(sequence), + ) + .into_iter() + .map(|((), endpoint)| endpoint) + .collect(), + None => selected_endpoint_addrs(records), + } +} + pub(crate) fn selected_endpoint_records( records: impl IntoIterator, +) -> Vec<(T, DquicEndpointAddr)> { + selected_endpoint_records_with_fallback_chain_keys( + records.into_iter().map(|(tag, record)| (tag, record, None)), + None, + ) +} + +pub(crate) fn selected_endpoint_records_for_sequence( + records: impl IntoIterator, + sequence: Option, +) -> Vec<(T, DquicEndpointAddr)> { + selected_endpoint_records_with_fallback_chain_keys( + records.into_iter().map(|(tag, record)| (tag, record, None)), + sequence, + ) +} + +pub(crate) fn selected_endpoint_records_with_fallback_chain_keys( + records: impl IntoIterator)>, + sequence: Option, ) -> Vec<(T, DquicEndpointAddr)> { let mut groups: Vec> = Vec::new(); - for (tag, record) in records { - let chain_key = record.certificate_chain_key(); + for (tag, record, fallback_chain_key) in records { + let chain_key = effective_chain_key(&record, fallback_chain_key); let Ok(endpoint) = DquicEndpointAddr::try_from(record) else { continue; }; @@ -35,12 +71,23 @@ pub(crate) fn selected_endpoint_records( groups.sort_by_key(|(chain_key, _)| { let primary_rank = match chain_key.kind() { - dhttp_identity::certificate::CertificateChainKind::Primary => 0, - dhttp_identity::certificate::CertificateChainKind::Secondary => 1, + CertificateChainKind::Primary => 0, + CertificateChainKind::Secondary => 1, }; (primary_rank, chain_key.sequence().get()) }); + if let Some(sequence) = sequence { + return groups + .into_iter() + .find(|(chain_key, _)| { + chain_key.kind() == CertificateChainKind::Primary + && chain_key.sequence() == sequence + }) + .map(|(_, endpoints)| endpoints) + .unwrap_or_default(); + } + groups .into_iter() .next() @@ -48,9 +95,22 @@ pub(crate) fn selected_endpoint_records( .unwrap_or_default() } +fn effective_chain_key( + record: &DnsEndpointAddr, + fallback_chain_key: Option, +) -> CertificateChainKey { + if record.is_main() || record.sequence().is_some() { + return record.certificate_chain_key(); + } + + fallback_chain_key.unwrap_or_else(|| record.certificate_chain_key()) +} + #[cfg(test)] mod tests { - use dhttp_identity::certificate::CertificateSequence; + use dhttp_identity::certificate::{ + CertificateChainKey, CertificateChainKind, CertificateSequence, + }; use crate::core::parser::record::endpoint::EndpointAddr; @@ -133,4 +193,60 @@ mod tests { dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.52:4433".parse().unwrap()) ); } + + #[test] + fn selected_endpoint_addrs_for_sequence_filters_requested_primary_group() { + let selected = super::selected_endpoint_addrs_for_sequence( + [ + direct("192.0.2.10:4433", true, 0), + direct("192.0.2.11:4433", true, 0), + direct("192.0.2.20:4433", true, 1), + direct("192.0.2.30:4433", false, 1), + ], + Some(CertificateSequence::from(1u8)), + ); + + assert_eq!( + selected, + vec![dquic::qbase::net::addr::EndpointAddr::direct( + "192.0.2.20:4433".parse().unwrap() + )] + ); + } + + #[test] + fn selected_endpoint_addrs_for_sequence_returns_empty_when_primary_sequence_missing() { + let selected = super::selected_endpoint_addrs_for_sequence( + [ + direct("192.0.2.10:4433", true, 0), + direct("192.0.2.20:4433", false, 2), + ], + Some(CertificateSequence::from(9u8)), + ); + + assert!(selected.is_empty()); + } + + #[test] + fn selected_endpoint_records_for_sequence_uses_fallback_chain_key_when_packet_omits_it() { + let endpoint = EndpointAddr::direct_v4("192.0.2.60:4433".parse().unwrap()); + let selected = super::selected_endpoint_records_with_fallback_chain_keys( + [( + "wifi", + endpoint, + Some(CertificateChainKey::new( + CertificateSequence::from(1u8), + CertificateChainKind::Primary, + )), + )], + Some(CertificateSequence::from(1u8)), + ); + + assert_eq!(selected.len(), 1); + assert_eq!(selected[0].0, "wifi"); + assert_eq!( + selected[0].1, + dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.60:4433".parse().unwrap()) + ); + } }