diff --git a/model_gateway/src/config/validation.rs b/model_gateway/src/config/validation.rs index 13639baac..ce168b65f 100644 --- a/model_gateway/src/config/validation.rs +++ b/model_gateway/src/config/validation.rs @@ -827,51 +827,44 @@ impl ConfigValidator { fn validate_urls(urls: &[String]) -> ConfigResult<()> { for url in urls { - if url.is_empty() { + if let Err(reason) = validate_worker_url(url) { return Err(ConfigError::InvalidValue { field: "worker_url".to_string(), value: url.clone(), - reason: "URL cannot be empty".to_string(), + reason, }); } + } + Ok(()) + } +} - // Case-insensitive scheme allow-list. Compare just the scheme - // segment so we don't allocate a lowercased copy of the full URL. - const ALLOWED_SCHEMES: &[&str] = &["http", "https", "grpc", "grpcs"]; - let scheme = url.split_once("://").map_or("", |(s, _)| s); - if !ALLOWED_SCHEMES - .iter() - .any(|allowed| scheme.eq_ignore_ascii_case(allowed)) - { - return Err(ConfigError::InvalidValue { - field: "worker_url".to_string(), - value: url.clone(), - reason: "URL must start with http://, https://, grpc://, or grpcs://" - .to_string(), - }); - } +/// Reject empty / schemeless / unparsable worker URLs so callers can wrap the +/// failure reason in whatever error type fits their layer. +pub(crate) fn validate_worker_url(url: &str) -> Result<(), String> { + if url.is_empty() { + return Err("URL cannot be empty".to_string()); + } - match ::url::Url::parse(url) { - Ok(parsed) => { - if parsed.host_str().is_none() { - return Err(ConfigError::InvalidValue { - field: "worker_url".to_string(), - value: url.clone(), - reason: "URL must have a valid host".to_string(), - }); - } - } - Err(e) => { - return Err(ConfigError::InvalidValue { - field: "worker_url".to_string(), - value: url.clone(), - reason: format!("Invalid URL format: {e}"), - }); - } + const ALLOWED_SCHEMES: &[&str] = &["http", "https", "grpc", "grpcs"]; + let scheme = url.split_once("://").map_or("", |(s, _)| s); + if !ALLOWED_SCHEMES + .iter() + .any(|allowed| scheme.eq_ignore_ascii_case(allowed)) + { + return Err("URL must start with http://, https://, grpc://, or grpcs://".to_string()); + } + + match ::url::Url::parse(url) { + Ok(parsed) => { + if parsed.host_str().is_none() { + return Err("URL must have a valid host".to_string()); } } - Ok(()) + Err(e) => return Err(format!("Invalid URL format: {e}")), } + + Ok(()) } fn validate_mebibyte_limit(field: &str, value_mb: usize) -> ConfigResult<()> { diff --git a/model_gateway/src/worker/service.rs b/model_gateway/src/worker/service.rs index 5c6a0cc53..7764f679e 100644 --- a/model_gateway/src/worker/service.rs +++ b/model_gateway/src/worker/service.rs @@ -16,7 +16,7 @@ use serde_json::json; use tracing::warn; use crate::{ - config::RouterConfig, + config::{validation::validate_worker_url, RouterConfig}, worker::{registry::WorkerId, worker::worker_to_info, WorkerRegistry}, workflow::{Job, JobQueue}, }; @@ -241,6 +241,8 @@ impl WorkerService { &self, config: WorkerSpec, ) -> Result { + validate_worker_url_request(&config.url)?; + if self.router_config.api_key.is_some() && config.api_key.is_none() { warn!( "Adding worker {} without API key while router has API key configured. \ @@ -431,3 +433,65 @@ impl WorkerService { Ok(UpdateWorkerResult { worker_id, url }) } } + +/// Wrap [`validate_worker_url`] so the API layer surfaces a 400 instead of a +/// config-layer error. +fn validate_worker_url_request(url: &str) -> Result<(), WorkerServiceError> { + validate_worker_url(url).map_err(|reason| WorkerServiceError::BadRequest { + message: format!("Worker URL '{url}' is invalid: {reason}"), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_worker_url_request_accepts_all_four_schemes() { + assert!(validate_worker_url_request("http://10.0.0.5:8000").is_ok()); + assert!(validate_worker_url_request("https://10.0.0.5:8000").is_ok()); + assert!(validate_worker_url_request("grpc://10.0.0.5:8000").is_ok()); + assert!(validate_worker_url_request("grpcs://10.0.0.5:8000").is_ok()); + } + + #[test] + fn validate_worker_url_request_accepts_case_insensitive_schemes() { + assert!(validate_worker_url_request("HTTP://10.0.0.5:8000").is_ok()); + assert!(validate_worker_url_request("GrPc://10.0.0.5:8000").is_ok()); + } + + #[test] + fn validate_worker_url_request_rejects_bare_host_port_as_400() { + let err = validate_worker_url_request("10.0.0.5:8000").unwrap_err(); + assert!(matches!(err, WorkerServiceError::BadRequest { .. })); + assert_eq!(err.status_code(), StatusCode::BAD_REQUEST); + assert!(err + .to_string() + .contains("http://, https://, grpc://, or grpcs://")); + } + + #[test] + fn validate_worker_url_request_rejects_empty_as_400() { + let err = validate_worker_url_request("").unwrap_err(); + assert!(matches!(err, WorkerServiceError::BadRequest { .. })); + assert!(err.to_string().contains("empty")); + } + + #[test] + fn validate_worker_url_request_rejects_unknown_scheme() { + let err = validate_worker_url_request("ftp://10.0.0.5:8000").unwrap_err(); + assert!(matches!(err, WorkerServiceError::BadRequest { .. })); + } + + #[test] + fn validate_worker_url_request_rejects_missing_host() { + let err = validate_worker_url_request("http://").unwrap_err(); + assert!(matches!(err, WorkerServiceError::BadRequest { .. })); + } + + #[test] + fn validate_worker_url_request_rejects_unparsable_url() { + let err = validate_worker_url_request("http://[invalid").unwrap_err(); + assert!(matches!(err, WorkerServiceError::BadRequest { .. })); + } +}