diff --git a/Cargo.lock b/Cargo.lock index 6070048..568e17f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -974,7 +974,7 @@ dependencies = [ [[package]] name = "rocket_ext" -version = "0.3.0" +version = "0.3.1" dependencies = [ "anyhow", "http 1.3.1", diff --git a/Cargo.toml b/Cargo.toml index 0ebd0e5..af21de7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rocket_ext" -version = "0.3.0" +version = "0.3.1" edition = "2024" authors = ["Devin Bidwell "] keywords = ["rocket", "cors", "preflight", "headers"] diff --git a/src/cors.rs b/src/cors.rs index a189fe2..da92380 100644 --- a/src/cors.rs +++ b/src/cors.rs @@ -674,7 +674,8 @@ mod tests { assert_eq!(origins.len(), 1); assert!(origins.contains(&Origin { scheme: OriginScheme::Https, - host: String::from("test.com") + host: String::from("test.com"), + port: None })); Ok(()) @@ -691,11 +692,13 @@ mod tests { assert_eq!(origins.len(), 2); assert!(origins.contains(&Origin { scheme: OriginScheme::Https, - host: String::from("test.com") + host: String::from("test.com"), + port: None })); assert!(origins.contains(&Origin { scheme: OriginScheme::Https, - host: String::from("example.com") + host: String::from("example.com"), + port: None })); Ok(()) diff --git a/src/cors/origin.rs b/src/cors/origin.rs index 22eb284..98b16de 100644 --- a/src/cors/origin.rs +++ b/src/cors/origin.rs @@ -43,11 +43,13 @@ impl TryFrom<&str> for OriginScheme { pub struct Origin { pub(crate) host: String, pub(crate) scheme: OriginScheme, + pub(crate) port: Option, } impl std::fmt::Display for Origin { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}://{}", self.scheme, self.host) + let port_suffix = self.port.map(|port| format!(":{port}")).unwrap_or_default(); + write!(f, "{}://{}{}", self.scheme, self.host, port_suffix) } } @@ -63,8 +65,9 @@ impl TryFrom<&str> for Origin { }; let host = authority.host().to_owned(); + let port = authority.port(); - Ok(Self { scheme, host }) + Ok(Self { scheme, host, port }) } } @@ -94,8 +97,9 @@ impl<'a> TryFrom> for Origin { }; let host = authority.host().to_owned(); + let port = authority.port().to_owned(); - Ok(Self { scheme, host }) + Ok(Self { scheme, host, port }) } } @@ -135,6 +139,13 @@ mod tests { assert_eq!(OriginScheme::Https, origin.scheme); assert_eq!("test.com", origin.host); + assert_eq!(None, origin.port); + + let origin = Origin::try_from("https://test.com:42")?; + + assert_eq!(OriginScheme::Https, origin.scheme); + assert_eq!("test.com", origin.host); + assert_eq!(Some(42), origin.port); Ok(()) } @@ -145,6 +156,13 @@ mod tests { assert_eq!(OriginScheme::Https, origin.scheme); assert_eq!("test.com", origin.host); + assert_eq!(None, origin.port); + + let origin = Origin::try_from(String::from("https://test.com:42"))?; + + assert_eq!(OriginScheme::Https, origin.scheme); + assert_eq!("test.com", origin.host); + assert_eq!(Some(42), origin.port); Ok(()) } @@ -157,6 +175,16 @@ mod tests { assert_eq!(OriginScheme::Https, origin.scheme); assert_eq!("test.com", origin.host); + assert_eq!(None, origin.port); + + let ab = Absolute::parse_owned("https://test.com:42".into()).expect("A valid URI"); + + let origin = Origin::try_from(ab)?; + + assert_eq!(OriginScheme::Https, origin.scheme); + assert_eq!("test.com", origin.host); + assert_eq!(Some(42), origin.port); + Ok(()) } @@ -168,6 +196,16 @@ mod tests { assert_eq!(OriginScheme::Https, origin.scheme); assert_eq!("test.com", origin.host); + assert_eq!(None, origin.port); + + let uri = Uri::parse_any("https://test.com:42").expect("A valid uri"); + + let origin = Origin::try_from(uri)?; + + assert_eq!(OriginScheme::Https, origin.scheme); + assert_eq!("test.com", origin.host); + assert_eq!(Some(42), origin.port); + Ok(()) } @@ -176,10 +214,19 @@ mod tests { let origin = Origin { scheme: OriginScheme::Https, host: "test.com".into(), + port: None, }; assert_eq!("https://test.com", origin.to_string()); + let origin = Origin { + scheme: OriginScheme::Https, + host: "test.com".into(), + port: Some(42), + }; + + assert_eq!("https://test.com:42", origin.to_string()); + Ok(()) } @@ -200,7 +247,15 @@ mod tests { let origin = Origin { scheme: OriginScheme::Http, host: "localhost".into(), + port: None, }; assert_eq!(origin.to_string(), "http://localhost"); + + let origin = Origin { + scheme: OriginScheme::Http, + host: "localhost".into(), + port: Some(42), + }; + assert_eq!(origin.to_string(), "http://localhost:42"); } }